diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a95edc4..77f77c9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,34 @@ [TOC] +## 19.6.0 (2025-06-06) + +### Features + + - Added "form factor" and "platform" to client ID. + +### Bug Fixes + + - Gracefully handle OEMCrypto_GenerateNonce() implementations that set the + nonce reference argument to zero on failure. + - Cleaned up `CertificateProvisioning` state. This prevents mismatched public + and private cert keys when an app makes multiple provisioning 4.0 requests. + - Implemented various small fixes suggested by clang-format and clang-tidy. + +### Tests + + - Updated `OEMCryptoClientTest.CheckBuildInformation_OutputLengthAPI17` to + accept a returned SHORT_BUFFER size that is larger than the actual required + size. + - Updated `OEMCryptoClientTest.CheckJsonBuildInformationAPI18` to treat the + JSON fields in the `ree` block as optional. + +### Dependency Updates + + - Updated libcppbor to 61d9bff9605ad2ffd877bd99a3bde414e21f01a2. Modifed some + Android-specific include names in order to compile correctly without + Android. + ## 19.5.0 (2025-04-02) This is a minor release with bug fixes, test improvements, and dependency diff --git a/README.md b/README.md index b77e65a3..35ea1130 100644 --- a/README.md +++ b/README.md @@ -10,51 +10,33 @@ following to learn more about the contents of this project and how to use them: The [Widevine Developer Site][wv-devsite] documents the CDM API and describes how to integrate the CDM into a system. -## New in v19.5.0 - -This is a minor release with bug fixes, test improvements, and dependency -updates. +## New in v19.6.0 ### Features - - Updated `HasRootOfTrustBeenRenewed()` to detect Drm Reprovisioning - - Updated CE BCC extraction tool: - - Added a Makefile to `wv_factory_extractor` tool - - Added an option to run BCC validator with the tool - - Refactored `ProvisioningHolder` to separate generate, fetch, and load - operations + - Added "form factor" and "platform" to client ID. ### Bug Fixes - - Ignored certain errors during `RemoveOfflineLicense()` for apps which have - been unprovisioned - - Allowed `key_session` to be equal to `oec_session` when removing entitled - key sessions + - Gracefully handle OEMCrypto_GenerateNonce() implementations that set the + nonce reference argument to zero on failure. + - Cleaned up `CertificateProvisioning` state. This prevents mismatched public + and private cert keys when an app makes multiple provisioning 4.0 requests. + - Implemented various small fixes suggested by clang-format and clang-tidy. ### Tests - - Added test `PrintClientAndServerVersionNumber` that prints the core message - info for both provisioning and license request/response - - Updated provisioning server version test to accommodate v16 responses - - Relaxed BCC validation test requirements, downgrading non-critical errors - to warnings and improving output clarity - - Added integration test to verify that renewal is not needed before renewal - delay seconds - - Added check to verify that renewals include client ID when "always includes - client ID" is true - - Updated `CdmUseCase_LicenseWithRenewal` test to verify that request has - correct info for persistent license - - Added a test to verify PST with length 127 succeeds - - Added new duration tests: - - A test for loading licenses unrelated to the content being played back - - Tests for the `30sSoftRental_UnlimitedPlayback` scenario - - Tests with past and future license start time policies - - Tests for short playback timers with unlimited rental duration + - Updated `OEMCryptoClientTest.CheckBuildInformation_OutputLengthAPI17` to + accept a returned SHORT_BUFFER size that is larger than the actual required + size. + - Updated `OEMCryptoClientTest.CheckJsonBuildInformationAPI18` to treat the + JSON fields in the `ree` block as optional. ### Dependency Updates - - Updated BoringSSL to latest (e4b6d4f7) - - Updated googletest to latest (4902ea2) + - Updated libcppbor to 61d9bff9605ad2ffd877bd99a3bde414e21f01a2. Modifed some + Android-specific include names in order to compile correctly without + Android. [CHANGELOG.md](./CHANGELOG.md) lists the major changes for each past release. diff --git a/cdm/include/cdm.h b/cdm/include/cdm.h index e102eebb..306bd57a 100644 --- a/cdm/include/cdm.h +++ b/cdm/include/cdm.h @@ -58,6 +58,7 @@ class CDM_EXPORT Cdm : public ITimerClient { kPersistentLicense = 1, kPersistent = kPersistentLicense, // deprecated name from June 1 draft // kPersistentUsageRecord = 2, // deprecated, no longer supported. + kUnknownSessionType = -1, // For error conditions }; // Message types defined by EME. diff --git a/cdm/include/cdm_version.h b/cdm/include/cdm_version.h index edb284a2..2400a065 100644 --- a/cdm/include/cdm_version.h +++ b/cdm/include/cdm_version.h @@ -10,7 +10,7 @@ # define CDM_VERSION_MAJOR 19 #endif #ifndef CDM_VERSION_MINOR -# define CDM_VERSION_MINOR 5 +# define CDM_VERSION_MINOR 6 #endif #ifndef CDM_VERSION_PATCH # define CDM_VERSION_PATCH 0 diff --git a/cdm/src/cdm.cpp b/cdm/src/cdm.cpp index c5a4b5c6..a3da9f9c 100644 --- a/cdm/src/cdm.cpp +++ b/cdm/src/cdm.cpp @@ -363,7 +363,8 @@ class CdmImpl final : public Cdm, public WvCdmEventListener { int64_t expiration; KeyStatusMap key_statuses; - SessionMetadata() : callable(false), type((SessionType)-1), expiration(0) {} + SessionMetadata() + : callable(false), type(kUnknownSessionType), expiration(0) {} }; std::map sessions_; diff --git a/cdm/src/log.cpp b/cdm/src/log.cpp index 4f20090a..1499634c 100644 --- a/cdm/src/log.cpp +++ b/cdm/src/log.cpp @@ -40,8 +40,8 @@ void Log(const char* file, const char* function, int line, LogPriority level, } const char* severities[] = {"ERROR", "WARN", "INFO", "DEBUG", "VERBOSE"}; - if (level < 0 || level >= static_cast(sizeof(severities) / - sizeof(severities[0]))) { + if (level < 0 || static_cast(level) >= + sizeof(severities) / sizeof(severities[0])) { std::string fatal_message(kFallbackLogMessage); FormatString(&fatal_message, "[FATAL:%s(%d):%s] Invalid log priority level: %d", file, line, diff --git a/cdm/src/properties_ce.cpp b/cdm/src/properties_ce.cpp index a33ec558..262137e8 100644 --- a/cdm/src/properties_ce.cpp +++ b/cdm/src/properties_ce.cpp @@ -37,6 +37,8 @@ std::string product_name_; std::string device_name_; std::string arch_name_; std::string build_info_; +std::string platform_; +std::string form_factor_; bool isClientInfoFieldValid(const char* tag, const char* value) { constexpr char kForbiddenSeparator[] = " | "; @@ -59,14 +61,12 @@ bool isClientInfoFieldValid(const char* tag, const char* value) { } void SetClientInfo() { - std::string platform; - std::string form_factor; std::string version; #if defined(RUNTIME_CLIENT_INFO) if (!CDM_NAMESPACE::ReadClientInformation( &company_name_, &model_name_, &model_year_, &product_name_, - &device_name_, &arch_name_, &platform, &form_factor, &version)) { + &device_name_, &arch_name_, &platform_, &form_factor_, &version)) { LOGE("ReadClientInformation failed."); client_info_is_valid_ = false; return; @@ -78,8 +78,8 @@ void SetClientInfo() { isClientInfoFieldValid("product_name", product_name_.c_str()) && isClientInfoFieldValid("device_name", device_name_.c_str()) && isClientInfoFieldValid("arch_name", arch_name_.c_str()) && - isClientInfoFieldValid("platform", platform.c_str()) && - isClientInfoFieldValid("form_factor", form_factor.c_str()) && + isClientInfoFieldValid("platform", platform_.c_str()) && + isClientInfoFieldValid("form_factor", form_factor_.c_str()) && isClientInfoFieldValid("version", version.c_str()))) { client_info_is_valid_ = false; return; @@ -104,14 +104,14 @@ void SetClientInfo() { product_name_ = CLIENT_PRODUCT_NAME; device_name_ = CLIENT_DEVICE_NAME; arch_name_ = CLIENT_ARCH_NAME; - platform = CLIENT_PLATFORM; - form_factor = CLIENT_FORM_FACTOR; + platform_ = CLIENT_PLATFORM; + form_factor_ = CLIENT_FORM_FACTOR; version = CLIENT_VERSION; #endif if (!wvutil::FormatString( &build_info_, "%s | %s | %s | %s | CE CDM %s | %s | %s | %s", - version.c_str(), platform.c_str(), form_factor.c_str(), + version.c_str(), platform_.c_str(), form_factor_.c_str(), arch_name_.c_str(), CDM_VERSION, CPU_ARCH_MESSAGE, LOGGING_MESSAGE, BUILD_FLAVOR_MESSAGE)) { client_info_is_valid_ = false; @@ -237,6 +237,16 @@ bool Properties::GetWVCdmVersion(std::string* version) { return GetValue(CDM_VERSION, version); } +// static +bool Properties::GetPlatform(std::string* platform) { + return GetValue(platform_.c_str(), platform); +} + +// static +bool Properties::GetFormFactor(std::string* form_factor) { + return GetValue(form_factor_.c_str(), form_factor); +} + // static bool Properties::GetDeviceFilesBasePath(CdmSecurityLevel, std::string* base_path) { diff --git a/core/include/cdm_engine.h b/core/include/cdm_engine.h index 61282c13..a29494d3 100644 --- a/core/include/cdm_engine.h +++ b/core/include/cdm_engine.h @@ -214,6 +214,15 @@ class CdmEngine { // system. This will force the device to reprovision itself. virtual CdmResponseType Unprovision(CdmSecurityLevel security_level); + // Remove the system's REE-side OEM certificate for the specified + // |security_level|. + // Only effects two-stage provisioning devices which have an OEM cert + // in the REE side file system. + // Removing the OEM certificate will cause all DRM certificates tied to + // the OEM certificate to be invalidated and unloadable to future + // sessions. + virtual CdmResponseType UnprovisionOemCert(CdmSecurityLevel security_level); + // Return the list of key_set_ids stored on the current (origin-specific) // file system. virtual CdmResponseType ListStoredLicenses( diff --git a/core/include/cdm_session.h b/core/include/cdm_session.h index dc85922a..222e72c1 100644 --- a/core/include/cdm_session.h +++ b/core/include/cdm_session.h @@ -265,7 +265,7 @@ class CdmSession { // true otherwise. bool VerifyOfflineUsageEntry(); - bool HasRootOfTrustBeenRenewed(); + bool HasRootOfTrustBeenRenewed(bool is_load); CdmResponseType ResetCryptoSession(); @@ -327,6 +327,7 @@ class CdmSession { UsageEntryIndex usage_entry_index_ = 0; UsageEntry usage_entry_; std::string usage_provider_session_token_; + std::string exported_license_data_; // information useful for offline and usage scenarios CdmKeyMessage key_request_; diff --git a/core/include/certificate_provisioning.h b/core/include/certificate_provisioning.h index 4906304c..8e23f52a 100644 --- a/core/include/certificate_provisioning.h +++ b/core/include/certificate_provisioning.h @@ -73,6 +73,28 @@ class CertificateProvisioning { // |default_url| by GetProvisioningRequest(). static void GetProvisioningServerUrl(std::string* default_url); + enum State { + // Freshly created, not yet initialized. + kUninitialized, + // A successful call to Init() has been made. + kInitialized, + // Has generated a DRM request; apps are allowed generate + // another one even if a response has not been received. + kDrmRequestSent, + // Has received (and successfully loaded) a DRM response. + kDrmResponseReceived, + // Has generated an OEM (Prov 4.0) request; apps are allowed + // generate another one even if a response has not been + // received. + kOemRequestSent, + // Has received (and successfully loaded) an OEM response. + kOemResponseReceived, + }; + static const char* StateToString(State state); + + // State setter for testing only. + void SetStateForTesting(State state) { state_ = state; } + private: #if defined(UNIT_TEST) friend class CertificateProvisioningTest; @@ -123,18 +145,29 @@ class CertificateProvisioning { CdmResponseType CloseSessionOnError(CdmResponseType status); void CloseSession(); + // Tracks the state of CertificateProvisioning. + State state_ = kUninitialized; + std::unique_ptr crypto_session_; CdmCertificateType cert_type_; std::unique_ptr service_certificate_; std::string request_; + + // == Provisioning 4.0 Variables == // The wrapped private key in provisioning 4 generated by calling // GenerateCertificateKeyPair. It will be saved to file system if a valid // response is received. - std::string provisioning_40_wrapped_private_key_; - // Key type of the generated key pair in provisioning 4. - CryptoWrappedKey::Type provisioning_40_key_type_; - // Store the last provisioning request message - std::string provisioning_request_message_; + CryptoWrappedKey prov40_wrapped_private_key_; + // Cache of the most recently sent OEM/DRM public key sent. Used + // to match the response with the request. + // This MUST be matched with the current |prov40_wrapped_private_key_|. + std::string prov40_public_key_; + + // Store the last provisioning request message. + // This is the serialized ProvisioningRequest. + // Used for X.509 responses which require the original + // request to verify the signature of the response. + std::string prov40_request_; }; // class CertificateProvisioning } // namespace wvcdm #endif // WVCDM_CORE_CERTIFICATE_PROVISIONING_H_ diff --git a/core/include/crypto_session.h b/core/include/crypto_session.h index c5b84a83..6ace4e4e 100644 --- a/core/include/crypto_session.h +++ b/core/include/crypto_session.h @@ -359,6 +359,9 @@ class CryptoSession { RequestedSecurityLevel requested_security_level, CdmClientTokenType* token_type); + virtual CdmResponseType LoadLicenseData(const std::string& data); + virtual CdmResponseType SaveLicenseData(std::string* data); + // OTA Provisioning static bool needs_keybox_provisioning() { return needs_keybox_provisioning_; } diff --git a/core/include/crypto_wrapped_key.h b/core/include/crypto_wrapped_key.h index 750796f4..f8bb14fb 100644 --- a/core/include/crypto_wrapped_key.h +++ b/core/include/crypto_wrapped_key.h @@ -5,6 +5,7 @@ #define WVCDM_CORE_CRYPTO_WRAPPED_KEY_H_ #include +#include #include "wv_class_utils.h" @@ -20,6 +21,8 @@ class CryptoWrappedKey { WVCDM_DEFAULT_COPY_AND_MOVE(CryptoWrappedKey); CryptoWrappedKey(Type type, const std::string& key) : type_(type), key_(key) {} + CryptoWrappedKey(Type type, std::string&& key) + : type_(type), key_(std::move(key)) {} Type type() const { return type_; } void set_type(Type type) { type_ = type; } @@ -28,6 +31,7 @@ class CryptoWrappedKey { // Mutable reference getter for passing to OMECrypto. std::string& key() { return key_; } void set_key(const std::string& key) { key_ = key; } + void set_key(std::string&& key) { key_ = std::move(key); } void Clear() { type_ = kUninitialized; diff --git a/core/include/device_files.h b/core/include/device_files.h index f02c6389..aae1ead5 100644 --- a/core/include/device_files.h +++ b/core/include/device_files.h @@ -110,6 +110,8 @@ class DeviceFiles { UsageEntryIndex usage_entry_index; std::string drm_certificate; CryptoWrappedKey wrapped_private_key; + // Exported license data + std::string exported_license_data; }; struct CdmUsageData { diff --git a/core/include/properties.h b/core/include/properties.h index bda142da..a7978252 100644 --- a/core/include/properties.h +++ b/core/include/properties.h @@ -75,6 +75,8 @@ class Properties { static bool GetProductName(std::string* product_name); static bool GetBuildInfo(std::string* build_info); static bool GetWVCdmVersion(std::string* version); + static bool GetPlatform(std::string* platform); + static bool GetFormFactor(std::string* form_factor); // Gets the base path for the device non-secure storage. Note that, depending // on the value of device_files_is_a_real_filesystem, this may or may not be // a real filesystem path. diff --git a/core/include/wv_cdm_types.h b/core/include/wv_cdm_types.h index ab648caf..187c40de 100644 --- a/core/include/wv_cdm_types.h +++ b/core/include/wv_cdm_types.h @@ -465,6 +465,9 @@ enum CdmResponseEnum : int32_t { GET_DEVICE_SIGNED_CSR_PAYLOAD_ERROR = 399, GET_TOKEN_FROM_EMBEDDED_CERT_ERROR = 400, GET_BCC_SIGNATURE_TYPE_ERROR = 401, + PROVISIONING_UNEXPECTED_RESPONSE_ERROR = 402, + PROVISIONING_4_STALE_RESPONSE = 403, + PROVISIONING_4_FAILED_TO_VERIFY_CERT_KEY = 404, // Don't forget to add new values to // * core/src/wv_cdm_types.cpp // * android/include/mapErrors-inl.h diff --git a/core/src/cdm_engine.cpp b/core/src/cdm_engine.cpp index 88fcc350..52f868c5 100644 --- a/core/src/cdm_engine.cpp +++ b/core/src/cdm_engine.cpp @@ -1306,6 +1306,13 @@ CdmResponseType CdmEngine::HandleProvisioningResponse( LOGE("Device has been revoked, cannot provision: status = %s", ret.ToString().c_str()); cert_provisioning_.reset(); + } else if (ret == PROVISIONING_4_STALE_RESPONSE) { + // The response is considered "stale" (likely from generating multiple + // requests, and providing out of order responses). + // Drop message without returning error or resetting + // provisioning context. + LOGW("Stale response, app may try again"); + return CdmResponseType(NO_ERROR); } else { // It is possible that a provisioning attempt was made after this one was // requested but before the response was received, which will cause this @@ -1352,8 +1359,7 @@ CdmProvisioningStatus CdmEngine::GetProvisioningStatus( return kUnknownProvisionStatus; } - UsagePropertySet property_set; - if (handle.HasCertificate(property_set.use_atsc_mode())) { + if (handle.HasCertificate(/* atsc_mode_enabled = */ false)) { return kProvisioned; } if (crypto_session->GetPreProvisionTokenType() == kClientTokenBootCertChain) { @@ -1376,8 +1382,8 @@ CdmResponseType CdmEngine::Unprovision(CdmSecurityLevel security_level) { LOGD("OKP fallback to L3"); security_level = kSecurityLevelL3; } - // Devices with baked-in DRM certs cannot be reprovisioned and therefore must - // not be unprovisioned. + // Devices with baked-in DRM certs cannot be reprovisioned + // and therefore must not be unprovisioned. std::unique_ptr crypto_session( CryptoSession::MakeCryptoSession(metrics_->GetCryptoMetrics())); CdmClientTokenType token_type = kClientTokenUninitialized; @@ -1396,18 +1402,78 @@ CdmResponseType CdmEngine::Unprovision(CdmSecurityLevel security_level) { LOGE("Unable to initialize device files"); return CdmResponseType(UNPROVISION_ERROR_1); } - - // TODO(b/141705730): Remove usage entries during unprovisioning. - if (!file_system_->IsGlobal()) { - if (!handle.RemoveCertificate() || !handle.RemoveOemCertificate()) { - LOGE("Unable to delete certificate"); + // This if statement is misleading. There is no consistent + // concept of "global" vs "per-app/origin" storage in the + // core library. Android vs CE CDM behave very different. + // On CE device: + // file_system_->IsGlobal() is always true, even if app/origin + // specific. + // On Android: + // file_system_->IsGlobal() is always false, except for some C++ + // test code. + // TODO(b/142280599): Refactor this once CE CDM SPOIDs are supported + // by the file system. May require moving platform-dependent behavior + // to the platform-dependent layer. Only have this remove the + // certificate and nothing else. + if (!file_system_->IsGlobal()) { // AKA is Android + // TODO(b/141705730): Remove usage entries during unprovisioning. + // Not considered an error if no certificate exists. + if (handle.HasCertificate(/* atsc_mode_enabled = */ false) && + !handle.RemoveCertificate()) { + LOGE("Unable to delete DRM certificate"); return CdmResponseType(UNPROVISION_ERROR_2); } + // Maintaining old behavior expected by Android. + const CdmResponseType oem_cert_status = UnprovisionOemCert(security_level); + if (oem_cert_status != NO_ERROR) return oem_cert_status; + } else { // AKA is CE CDM (or some Android tests) + // On CE CDM, deleting all files only deletes the app/origin + // specific files. + // On Android, this will delete all files (only possible + // during testing). + if (!handle.DeleteAllFiles()) { + LOGE("Unable to delete files"); + return CdmResponseType(UNPROVISION_ERROR_3); + } + } + return CdmResponseType(NO_ERROR); +} + +CdmResponseType CdmEngine::UnprovisionOemCert(CdmSecurityLevel security_level) { + LOGI("security_level = %s", CdmSecurityLevelToString(security_level)); + if (security_level == kSecurityLevelL1 && OkpIsInFallbackMode()) { + LOGD("OKP fallback to L3"); + security_level = kSecurityLevelL3; + } + // Only BCC-based system have an OEM certificate that can + // unprovisioned. + // Prov 3.0 system's OEM certs are built into the TEE. + std::unique_ptr crypto_session( + CryptoSession::MakeCryptoSession(metrics_->GetCryptoMetrics())); + CdmClientTokenType token_type = kClientTokenUninitialized; + const CdmResponseType res = crypto_session->GetProvisioningMethod( + security_level == kSecurityLevelL3 ? kLevel3 : kLevelDefault, + &token_type); + if (res != NO_ERROR) { + return res; + } + if (token_type != kClientTokenBootCertChain) { + LOGD("Device does not support OEM certificate unprovisioning"); return CdmResponseType(NO_ERROR); } - if (!handle.DeleteAllFiles()) { - LOGE("Unable to delete files"); - return CdmResponseType(UNPROVISION_ERROR_3); + // For Prov 4.0 devices, this will cause every app/origin client + // to lose their offline content for the same TEE security level. + wvutil::FileSystem global_file_system; + DeviceFiles global_handle(&global_file_system); + if (!global_handle.Init(security_level)) { + LOGE("Unable to initialize global device files"); + return CdmResponseType(UNPROVISION_ERROR_1); + } + // Not considered an error if no certificate exists. + if (global_handle.HasOemCertificate() && + !global_handle.RemoveOemCertificate()) { + LOGE("Unable to delete OEM certificate"); + return CdmResponseType(UNPROVISION_ERROR_2); } return CdmResponseType(NO_ERROR); } diff --git a/core/src/cdm_session.cpp b/core/src/cdm_session.cpp index 407dc22d..c1153531 100644 --- a/core/src/cdm_session.cpp +++ b/core/src/cdm_session.cpp @@ -160,7 +160,9 @@ CdmResponseType CdmSession::Init(CdmClientPropertySet* cdm_client_property_set, return CdmResponseType(NEED_PROVISIONING); // Require reprovisioning if the root of trust has changed - if (HasRootOfTrustBeenRenewed()) return CdmResponseType(NEED_PROVISIONING); + if (HasRootOfTrustBeenRenewed(forced_session_id)) { + return CdmResponseType(NEED_PROVISIONING); + } if (forced_session_id) { key_set_id_ = *forced_session_id; @@ -278,8 +280,8 @@ CdmResponseType CdmSession::RestoreOfflineSession(const CdmKeySetId& key_set_id, // Only restore offline licenses if they are active or this is a release // retry. - if (!(license_type == kLicenseTypeRelease || - license_data.state == kLicenseStateActive)) { + if (license_type != kLicenseTypeRelease && + license_data.state != kLicenseStateActive) { LOGE("Invalid offline license state: state = %s, license_type = %s", CdmOfflineLicenseStateToString(license_data.state), CdmLicenseTypeToString(license_type)); @@ -335,6 +337,12 @@ CdmResponseType CdmSession::RestoreOfflineSession(const CdmKeySetId& key_set_id, if (result != NO_ERROR) return result; } + if (!license_data.exported_license_data.empty()) { + result = + crypto_session_->LoadLicenseData(license_data.exported_license_data); + if (result != NO_ERROR) return result; + } + if (license_type == kLicenseTypeRelease) { result = license_parser_->RestoreLicenseForRelease( license_data.drm_certificate, key_request_, key_response_); @@ -596,6 +604,10 @@ CdmResponseType CdmSession::AddKeyInternal(const CdmKeyResponse& key_response) { if (sts != KEY_ADDED) return (sts == KEY_ERROR) ? CdmResponseType(ADD_KEY_ERROR) : sts; + // If we are L1 or export is not supported, this call will do nothing. + sts = crypto_session_->SaveLicenseData(&exported_license_data_); + if (sts != NO_ERROR) return sts; + license_received_ = true; key_response_ = key_response; @@ -993,7 +1005,8 @@ bool CdmSession::StoreLicense(CdmOfflineLicenseState state, int* error_detail) { usage_entry_, usage_entry_index_, drm_certificate_, - wrapped_private_key_}; + wrapped_private_key_, + exported_license_data_}; bool result = file_handle_->StoreLicense(license_data, &error_detail_alt); if (error_detail != nullptr) { @@ -1246,14 +1259,16 @@ CdmResponseType CdmSession::LoadPrivateKey( // Use a change in system ID as an indication that Root of Trust // has been renewed. -bool CdmSession::HasRootOfTrustBeenRenewed() { +bool CdmSession::HasRootOfTrustBeenRenewed(bool is_load) { if (atsc_mode_enabled_) return false; // Ignore System ID changes for non-Rikers L3 as the root of trust might not - // have changed even if the system ID has. + // have changed even if the system ID has. Also ignore for the Rikers L3 when + // loading an existing license since we can still load it without renewing. if (crypto_session_->GetSecurityLevel() == kSecurityLevelL3 && - crypto_session_->GetPreProvisionTokenType() != - kClientTokenDrmCertificateReprovisioning) { + (crypto_session_->GetPreProvisionTokenType() != + kClientTokenDrmCertificateReprovisioning || + is_load)) { return false; } diff --git a/core/src/cdm_usage_table.cpp b/core/src/cdm_usage_table.cpp index f2dbe368..1807e405 100644 --- a/core/src/cdm_usage_table.cpp +++ b/core/src/cdm_usage_table.cpp @@ -815,8 +815,7 @@ CdmResponseType CdmUsageTable::StoreEntry(UsageEntryIndex entry_index, case kStorageUsageInfo: { UsageEntry retrieved_entry; UsageEntryIndex retrieved_entry_index; - std::string provider_session_token, init_data, key_request, key_response, - key_renewal_request; + std::string provider_session_token, key_request, key_response; std::string drm_certificate; CryptoWrappedKey wrapped_private_key; if (!device_files->RetrieveUsageInfoByKeySetId( diff --git a/core/src/certificate_provisioning.cpp b/core/src/certificate_provisioning.cpp index bb0faedc..35152b5e 100644 --- a/core/src/certificate_provisioning.cpp +++ b/core/src/certificate_provisioning.cpp @@ -1,9 +1,10 @@ // Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary // source code may only be used and distributed under the Widevine License // Agreement. - #include "certificate_provisioning.h" +#include + #include "client_identification.h" #include "crypto_wrapped_key.h" #include "device_files.h" @@ -88,6 +89,127 @@ bool RetrieveOemCertificateAndLoadPrivateKey(CryptoSession& crypto_session, return true; } +// Checks if any instances of |needle| sequences found in the |haystack|. +// +// Special cases: +// - An empty |needle| is always present, even if |haystack| is empty. +// Note: This is a convention used by many string utility +// libraries. +bool StringContains(const std::string& haystack, const std::string& needle) { + if (needle.empty()) return true; + if (haystack.size() < needle.size()) return false; + return haystack.find(needle) != std::string::npos; +} + +// Checks if the |needle| sequences found at the end of |haystack|. +// +// Special cases: +// - An empty |needle| is always present, even if |haystack| is empty. +// Note: This is a convention used by many string utility +// libraries. +bool StringEndsWith(const std::string& haystack, const std::string& needle) { + if (haystack.size() < needle.size()) return false; + return std::equal(haystack.rbegin(), haystack.rbegin() + needle.size(), + needle.rbegin(), needle.rend()); +} + +// Checks the actual length of an ASN.1 DER encoded message +// roughly matches the expected length from within the message. +// Technically, the DER message may contain some trailing +// end-of-contents bytes (at most 2). +// +// Parameters: +// |actual_length| - The real length of the DER message +// |expected_length| - The reported length of the DER message plus +// the header bytes parsed. +bool IsAsn1ExpectedLength(size_t actual_length, size_t expected_length) { + return actual_length >= expected_length && + actual_length <= (expected_length + 2); +} + +// Checks if the provided |message| resembles ASN.1 DER encoded +// message. +// This is a light check, it verifies the type (SEQUENCE) and that +// the encoded length matches the total message length. +bool IsAsn1DerSequenceLike(const std::string& message) { + // Anything less than 3 bytes will not be an ASN.1 sequence. + if (message.size() < 3) return false; + // Verify type header + // class = universal(0) - bits 6-7 + // p/c = constructed(1) - bit 5 + // tag = sequence(0x10) - bits 0-4 + static constexpr uint8_t kUniversal = (0 << 6); + static constexpr uint8_t kConstructBit = (1 << 5); + static constexpr uint8_t kSequenceTag = 0x10; + static constexpr uint8_t kSequenceHeader = + kUniversal | kConstructBit | kSequenceTag; + const uint8_t type_header = static_cast(message.front()); + if (type_header != kSequenceHeader) return false; + + // Verify length. + const uint8_t length_header = static_cast(message[1]); + // A reserved length is never used. If |length_header| is + // reserved length, then this is not an ASN.1 message. + static constexpr uint8_t kReservedLength = 0xff; + if (length_header == kReservedLength) return false; + + static constexpr uint8_t kIndefiniteLength = 0x80; + if (length_header == kIndefiniteLength) { + // If length is indefinite, then search for two "end of contents" + // octets at the end. + static constexpr uint8_t kAsnEndOfContents = 0x00; + const std::string kDoubleEoc(2, kAsnEndOfContents); + return StringEndsWith(message, kDoubleEoc); + } + + // Definite lengths may be long or short (most likely long for our case). + static constexpr uint8_t kLongLengthBit = 0x80; + + if ((length_header & kLongLengthBit) != kLongLengthBit) { + // Short length (unlikely, but check anyways). + // For short lengths, the value component of the length + // header is the payload length. + static constexpr uint8_t kShortLengthMask = 0x7f; + const size_t payload_length = + static_cast(length_header & kShortLengthMask); + + // The total message is: type header + length header + payload. + const size_t total_length = 2 + payload_length; + return IsAsn1ExpectedLength(message.size(), total_length); + } + + // Long length. + // |length_header| contains the number of bytes following the + // length header containing the payload length. + static constexpr uint8_t kLengthSizeMask = 0x7f; + const size_t length_length = + static_cast(length_header & kLengthSizeMask); + // For long-lengths, the first two bytes were type header and + // length header. + static constexpr size_t kPayloadLengthOffset = 2; + // If the message is smaller than needed to obtain the length, + // it is either not ASN.1 (or an incomplete message, which is still + // invalid). + if ((message.size()) < (length_length + kPayloadLengthOffset)) return false; + // DER encoding should use the minimum number of bytes necessary + // to encode the length, and if the number of bytes to encode the + // length is more than 3 (payload is larged than 16 MB) which is much + // larger than any expected certificate chain. + if (length_length > 3) return false; + + // Decode the length as big-endian. + size_t payload_length = 0; + for (size_t i = 0; i < length_length; i++) { + // Casting from char to uint8_t to size_t is necessary. + const uint8_t length_byte = + static_cast(message[kPayloadLengthOffset + i]); + payload_length = (payload_length << 8) + static_cast(length_byte); + } + + // Total message is: type header + length header + payload length + payload. + const size_t total_length = 2 + length_length + payload_length; + return IsAsn1ExpectedLength(message.size(), total_length); +} } // namespace // Protobuf generated classes. using video_widevine::DrmCertificate; @@ -99,6 +221,25 @@ using video_widevine::PublicKeyToCertify; using video_widevine::SignedDrmCertificate; using video_widevine::SignedProvisioningMessage; +// static +const char* CertificateProvisioning::StateToString(State state) { + switch (state) { + case kUninitialized: + return "Uninitialized"; + case kInitialized: + return "Initialized"; + case kDrmRequestSent: + return "DrmRequestSent"; + case kDrmResponseReceived: + return "DrmResponseReceived"; + case kOemRequestSent: + return "OemRequestSent"; + case kOemResponseReceived: + return "OemResponseReceived"; + } + return ""; +} + // static void CertificateProvisioning::GetProvisioningServerUrl( std::string* default_url) { @@ -115,7 +256,11 @@ CdmResponseType CertificateProvisioning::Init( service_certificate.empty() ? wvutil::a2bs_hex(kCpProductionServiceCertificate) : service_certificate; - return service_certificate_->Init(certificate); + const CdmResponseType result = service_certificate_->Init(certificate); + if (result == NO_ERROR) { + state_ = kInitialized; + } + return result; } // Fill in the appropriate SPOID (Stable Per-Origin IDentifier) option. @@ -207,11 +352,18 @@ CdmResponseType CertificateProvisioning::GetProvisioningRequestInternal( default_url->assign(kProvisioningServerUrl); + if (state_ != kInitialized) { + LOGD("Overriding old request: state = %s", StateToString(state_)); + // Once the previous session is closed, there is no way to complete + // an in-flight request. + state_ = kInitialized; + } + CloseSession(); CdmResponseType status = crypto_session_->Open(requested_security_level); if (NO_ERROR != status) { - LOGE("Failed to create a crypto session: status = %d", - static_cast(status)); + LOGE("Failed to create a crypto session: status = %s", + status.ToString().c_str()); return status; } @@ -300,6 +452,7 @@ CdmResponseType CertificateProvisioning::GetProvisioningRequestInternal( } else { *request = std::move(serialized_request); } + state_ = kDrmRequestSent; return CdmResponseType(NO_ERROR); } @@ -325,7 +478,11 @@ CdmResponseType CertificateProvisioning::GetProvisioning40RequestInternal( return CdmResponseType(PROVISIONING_4_FAILED_TO_INITIALIZE_DEVICE_FILES); } - ProvisioningRequest provisioning_request; + if (!service_certificate_) { + LOGE("Service certificate not set"); + return CdmResponseType(CERT_PROVISIONING_EMPTY_SERVICE_CERTIFICATE); + } + // Determine the current stage by checking if OEM cert exists. std::string stored_oem_cert; if (global_file_handle.HasOemCertificate()) { @@ -342,9 +499,11 @@ CdmResponseType CertificateProvisioning::GetProvisioning40RequestInternal( } } } + const bool is_oem_prov_request = stored_oem_cert.empty(); // Retrieve the Spoid, but put it to the client identification instead, so it // is encrypted. + ProvisioningRequest provisioning_request; CdmAppParameterMap additional_parameter; CdmResponseType status = SetSpoidParameter(origin, spoid, &provisioning_request); @@ -364,7 +523,7 @@ CdmResponseType CertificateProvisioning::GetProvisioning40RequestInternal( provisioning_request.clear_stable_id(); } - if (stored_oem_cert.empty()) { + if (is_oem_prov_request) { // This is the first stage provisioning. default_url->assign(std::string(kProvisioningServerUrl) + kProv40FirstStageServerUrlSuffix); @@ -378,8 +537,8 @@ CdmResponseType CertificateProvisioning::GetProvisioning40RequestInternal( // Since |stored_oem_cert| is empty, the client identification token will be // retrieved from OEMCrypto, which is the BCC in this case. - status = FillEncryptedClientId(stored_oem_cert, provisioning_request, - wv_service_cert); + status = FillEncryptedClientId(/* client_token = */ std::string(), + provisioning_request, wv_service_cert); if (status != NO_ERROR) return status; } else { // This is the second stage provisioning. @@ -417,25 +576,24 @@ CdmResponseType CertificateProvisioning::GetProvisioning40RequestInternal( std::string public_key; std::string public_key_signature; - provisioning_40_wrapped_private_key_.clear(); - provisioning_40_key_type_ = CryptoWrappedKey::kUninitialized; + std::string wrapped_private_key; + CryptoWrappedKey::Type private_key_type = CryptoWrappedKey::kUninitialized; status = crypto_session_->GenerateCertificateKeyPair( - &public_key, &public_key_signature, &provisioning_40_wrapped_private_key_, - &provisioning_40_key_type_); + &public_key, &public_key_signature, &wrapped_private_key, + &private_key_type); if (status != NO_ERROR) return status; PublicKeyToCertify* key_to_certify = provisioning_request.mutable_certificate_public_key(); key_to_certify->set_public_key(public_key); key_to_certify->set_signature(public_key_signature); - key_to_certify->set_key_type(provisioning_40_key_type_ == - CryptoWrappedKey::kRsa + key_to_certify->set_key_type(private_key_type == CryptoWrappedKey::kRsa ? PublicKeyToCertify::RSA : PublicKeyToCertify::ECC); std::string serialized_message; provisioning_request.SerializeToString(&serialized_message); - provisioning_request_message_ = serialized_message; + prov40_request_ = serialized_message; SignedProvisioningMessage signed_provisioning_msg; signed_provisioning_msg.set_message(serialized_message); @@ -491,6 +649,15 @@ CdmResponseType CertificateProvisioning::GetProvisioning40RequestInternal( *request = std::move(serialized_request); } request_ = std::move(serialized_message); + // Need the wrapped Prov 4.0 private key to store once the response + // is received. The wrapped key is not available in the response. + prov40_wrapped_private_key_ = + CryptoWrappedKey(private_key_type, std::move(wrapped_private_key)); + // Store the public key from the request. This is used to match + // up the response with the most recently generated request. + prov40_public_key_ = std::move(public_key); + + state_ = is_oem_prov_request ? kOemRequestSent : kDrmRequestSent; return CdmResponseType(NO_ERROR); } @@ -553,6 +720,18 @@ CdmResponseType CertificateProvisioning::HandleProvisioning40Response( return CdmResponseType(PROVISIONING_4_RESPONSE_HAS_ERROR_STATUS); } } + if (state_ == kOemResponseReceived || state_ == kDrmResponseReceived) { + // A response has already been received (successfully), this + // response can be silently dropped. + LOGW("Response already received: state = %s", StateToString(state_)); + return CdmResponseType(NO_ERROR); + } + if (state_ != kOemRequestSent && state_ != kDrmRequestSent) { + LOGE("Not expecting a response: state = %s", StateToString(state_)); + return CdmResponseType(PROVISIONING_UNEXPECTED_RESPONSE_ERROR); + } + LOGD("Handling response: state = %s", StateToString(state_)); + const bool is_oem_prov_response = (state_ == kOemRequestSent); const std::string& device_certificate = provisioning_response.device_certificate(); @@ -561,17 +740,16 @@ CdmResponseType CertificateProvisioning::HandleProvisioning40Response( return CdmResponseType(PROVISIONING_4_RESPONSE_HAS_NO_CERTIFICATE); } - if (provisioning_40_wrapped_private_key_.empty()) { - LOGE("No private key was generated"); + if (!prov40_wrapped_private_key_.IsValid() || prov40_public_key_.empty()) { + LOGE("No %s key was generated", + !prov40_wrapped_private_key_.IsValid() ? "private" : "public"); return CdmResponseType(PROVISIONING_4_NO_PRIVATE_KEY); } - const CryptoWrappedKey private_key(provisioning_40_key_type_, - provisioning_40_wrapped_private_key_); - if (cert_type_ == kCertificateX509) { // Load csr private key to decrypt session key - auto status = crypto_session_->LoadCertificatePrivateKey(private_key); + auto status = + crypto_session_->LoadCertificatePrivateKey(prov40_wrapped_private_key_); if (status != NO_ERROR) { LOGE("Failed to load x509 certificate."); return status; @@ -582,9 +760,8 @@ CdmResponseType CertificateProvisioning::HandleProvisioning40Response( const std::string& signature = signed_response.signature(); const std::string& core_message = signed_response.oemcrypto_core_message(); status = crypto_session_->LoadProvisioningCast( - signed_response.session_key(), provisioning_request_message_, - response_message, core_message, signature, - &cast_cert_private_key.key()); + signed_response.session_key(), prov40_request_, response_message, + core_message, signature, &cast_cert_private_key.key()); if (status != NO_ERROR) { LOGE("Failed to generate wrapped key for cast cert."); return status; @@ -594,11 +771,79 @@ CdmResponseType CertificateProvisioning::HandleProvisioning40Response( *cert = device_certificate; *wrapped_key = cast_cert_private_key.key(); + state_ = is_oem_prov_response ? kOemResponseReceived : kDrmResponseReceived; + prov40_wrapped_private_key_.Clear(); + prov40_public_key_.clear(); return CdmResponseType(NO_ERROR); } + // Verify that the response contains the same key as the request. + // It is possible that multiple requests were generated, the CDM + // can only accept the response from the most recently generated + // one. + // + // Check the first few bytes to determine the type of message. + // OEM responses: + // ASN.1 DER encoded ContentInfo (containing an X.509 certificate). + // DRM responses: + // Protobuf SignedDrmCertificate + if (is_oem_prov_response) { + // Here |device_certificate| (haystack) is an X.509 cert chain, and + // |prov40_public_key_| (needle) is a SubjectPublicKeyInfo. + // The cert chain should contain a byte-for-byte copy of the + // public key. + // TODO(b/391469176): Use RSA/ECC key loading to detected mismatched + // keys. + if (!StringContains(/* haystack = */ device_certificate, + /* needle */ prov40_public_key_)) { + LOGD("OEM response is stale"); + return CdmResponseType(PROVISIONING_4_STALE_RESPONSE); + } + } else { // Is DRM response + video_widevine::SignedDrmCertificate signed_certificate; + if (!signed_certificate.ParseFromString(device_certificate)) { + // Check if ASN.1 like. + if (IsAsn1DerSequenceLike(device_certificate)) { + // This might be a late OEM certificate response + // generated from before the DRM response was received. + LOGD("Received late OEM certificate response"); + return CdmResponseType(PROVISIONING_4_STALE_RESPONSE); + } + LOGE("Unable to parse Signed DRM certificate"); + return CdmResponseType(PROVISIONING_4_FAILED_TO_VERIFY_CERT_KEY); + } + video_widevine::DrmCertificate drm_certificate; + if (!drm_certificate.ParseFromString( + signed_certificate.drm_certificate())) { + LOGE("Unable to parse DRM certificate"); + return CdmResponseType(PROVISIONING_4_FAILED_TO_VERIFY_CERT_KEY); + } + // The sent public key is of the format SubjectPublicKeyInfo; + // however, the received format is RSAPublicKey (RSA only) or + // SubjectPublicKeyInfo (ECC, and future RSA). + // Here |prov40_public_key_| (haystack) is SubjectPublicKeyInfo, + // and |drm_certificate.public_key()| (needle) may be + // SubjectPublicKeyInfo or RSAPublicKey. + // If the DRM cert's public key is in SubjectPublicKeyInfo format + // it should be a byte-for-byte copy. If the DRM cert's public key + // is RSAPublicKey format then hopefully a byte-for-byte copy is + // found within the SubjectPublicKeyInfo. Note: SubjectPublicKeyInfo + // containing an RSA public key uses RSAPublicKey to store the + // key fields. + // TODO(b/391469176): Use RSA/ECC key loading to detected mismatched + // keys. + if (!StringContains(/* haystack = */ prov40_public_key_, + /* needle = */ drm_certificate.public_key())) { + // This might be a response from a previously generated DRM + // certificate response. + LOGD("DRM response is stale"); + return CdmResponseType(PROVISIONING_4_STALE_RESPONSE); + } + } + // Can clear the |prov40_public_key_| after validating. + prov40_public_key_.clear(); + const CdmSecurityLevel security_level = crypto_session_->GetSecurityLevel(); - CloseSession(); wvutil::FileSystem global_file_system; DeviceFiles global_file_handle(&global_file_system); if (!global_file_handle.Init(security_level)) { @@ -608,10 +853,21 @@ CdmResponseType CertificateProvisioning::HandleProvisioning40Response( // Check the stage of the provisioning by checking if an OEM cert is already // stored in the file system. - if (!global_file_handle.HasOemCertificate()) { + if (is_oem_prov_response) { + if (global_file_handle.HasOemCertificate()) { + // Possible that concurrent apps were generated provisioning + // requests, and this one arrived after an other one. + LOGI("CDM has already received an OEM certificate"); + CloseSession(); + state_ = kOemResponseReceived; + prov40_wrapped_private_key_.Clear(); + prov40_public_key_.clear(); + return CdmResponseType(NO_ERROR); + } + // No OEM cert already stored => the response is expected to be an OEM cert. if (!global_file_handle.StoreOemCertificate(device_certificate, - private_key)) { + prov40_wrapped_private_key_)) { LOGE("Failed to store provisioning 4 OEM certificate"); return CdmResponseType(PROVISIONING_4_FAILED_TO_STORE_OEM_CERTIFICATE); } @@ -629,20 +885,27 @@ CdmResponseType CertificateProvisioning::HandleProvisioning40Response( LOGW("Failed to extract system id from OEM certificate"); } } - } else { - // The response is assumed to be an DRM cert. - DeviceFiles per_origin_file_handle(file_system); - if (!per_origin_file_handle.Init(security_level)) { - LOGE("Failed to initialize per-origin DeviceFiles"); - return CdmResponseType( - PROVISIONING_4_FAILED_TO_INITIALIZE_DEVICE_FILES_3); - } - if (!per_origin_file_handle.StoreCertificate(device_certificate, - private_key)) { - LOGE("Failed to store provisioning 4 DRM certificate"); - return CdmResponseType(PROVISIONING_4_FAILED_TO_STORE_DRM_CERTIFICATE); - } + CloseSession(); + state_ = kOemResponseReceived; + prov40_wrapped_private_key_.Clear(); + prov40_public_key_.clear(); + return CdmResponseType(NO_ERROR); } + // The response is assumed to be a DRM cert. + DeviceFiles per_origin_file_handle(file_system); + if (!per_origin_file_handle.Init(security_level)) { + LOGE("Failed to initialize per-origin DeviceFiles"); + return CdmResponseType(PROVISIONING_4_FAILED_TO_INITIALIZE_DEVICE_FILES_3); + } + if (!per_origin_file_handle.StoreCertificate(device_certificate, + prov40_wrapped_private_key_)) { + LOGE("Failed to store provisioning 4 DRM certificate"); + return CdmResponseType(PROVISIONING_4_FAILED_TO_STORE_DRM_CERTIFICATE); + } + CloseSession(); + state_ = kDrmResponseReceived; + prov40_wrapped_private_key_.Clear(); + prov40_public_key_.clear(); return CdmResponseType(NO_ERROR); } @@ -692,6 +955,15 @@ CdmResponseType CertificateProvisioning::HandleProvisioningResponse( wrapped_key); } + if (state_ == kDrmResponseReceived) { + LOGD("Response already received"); + return CdmResponseType(NO_ERROR); + } + if (state_ != kDrmRequestSent) { + LOGE("Not expecting a response: state = %s", StateToString(state_)); + return CdmResponseType(PROVISIONING_UNEXPECTED_RESPONSE_ERROR); + } + bool error = false; if (!signed_response.has_signature()) { LOGE("Signed response does not have signature"); @@ -767,6 +1039,7 @@ CdmResponseType CertificateProvisioning::HandleProvisioningResponse( if (cert_type_ == kCertificateX509) { *cert = device_cert_data; *wrapped_key = private_key.key(); + state_ = kDrmResponseReceived; return CdmResponseType(NO_ERROR); } @@ -813,6 +1086,7 @@ CdmResponseType CertificateProvisioning::HandleProvisioningResponse( return CdmResponseType(CERT_PROVISIONING_RESPONSE_ERROR_8); } + state_ = kDrmResponseReceived; return CdmResponseType(NO_ERROR); } diff --git a/core/src/client_identification.cpp b/core/src/client_identification.cpp index d3c4817d..06b5d250 100644 --- a/core/src/client_identification.cpp +++ b/core/src/client_identification.cpp @@ -15,6 +15,7 @@ namespace wvcdm { namespace { +// These keys come form the Widevine License Exchange Protocol. const std::string kKeyCompanyName = "company_name"; const std::string kKeyModelName = "model_name"; const std::string kKeyModelYear = "model_year"; @@ -27,10 +28,14 @@ const std::string kKeyOemCryptoSecurityPatchLevel = "oem_crypto_security_patch_level"; const std::string kKeyOemCryptoBuildInformation = "oem_crypto_build_information"; +// CDM uses "form_factor", though documentation may refer to this +// as "device_type". +const std::string kKeyFormFactor = "form_factor"; +const std::string kKeyPlatformName = "platform_name"; // These client identification keys are used by the CDM for relaying // important device information that cannot be overwritten by the app. -const std::array kReservedProperties = { +const std::array kReservedProperties = { kKeyCompanyName, kKeyModelName, kKeyModelYear, @@ -41,6 +46,8 @@ const std::array kReservedProperties = { kKeyWvCdmVersion, kKeyOemCryptoSecurityPatchLevel, kKeyOemCryptoBuildInformation, + kKeyFormFactor, + kKeyPlatformName, // TODO(b/148813171,b/142280599): include "origin" and "application_name" // to this list once collection of this information has been moved // to the core CDM. @@ -213,6 +220,16 @@ CdmResponseType ClientIdentification::Prepare( client_info->set_name(kKeyWvCdmVersion); client_info->set_value(value); } + if (Properties::GetPlatform(&value)) { + client_info = client_id->add_client_info(); + client_info->set_name(kKeyPlatformName); + client_info->set_value(value); + } + if (Properties::GetFormFactor(&value)) { + client_info = client_id->add_client_info(); + client_info->set_name(kKeyFormFactor); + client_info->set_value(value); + } client_info = client_id->add_client_info(); client_info->set_name(kKeyOemCryptoSecurityPatchLevel); client_info->set_value( diff --git a/core/src/crypto_session.cpp b/core/src/crypto_session.cpp index 195491c8..0ae8497a 100644 --- a/core/src/crypto_session.cpp +++ b/core/src/crypto_session.cpp @@ -356,6 +356,42 @@ CdmResponseType CryptoSession::GetProvisioningMethod( return CdmResponseType(NO_ERROR); } +CdmResponseType CryptoSession::LoadLicenseData(const std::string& data) { + RETURN_IF_UNINITIALIZED(CRYPTO_SESSION_NOT_INITIALIZED); + auto sts = WithOecSessionLock("LoadLicenseData", [&] { + return OEMCrypto_LoadLicenseData( + oec_session_id_, reinterpret_cast(data.data()), + data.size()); + }); + // level3_adapter may return this, and in case partners implement it, we + // ignore not-implemented errors. + if (sts == OEMCrypto_ERROR_NOT_IMPLEMENTED) sts = OEMCrypto_SUCCESS; + return MapOEMCryptoResult(sts, LOAD_KEY_ERROR, "LoadLicenseData"); +} + +CdmResponseType CryptoSession::SaveLicenseData(std::string* data) { + RETURN_IF_UNINITIALIZED(CRYPTO_SESSION_NOT_INITIALIZED); + size_t data_size = data->size(); + auto sts = WithOecSessionLock("SaveLicenseData - attempt 1", [&] { + return OEMCrypto_SaveLicenseData( + oec_session_id_, reinterpret_cast(data->data()), &data_size); + }); + if (sts == OEMCrypto_ERROR_SHORT_BUFFER) { + data->resize(data_size); + sts = WithOecSessionLock("SaveLicenseData - attempt 2", [&] { + return OEMCrypto_SaveLicenseData( + oec_session_id_, reinterpret_cast(data->data()), + &data_size); + }); + } + data->resize(data_size); + + // level3_adapter may return this, and in case partners implement it, we + // ignore not-implemented errors. + if (sts == OEMCrypto_ERROR_NOT_IMPLEMENTED) sts = OEMCrypto_SUCCESS; + return MapOEMCryptoResult(sts, STORE_LICENSE_ERROR_1, "SaveLicenseData"); +} + void CryptoSession::Init() { LOGV("Initializing crypto session"); bool initialized = false; @@ -2268,12 +2304,17 @@ bool CryptoSession::IsAntiRollbackHwPresent() { CdmResponseType CryptoSession::GenerateNonce(uint32_t* nonce) { RETURN_IF_NULL(nonce, PARAMETER_NULL); - OEMCryptoResult result; - WithOecWriteLock("GenerateNonce", [&] { - result = OEMCrypto_GenerateNonce(oec_session_id_, nonce); + // Some OEMCrypto implementation might modify the provided + // |nonce| value on failure (setting zero). + // Using an intermediate |temp_nonce| to protect against this. + uint32_t temp_nonce = 0; + const OEMCryptoResult result = WithOecWriteLock("GenerateNonce", [&] { + return OEMCrypto_GenerateNonce(oec_session_id_, &temp_nonce); }); metrics_->oemcrypto_generate_nonce_.Increment(result); - + if (result == OEMCrypto_SUCCESS) { + *nonce = temp_nonce; + } return MapOEMCryptoResult(result, NONCE_GENERATION_ERROR, "GenerateNonce"); } @@ -2406,7 +2447,8 @@ CdmResponseType CryptoSession::LoadProvisioningCast( CdmResponseType CryptoSession::GetHdcpCapabilities(HdcpCapability* current, HdcpCapability* max) { - LOGV("Getting HDCP capabilities: id = %u", oec_session_id_); + LOGV("Getting HDCP capabilities: security_level = %s", + RequestedSecurityLevelToString(requested_security_level_)); RETURN_IF_NOT_OPEN(CRYPTO_SESSION_NOT_OPEN); return GetHdcpCapabilities(requested_security_level_, current, max); } @@ -2414,8 +2456,8 @@ CdmResponseType CryptoSession::GetHdcpCapabilities(HdcpCapability* current, CdmResponseType CryptoSession::GetHdcpCapabilities( RequestedSecurityLevel security_level, HdcpCapability* current, HdcpCapability* max) { - LOGV("Getting HDCP capabilities: id = %u, security_level = %s", - oec_session_id_, RequestedSecurityLevelToString(security_level)); + LOGV("Getting HDCP capabilities: security_level = %s", + RequestedSecurityLevelToString(security_level)); RETURN_IF_UNINITIALIZED(CRYPTO_SESSION_NOT_INITIALIZED); RETURN_IF_NULL(current, PARAMETER_NULL); RETURN_IF_NULL(max, PARAMETER_NULL); @@ -2438,7 +2480,8 @@ CdmResponseType CryptoSession::GetHdcpCapabilities( bool CryptoSession::GetSupportedCertificateTypes( SupportedCertificateTypes* support) { - LOGV("Getting supported certificate types: id = %u", oec_session_id_); + LOGV("Getting supported certificate types: security_level = %s", + RequestedSecurityLevelToString(requested_security_level_)); RETURN_IF_UNINITIALIZED(false); RETURN_IF_NULL(support, false); const uint32_t oec_support = @@ -2456,8 +2499,8 @@ bool CryptoSession::GetSupportedCertificateTypes( CdmResponseType CryptoSession::GetNumberOfOpenSessions( RequestedSecurityLevel security_level, size_t* count) { - LOGV("Getting number of open sessions: id = %u, security_level = %s", - oec_session_id_, RequestedSecurityLevelToString(security_level)); + LOGV("Getting number of open sessions: security_level = %s", + RequestedSecurityLevelToString(security_level)); RETURN_IF_UNINITIALIZED(CRYPTO_SESSION_NOT_INITIALIZED); RETURN_IF_NULL(count, PARAMETER_NULL); @@ -2480,8 +2523,8 @@ CdmResponseType CryptoSession::GetNumberOfOpenSessions( CdmResponseType CryptoSession::GetMaxNumberOfSessions( RequestedSecurityLevel security_level, size_t* max) { - LOGV("Getting max number of sessions: id = %u, security_level = %s", - oec_session_id_, RequestedSecurityLevelToString(security_level)); + LOGV("Getting max number of sessions: security_level = %s", + RequestedSecurityLevelToString(security_level)); RETURN_IF_UNINITIALIZED(CRYPTO_SESSION_NOT_INITIALIZED); RETURN_IF_NULL(max, PARAMETER_NULL); @@ -3236,7 +3279,8 @@ CdmResponseType CryptoSession::MoveUsageEntry(UsageEntryIndex new_entry_index) { bool CryptoSession::GetAnalogOutputCapabilities(bool* can_support_output, bool* can_disable_output, bool* can_support_cgms_a) { - LOGV("Getting analog output capabilities: id = %u", oec_session_id_); + LOGV("Getting analog output capabilities: security_level = %s", + RequestedSecurityLevelToString(requested_security_level_)); RETURN_IF_UNINITIALIZED(false); const uint32_t flags = WithOecReadLock("GetAnalogOutputCapabilities", [&] { return OEMCrypto_GetAnalogOutputFlags(requested_security_level_); diff --git a/core/src/device_files.cpp b/core/src/device_files.cpp index c3ca252a..034d531a 100644 --- a/core/src/device_files.cpp +++ b/core/src/device_files.cpp @@ -709,17 +709,26 @@ bool DeviceFiles::RemoveCertificate() { RETURN_FALSE_IF_UNINITIALIZED() std::string certificate_file_name; - if (GetCertificateFileName(kCertificateLegacy, &certificate_file_name)) - RemoveFile(certificate_file_name); - if (GetCertificateFileName(kCertificateDefault, &certificate_file_name)) - return RemoveFile(certificate_file_name); - return true; + // Return true so long as at least one certificate was removed. + // This is to compliment the behavior of HasCertificate() which + // returns true if at least one certificate exists. + bool result = false; + if (GetCertificateFileName(kCertificateLegacy, &certificate_file_name)) { + LOGI("Removing legacy DRM cert"); + result |= RemoveFile(certificate_file_name); + } + if (GetCertificateFileName(kCertificateDefault, &certificate_file_name)) { + LOGI("Removing DRM cert"); + result |= RemoveFile(certificate_file_name); + } + return result; } bool DeviceFiles::RemoveOemCertificate() { RETURN_FALSE_IF_UNINITIALIZED() std::string certificate_file_name; if (GetOemCertificateFileName(&certificate_file_name)) { + LOGI("Removing OEM certificate"); return RemoveFile(certificate_file_name); } return true; @@ -881,6 +890,9 @@ bool DeviceFiles::StoreLicense(const CdmLicenseData& license_data, } license->set_usage_entry(license_data.usage_entry); license->set_usage_entry_index(license_data.usage_entry_index); + if (!license_data.exported_license_data.empty()) { + license->set_exported_license_data(license_data.exported_license_data); + } if (!license_data.drm_certificate.empty()) { DeviceCertificate* device_certificate = license->mutable_drm_certificate(); if (!SetDeviceCertificate(license_data.drm_certificate, @@ -974,6 +986,8 @@ bool DeviceFiles::RetrieveLicense(const std::string& key_set_id, license_data->usage_entry_index = static_cast(license.usage_entry_index()); + license_data->exported_license_data = license.exported_license_data(); + if (!license.has_drm_certificate()) { license_data->drm_certificate.clear(); license_data->wrapped_private_key.Clear(); @@ -2026,8 +2040,6 @@ DeviceFiles::ResponseType DeviceFiles::StoreFileRaw( DeviceFiles::ResponseType DeviceFiles::RetrieveHashedFile( const std::string& name, video_widevine_client::sdk::File* deserialized_file) { - std::string serialized_file; - if (deserialized_file == nullptr) { LOGE("File handle parameter |deserialized_file| not provided"); return kParameterNull; diff --git a/core/src/device_files.proto b/core/src/device_files.proto index caf93563..dcf64aa0 100644 --- a/core/src/device_files.proto +++ b/core/src/device_files.proto @@ -81,6 +81,12 @@ message License { optional bytes usage_entry = 12; optional int64 usage_entry_index = 13; optional DeviceCertificate drm_certificate = 14; + + // OEMCrypto-specific data for the license that will need to be loaded later + // to be able to use the license. For example, encryption keys that are + // associated with the license. + // Currently only used for the L3. + optional bytes exported_license_data = 15; } message UsageInfo { diff --git a/core/src/oemcrypto_adapter_static.cpp b/core/src/oemcrypto_adapter_static.cpp index 1832a4c2..0725dd97 100644 --- a/core/src/oemcrypto_adapter_static.cpp +++ b/core/src/oemcrypto_adapter_static.cpp @@ -243,3 +243,13 @@ WEAK OEMCryptoResult OEMCrypto_UseSecondaryKey(OEMCrypto_SESSION, bool) { WEAK OEMCryptoResult OEMCrypto_MarkOfflineSession(OEMCrypto_SESSION) { return OEMCrypto_ERROR_NOT_IMPLEMENTED; } + +WEAK OEMCryptoResult OEMCrypto_LoadLicenseData(OEMCrypto_SESSION, + const uint8_t*, size_t) { + return OEMCrypto_ERROR_NOT_IMPLEMENTED; +} + +WEAK OEMCryptoResult OEMCrypto_SaveLicenseData(OEMCrypto_SESSION, uint8_t*, + size_t*) { + return OEMCrypto_ERROR_NOT_IMPLEMENTED; +} diff --git a/core/src/wv_cdm_types.cpp b/core/src/wv_cdm_types.cpp index a42f3d11..57cab6ec 100644 --- a/core/src/wv_cdm_types.cpp +++ b/core/src/wv_cdm_types.cpp @@ -891,6 +891,12 @@ const char* CdmResponseEnumToString(CdmResponseEnum cdm_response_enum) { return "SESSION_NOT_FOUND_24"; case GET_BCC_SIGNATURE_TYPE_ERROR: return "GET_BCC_SIGNATURE_TYPE_ERROR"; + case PROVISIONING_UNEXPECTED_RESPONSE_ERROR: + return "PROVISIONING_UNEXPECTED_RESPONSE_ERROR"; + case PROVISIONING_4_STALE_RESPONSE: + return "PROVISIONING_4_STALE_RESPONSE"; + case PROVISIONING_4_FAILED_TO_VERIFY_CERT_KEY: + return "PROVISIONING_4_FAILED_TO_VERIFY_CERT_KEY"; } return UnknownValueRep(cdm_response_enum); } diff --git a/core/test/cdm_usage_table_unittest.cpp b/core/test/cdm_usage_table_unittest.cpp index ef5d0bf0..f9db0b84 100644 --- a/core/test/cdm_usage_table_unittest.cpp +++ b/core/test/cdm_usage_table_unittest.cpp @@ -140,6 +140,7 @@ const CdmUsageEntryInfo kDummyUsageEntryInfo = { /* offline_license_expiry_time = */ kDefaultExpireDuration}; const std::vector kEmptyLicenseList; +const std::string kNoExportedData; const std::string kLicenseArray[] = { kUsageEntryInfoOfflineLicense1.key_set_id, @@ -1892,7 +1893,8 @@ TEST_F(CdmUsageTableTest, kUsageEntry, static_cast(3) /* Mismatch */, kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense3.key_set_id, NotNull(), NotNull())) @@ -1914,7 +1916,8 @@ TEST_F(CdmUsageTableTest, kUsageEntry, static_cast(2) /* Mismatch */, kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense2.key_set_id, NotNull(), NotNull())) @@ -2079,7 +2082,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_LastEntriesAreStorageTypeUnknown) { kUsageEntry, static_cast(3), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense2.key_set_id, NotNull(), NotNull())) @@ -2180,7 +2184,8 @@ TEST_F(CdmUsageTableTest, kUsageEntry, /* usage_entry_index = */ 4, kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense3.key_set_id, NotNull(), NotNull())) @@ -2356,7 +2361,8 @@ TEST_F(CdmUsageTableTest, kUsageEntry, 4, kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense3.key_set_id, NotNull(), NotNull())) @@ -2382,7 +2388,8 @@ TEST_F(CdmUsageTableTest, kUsageEntry, 3, kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense2.key_set_id, NotNull(), NotNull())) @@ -2573,7 +2580,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_LastEntryIsOffline) { kUsageEntry, static_cast(4), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense3.key_set_id, NotNull(), NotNull())) @@ -2602,7 +2610,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_LastEntryIsOffline) { kUsageEntry, static_cast(3), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense2.key_set_id, NotNull(), NotNull())) @@ -2824,7 +2833,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_LastEntriesAreOfflineAndUnknknown) { kUsageEntry, static_cast(4), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense3.key_set_id, NotNull(), NotNull())) @@ -2853,7 +2863,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_LastEntriesAreOfflineAndUnknknown) { kUsageEntry, static_cast(3), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense2.key_set_id, NotNull(), NotNull())) @@ -3127,7 +3138,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_MaxSessionReached) { kUsageEntry, static_cast(1), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense1.key_set_id, NotNull(), NotNull())) @@ -3206,7 +3218,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_FirstEntry_MaxSessionReached) { kUsageEntry, static_cast(1), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense1.key_set_id, NotNull(), NotNull())) @@ -3283,7 +3296,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_SystemInvalidation_OnMove) { kUsageEntry, static_cast(1), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense1.key_set_id, NotNull(), NotNull())) @@ -3365,7 +3379,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_SessionInvalidation_OnMove) { kUsageEntry, static_cast(1), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense1.key_set_id, NotNull(), NotNull())) @@ -3444,7 +3459,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_ShrinkFails) { kUsageEntry, static_cast(1), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense1.key_set_id, NotNull(), NotNull())) @@ -3536,7 +3552,8 @@ TEST_F(CdmUsageTableTest, InvalidateEntry_DestinationInUse_OnMove) { kUsageEntry, static_cast(1), kDrmCertificate, - kCryptoWrappedKey}; + kCryptoWrappedKey, + kNoExportedData}; EXPECT_CALL(*device_files_, RetrieveLicense(kUsageEntryInfoOfflineLicense1.key_set_id, NotNull(), NotNull())) diff --git a/core/test/certificate_provisioning_unittest.cpp b/core/test/certificate_provisioning_unittest.cpp index fcc62117..73ace83a 100644 --- a/core/test/certificate_provisioning_unittest.cpp +++ b/core/test/certificate_provisioning_unittest.cpp @@ -413,6 +413,9 @@ TEST_P(CertificateProvisioningTest, ProvisioningRequestFailsEmptySignature) { TEST_P(CertificateProvisioningTest, ProvisioningResponseFailsWithEmptyResponse) { certificate_provisioning_->Init(""); + // Must set state if not generating request. + certificate_provisioning_->SetStateForTesting( + CertificateProvisioning::kDrmRequestSent); MockFileSystem file_system; std::string certificate; @@ -425,6 +428,9 @@ TEST_P(CertificateProvisioningTest, TEST_P(CertificateProvisioningTest, ProvisioningResponseFailsIfDeviceIsRevoked) { certificate_provisioning_->Init(""); + // Must set state if not generating request. + certificate_provisioning_->SetStateForTesting( + CertificateProvisioning::kDrmRequestSent); MockFileSystem file_system; std::string response_certificate; @@ -445,6 +451,10 @@ TEST_P(CertificateProvisioningTest, TEST_P(CertificateProvisioningTest, ProvisioningResponseSuccess) { certificate_provisioning_->Init(""); + // Must set state if not generating request. + certificate_provisioning_->SetStateForTesting( + CertificateProvisioning::kDrmRequestSent); + std::string expected_certificate; std::string response; ASSERT_TRUE(MakeSignedDrmCertificate(kFakePublicKey, kSerialNumber, kSystemId, diff --git a/core/test/core_integration_test.cpp b/core/test/core_integration_test.cpp index ebcbae1f..29e43cbe 100644 --- a/core/test/core_integration_test.cpp +++ b/core/test/core_integration_test.cpp @@ -301,4 +301,539 @@ TEST_F(CoreIntegrationTest, NeedKeyBeforeLicenseLoad) { EXPECT_EQ(NEED_KEY, holder.Decrypt(key_id)); ASSERT_NO_FATAL_FAILURE(holder.CloseSession()); } + +class Prov40IntegrationTest : public WvCdmTestBaseWithEngine { + public: + void SetUp() override { + WvCdmTestBaseWithEngine::SetUp(); + // Ensure CDM is operating using Provisioning 4.0. + std::string prov_model; + CdmResponseType status = cdm_engine_.QueryStatus( + kLevelDefault, QUERY_KEY_PROVISIONING_MODEL, &prov_model); + ASSERT_EQ(status, NO_ERROR) << "Failed to determine provisioning model"; + if (prov_model != QUERY_VALUE_BOOT_CERTIFICATE_CHAIN) { + GTEST_SKIP() << "Test is for Prov4.0 only"; + return; + } + // Ensure CDM is not provisioned. + if (IsProvisioned()) { + status = cdm_engine_.Unprovision(kSecurityLevelL1); + ASSERT_EQ(status, NO_ERROR) << "Failed to unprovision DRM cert"; + status = cdm_engine_.UnprovisionOemCert(kSecurityLevelL1); + ASSERT_EQ(status, NO_ERROR) << "Failed to unprovision OEM cert"; + ASSERT_EQ(GetProvisioningStatus(), kNeedsOemCertProvisioning); + } + } + + CdmProvisioningStatus GetProvisioningStatus() { + return cdm_engine_.GetProvisioningStatus(kSecurityLevelL1); + } + + bool IsProvisioned() { return cdm_engine_.IsProvisioned(kSecurityLevelL1); } + + void PreDrmProvisioningCheck() { + ASSERT_EQ(GetProvisioningStatus(), kNeedsOemCertProvisioning) + << "Not in valid state for pre DRM provisioning check"; + ProvisioningHolder provisioner(&cdm_engine_, config_); + // OEM provisioning. + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "OEM Certificate provisioning attempt failed"; + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning) + << "OEM Certificate provisioning was not completed"; + } + + void PostIncompleteOemProvisioningCheck() { + ASSERT_EQ(GetProvisioningStatus(), kNeedsOemCertProvisioning) + << "Not in valid state for post incomplete OEM provisioning check"; + ProvisioningHolder provisioner(&cdm_engine_, config_); + // OEM provisioning. + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "OEM Certificate provisioning attempt failed"; + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning) + << "OEM Certificate provisioning was not completed"; + // DRM provisioning. + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "DRM Certificate provisioning attempt failed"; + ASSERT_EQ(GetProvisioningStatus(), kProvisioned) + << "DRM Certificate provisioning was not completed"; + // Remaining is the same as post DRM provisioning. + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()) + << "Failed post incomplete OEM provisioning check after DRM " + "provisioning"; + } + + void PostOemProvisioningCheck() { + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning) + << "Not in valid state for post OEM provisioning check"; + ProvisioningHolder provisioner(&cdm_engine_, config_); + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "DRM Certificate provisioning failed"; + ASSERT_EQ(GetProvisioningStatus(), kProvisioned) + << "DRM Certificate provisioning was not completed"; + // Remaining is the same as post DRM provisioning. + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()) + << "Failed post OEM provisioning check after DRM provisioning"; + } + + void PostIncompleteDrmProvisioningCheck() { + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning) + << "Not in valid state for post incomplete DRM provisioning check"; + ProvisioningHolder provisioner(&cdm_engine_, config_); + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "DRM Certificate provisioning failed"; + ASSERT_EQ(GetProvisioningStatus(), kProvisioned) + << "DRM Certificate provisioning was not completed"; + // Remaining is the same as post DRM provisioning. + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()) + << "Failed post incomplete DRM provisioning check after DRM " + "provisioning"; + } + + void PostDrmProvisioningCheck() { + ASSERT_EQ(GetProvisioningStatus(), kProvisioned) + << "Not in valid state for post DRM provisioning check"; + LicenseHolder holder("CDM_Streaming", &cdm_engine_, config_); + ASSERT_NO_FATAL_FAILURE(holder.OpenSession()); + ASSERT_NO_FATAL_FAILURE(holder.FetchLicense()); + ASSERT_NO_FATAL_FAILURE(holder.LoadLicense()); + ASSERT_NO_FATAL_FAILURE(holder.CloseSession()); + } +}; // class Prov40IntegrationTest + +// Expected flow of an app; 1 OEM request-response, 1 DRM request-response. +// +// Case: OemReq1, OemResp1, DrmReq1, DrmResp1 +// +// Notes: +// This is Widevine's expected behavior by an app. +// +// Post-Case: Load license +TEST_F(Prov40IntegrationTest, UsualOrder_LoadOem1_LoadDrm1) { + ProvisioningHolder provisioner(&cdm_engine_, config_); + + ASSERT_EQ(GetProvisioningStatus(), kNeedsOemCertProvisioning); + + // Round 1 - OEM provisioning (OemReq1, OemResp1). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "OEM Certificate provisioning failed"; + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + // Round 2 - DRM provisioning (DrmReq1, DrmResp1). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "DRM Certificate provisioning failed"; + ASSERT_EQ(GetProvisioningStatus(), kProvisioned); + + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()); +} + +// Case: OemReq1, OemReq2, OemResp1 (OemResp2 is never acquired) +// Expectation: +// CDM handles OemResp1, but does not complete OEM +// provisioning. +// +// Notes: +// This is undesirable behavior by the app, but can be partially +// handle by the CDM. +// Apps that encounter this situation are likely generating many +// provisioning requests and loading them in whatever order they +// arrive. +// +// Post-Case: OEM provisioning, DRM provisioning, load license +TEST_F(Prov40IntegrationTest, UnusualOrder_DropOem2_LoadOem1) { + ProvisioningHolder provisioner(&cdm_engine_, config_); + + // OEM provisioning. + // Generate first request (OemReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string oem_request1 = provisioner.request(); + + // Generate second request (OemReq2). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + // Never send for the second request. + + // Use first request for fetching/loading response (OemResp1). + // CDM may or may not return an error, but OEM provisioning is still + // needed. + provisioner.set_request(oem_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + ASSERT_EQ(GetProvisioningStatus(), kNeedsOemCertProvisioning); + + ASSERT_NO_FATAL_FAILURE(PostIncompleteOemProvisioningCheck()); +} + +// Case: OemReq1, OemReq2, OemResp2 (OemResp1 is never acquired) +// Expectation: +// CDM handles OemReq2 (NO_ERROR), and OEM provisioning is +// completed. +// +// Notes: +// This is OK behavior by the app. +// Only the OEM response from the most recent OEM request will +// complete provisioning. +// +// Post-Case: OEM provisioning, DRM provisioning, load license +TEST_F(Prov40IntegrationTest, UnusualOrder_DropOem1_LoadOem2) { + ProvisioningHolder provisioner(&cdm_engine_, config_); + + // OEM provisioning. + // Generate first request (OemReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + // Never send for the first request. + + // Generate, fetch and load second request (OemReq2, OemResp2). + // This should complete OEM provisioning. + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "OEM Certificate provisioning failed"; + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + ASSERT_NO_FATAL_FAILURE(PostOemProvisioningCheck()); +} + +// Case: OemReq1, OemReq2, OemResp1, OemResp2 +// Expectation: +// OemResp1 is handled by the CDM, but does not complete +// provisioning. OemResp2 is accepted by the CDM +// and completes provisioning. +// +// Notes: +// This is undesirable behavior by the app, but can be partially +// handle by the CDM. +// Only the OEM response from the most recent OEM request will +// complete provisioning. +// +// Post-Case: DRM provisioning, load license +TEST_F(Prov40IntegrationTest, UnusualOrder_LoadOem1_LoadOem2) { + ProvisioningHolder provisioner(&cdm_engine_, config_); + + // OEM provisioning. + // Generate first request, store it for later (OemReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string oem_request1 = provisioner.request(); + + // Generate second request, store it for later (OemReq2). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string oem_request2 = provisioner.request(); + + // Use first request for fetching/loading response (OemResp1). + // CDM may or may not return an error, but OEM provisioning is still + // needed. + provisioner.set_request(oem_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + ASSERT_EQ(GetProvisioningStatus(), kNeedsOemCertProvisioning); + + // Use second request for fetching/loading response (OemResp2). + // CDM should accept the second response as valid (so long as + // a third was not generated). + provisioner.set_request(oem_request2); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + ASSERT_NO_FATAL_FAILURE(provisioner.LoadResponse(binary_provisioning_)) + << "OEM Certificate provisioning failed"; + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + ASSERT_NO_FATAL_FAILURE(PostOemProvisioningCheck()); +} + +// Case: OemReq1, OemReq2, OemResp2, OemResp1 +// Expectation: +// OemResp2 is accepted by the CDM and comletes OEM provisioning. +// OemResp1 does not cause the CDM to be corrupted. +// +// Notes: +// This is undesirable behavior by the app, cannot be handle +// by the CDM. +// In single-staged provisioning, the CDM silently drops +// any additional provisioning responses; but in two-stage +// this cannot easily by determine that the response is a +// late OEM response. +// +// Post-Case: DRM provisioning, load license +TEST_F(Prov40IntegrationTest, UnusualOrder_LoadOem2_LoadOem1) { + ProvisioningHolder provisioner(&cdm_engine_, config_); + + // OEM provisioning. + // Generate first request, store it for later (OemReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string oem_request1 = provisioner.request(); + + // Generate, fetch and load second request (OemReq2, OemResp2). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "OEM Certificate provisioning failed"; + // Provisioning should be complete. + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + // Use first request for fetching/loading response (OemResp1). + // CDM may or may not return an error, but DRM provisioning + // should still be allowed after. + provisioner.set_request(oem_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + // Should not effect existing provisioning state. + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning) + << "Late OEM Certificate response invalidated original response"; + + ASSERT_NO_FATAL_FAILURE(PostOemProvisioningCheck()); +} + +// Case: DrmReq1, DrmReq2, DrmResp1, (DrmResp2 is never acquired) +// Expectation: +// DrmResp1 is handled by the CDM, but does not complete +// provisioning. +// +// Notes: +// This is undesirable behavior by the app, but can be partially +// handle by the CDM. +// Apps that encounter this situation are likely generating many +// provisioning requests and loading them in whatever order they +// arrive. +// For single-stage, this situation usually returns a signature +// failure. +// +// Pre-Case: OEM provisioning +// Post-Case: DRM provisioning, load license +TEST_F(Prov40IntegrationTest, UnusualOrder_DropDrm2_LoadDrm1) { + ASSERT_NO_FATAL_FAILURE(PreDrmProvisioningCheck()); + + ProvisioningHolder provisioner(&cdm_engine_, config_); + // DRM provisioning. + // Generate first request, store it for later (DrmReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string drm_request1 = provisioner.request(); + + // Generate second request (DrmReq2). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + // Never send for the second request. + + // Use first request for fetching/loading response (DrmResp1). + // CDM may or may not return an error, but DRM provisioning is still + // needed. + provisioner.set_request(drm_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + ASSERT_NO_FATAL_FAILURE(PostIncompleteDrmProvisioningCheck()); +} + +// Case: DrmReq1, DrmReq2, DrmResp2 (DrmResp1 is never acquired) +// Expectation: +// CDM accepts DrmReq2 (NO_ERROR), and DRM provisioning is +// completed. +// +// Notes: +// This is OK behavior by the app. +// Only the DRM response from the most recent DRM request will +// complete provisioning. +// +// Pre-Case: OEM provisioning +// Post-Case: Load license +TEST_F(Prov40IntegrationTest, UnusualOrder_DropDrm1_LoadDrm2) { + ASSERT_NO_FATAL_FAILURE(PreDrmProvisioningCheck()); + + ProvisioningHolder provisioner(&cdm_engine_, config_); + // DRM provisioning. + // Generate first request (DrmReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + // Never send for the first request. + + // Generate, fetch and load second request (DrmReq2, DrmResp2). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "DRM Certificate provisioning failed"; + ASSERT_TRUE(IsProvisioned()); + + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()); +} + +// Case: DrmReq1, DrmReq2, DrmResp1, DrmResp2 +// Expectation: +// DrmResp1 is handled by the CDM, but does not complete +// provisioning. DrmResp2 is accepted by the CDM and +// completes provisioning. +// +// Notes: +// This is undesirable behavior by the app, but can be partially +// handle by the CDM. +// Only the DRM response from the most recent DRM request will +// complete provisioning. +// +// Pre-Case: OEM provisioning +// Post-Case: Load license +TEST_F(Prov40IntegrationTest, UnusualOrder_LoadDrm1_LoadDrm2) { + ASSERT_NO_FATAL_FAILURE(PreDrmProvisioningCheck()); + + ProvisioningHolder provisioner(&cdm_engine_, config_); + // DRM provisioning. + // Generate first request, store it for later (DrmReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string drm_request1 = provisioner.request(); + + // Generate second request, store it for later (DrmReq2). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string drm_request2 = provisioner.request(); + + // Use first request for fetching/loading response (DrmResp1). + // CDM may or may not return an error, but DRM provisioning is still + // needed. + provisioner.set_request(drm_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + // Use second request for fetching/loading response (DrmResp2). + // CDM should accept the second response as valid (so long as + // a third was not generated). + provisioner.set_request(drm_request2); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + ASSERT_NO_FATAL_FAILURE(provisioner.LoadResponse(binary_provisioning_)) + << "DRM Certificate provisioning failed"; + ASSERT_TRUE(IsProvisioned()); + + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()); +} + +// Case: DrmReq1, DrmReq2, DrmResp2, DrmResp1 +// Expectation: +// DrmResp2 is accepted by the CDM (NO_ERROR) and completes +// provisioning. DrmResp1 is handled by the CDM, but is dropped +// without causing issues with existing certificates. +// +// Notes: +// This is undesirable behavior by the app, but can be partially +// handle by the CDM. +// +// Pre-Case: OEM provisioning +// Post-Case: Load license +TEST_F(Prov40IntegrationTest, UnusualOrder_LoadDrm2_LoadDrm1) { + ASSERT_NO_FATAL_FAILURE(PreDrmProvisioningCheck()); + + ProvisioningHolder provisioner(&cdm_engine_, config_); + // DRM provisioning. + // Generate first request, store it for later (DrmReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string drm_request1 = provisioner.request(); + + // Generate, fetch and load second request (DrmReq2, DrmResp2). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)) + << "DRM Certificate provisioning failed"; + ASSERT_TRUE(IsProvisioned()); + + // Use first request for fetching/loading response (DrmResp1). + // CDM may or may not return an error, and the CDM should still + // be considered provisioned. + provisioner.set_request(drm_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + // Should not effect existing provisioning state. + ASSERT_TRUE(IsProvisioned()) + << "Late DRM Certificate response invalidated original response"; + + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()); +} + +// Case: OemReq1, OemReq2, OemResp2, DrmReq1, OemResp1, DrmResp1 +// Expectation: +// OemResp2 will complete OEM provisioning, allowing the +// creation of DrmReq1. +// OemResp1 (being received after OEM provisioning is completed, +// and DRM provisioning initiated) is handled by the CDM +// and does not prevent the completion of DRM provisioning. +// +// +// Notes: +// This is undesirable behavior by the app, but can be partially +// handle by the CDM. +// Stale OEM responses should not interrupt DRM provisioning in +// progress. +// +// Post-Case: Load license +TEST_F(Prov40IntegrationTest, UnusualOrder_LoadOem2_LoadDrm1_LoadOem1AsDrm) { + ProvisioningHolder provisioner(&cdm_engine_, config_); + + // Round 1 - OEM provisioning. + // Generated and stored first OEM request (OemReq1) + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string oem_request1 = provisioner.request(); + + // Complete provisioning on the second attempt (OemReq2, OemResp2). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)); + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + // Round 2 - DRM provisioning. + // Generate DRM certificate request (DrmReq1). + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string drm_request1 = provisioner.request(); + + // Use OEM request 1 to get an OEM response (OemResp1). + // CDM should detect that the OEM response is no longer needed + // and should drop the response with or without errors. + provisioner.set_request(oem_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + // Should not effect existing provisioning state. + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + // Use DRM request 1 to get a DRM response (DrmResp1). + provisioner.set_request(drm_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + ASSERT_NO_FATAL_FAILURE(provisioner.LoadResponse(binary_provisioning_)) + << "Real DRM Certificate provisioning failed"; + ASSERT_TRUE(IsProvisioned()); + + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()); +} + +// Case: OemReq1, OemReq2, OemResp2, DrmReq1, DrmResp1, OemResp1 +// Expectation: +// OemResp2 will complete OEM provisioning, allowing the +// creation of DrmReq1. +// DrmResp1 will complete DRM provisioning. +// OemResp1 (being received after OEM provisioning is completed, +// and after DRM provisioning is complete) is handled by the CDM +// and does not cause any other issue. +// +// Notes: +// This is undesirable behavior by the app, but can be partially +// handle by the CDM. +// Any provisioning response received after DRM provisioning +// is completed is ignored. +// +// Post-Case: Load license +TEST_F(Prov40IntegrationTest, UnusualOrder_LoadOem2_LoadOem1AsDrm_LoadDrm1) { + ProvisioningHolder provisioner(&cdm_engine_, config_); + + // Round 1 - OEM provisioning. + // Generated and stored first OEM request (OemReq1) + ASSERT_NO_FATAL_FAILURE(provisioner.GenerateRequest(binary_provisioning_)); + const std::string oem_request1 = provisioner.request(); + + // Complete provisioning on the second attempt (OemReq2, OemResp2). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)); + ASSERT_EQ(GetProvisioningStatus(), kNeedsDrmCertProvisioning); + + // Round 2 - DRM provisioning (DrmReq1, DrmReq2). + ASSERT_NO_FATAL_FAILURE(provisioner.Provision(binary_provisioning_)); + ASSERT_TRUE(IsProvisioned()); + + // Use OEM request 1 to get an OEM response (OemResp2). + // CDM should detect that CDM is fully provisioned and should drop + // the response with or without errors. + provisioner.set_request(oem_request1); + ASSERT_NO_FATAL_FAILURE(provisioner.FetchResponse()); + // Do not enforce any particular error (including NO_ERROR). + provisioner.LoadResponseReturnStatus(binary_provisioning_); + // Should not effect existing provisioning state. + ASSERT_TRUE(IsProvisioned()) + << "Late OEM Certificate response invalidated DRM certificate"; + ; + + ASSERT_NO_FATAL_FAILURE(PostDrmProvisioningCheck()); +} } // namespace wvcdm diff --git a/core/test/device_files_unittest.cpp b/core/test/device_files_unittest.cpp index 575cf041..1ac97242 100644 --- a/core/test/device_files_unittest.cpp +++ b/core/test/device_files_unittest.cpp @@ -5231,7 +5231,8 @@ TEST_P(DeviceFilesStoreTest, StoreLicense) { kLicenseTestData[license_num].usage_entry_index, kLicenseTestData[license_num].drm_certificate, CryptoWrappedKey(kLicenseTestData[license_num].key_type, - kLicenseTestData[license_num].private_key)}; + kLicenseTestData[license_num].private_key), + /* exported_license_data= */ ""}; EXPECT_TRUE(device_files.StoreLicense(license_data, &sub_error_code)); EXPECT_EQ(DeviceFiles::kNoError, sub_error_code); } @@ -5300,7 +5301,8 @@ TEST_F(DeviceFilesTest, StoreLicenses) { kLicenseTestData[i].usage_entry_index, kLicenseTestData[i].drm_certificate, CryptoWrappedKey(kLicenseTestData[i].key_type, - kLicenseTestData[i].private_key)}; + kLicenseTestData[i].private_key), + /* exported_license_data= */ ""}; EXPECT_TRUE(device_files.StoreLicense(license_data, &sub_error_code)); EXPECT_EQ(DeviceFiles::kNoError, sub_error_code); } @@ -5466,7 +5468,8 @@ TEST_F(DeviceFilesTest, UpdateLicenseState) { kLicenseUpdateTestData[0].usage_entry_index, kLicenseUpdateTestData[0].drm_certificate, CryptoWrappedKey(kLicenseTestData[0].key_type, - kLicenseTestData[0].private_key)}; + kLicenseTestData[0].private_key), + /* exported_license_data= */ ""}; DeviceFiles::ResponseType sub_error_code; EXPECT_TRUE(device_files.StoreLicense(license_data, &sub_error_code)); EXPECT_EQ(DeviceFiles::kNoError, sub_error_code); diff --git a/core/test/provisioning_holder.cpp b/core/test/provisioning_holder.cpp index b84c69a8..a1536113 100644 --- a/core/test/provisioning_holder.cpp +++ b/core/test/provisioning_holder.cpp @@ -139,6 +139,54 @@ void ProvisioningHolder::LoadResponse(bool binary_provisioning) { if (log_core_message_) MessageDumper::PrintProvisioningResponse(response_); } +CdmResponseType ProvisioningHolder::LoadResponseReturnStatus( + bool binary_provisioning) { + // Preconditions. + if (response_.empty()) { + ADD_FAILURE() << "No response was fetched"; + return CdmResponseType(UNKNOWN_ERROR); + } + + std::string cdm_prov_response; + if (binary_provisioning) { + // CDM is expecting the response to be in binary form, response + // must be extracted and decoded. + std::string base_64_response; + if (!ExtractSignedMessage(response_, &base_64_response)) { + ADD_FAILURE() + << "Failed to extract signed serialized response from JSON response"; + return CdmResponseType(UNKNOWN_ERROR); + } + if (base_64_response.empty()) { + ADD_FAILURE() + << "Base64 encoded provisioning response is unexpectedly empty"; + return CdmResponseType(UNKNOWN_ERROR); + } + LOGV("Extracted response message: \n%s\n", base_64_response.c_str()); + + const std::vector response_vec = + wvutil::Base64SafeDecode(base_64_response); + if (response_vec.empty()) { + ADD_FAILURE() << "Failed to decode base64 response: " << base_64_response; + return CdmResponseType(UNKNOWN_ERROR); + } + cdm_prov_response.assign(response_vec.begin(), response_vec.end()); + } else { + cdm_prov_response = response_; + } + // HandleProvisioningResponse() may or may not succeed, + // left to caller to determine if this is considered a + // test failure. + const CdmResponseType status = cdm_engine_->HandleProvisioningResponse( + cdm_prov_response, kLevelDefault, &certificate_, &wrapped_key_); + if (status == NO_ERROR) { + // Only dump data if successful. + if (config_.dump_golden_data()) MessageDumper::DumpProvisioning(response_); + if (log_core_message_) MessageDumper::PrintProvisioningResponse(response_); + } + return status; +} + bool ProvisioningHolder::ExtractSignedMessage(const std::string& response, std::string* result) { static const std::string kMessageStart = "\"signedResponse\": \""; diff --git a/core/test/provisioning_holder.h b/core/test/provisioning_holder.h index 4618b273..c4bf0db2 100644 --- a/core/test/provisioning_holder.h +++ b/core/test/provisioning_holder.h @@ -40,7 +40,11 @@ class ProvisioningHolder { // JSON message of the response to |response_|. void FetchResponse(); + // Loads the response into the |cdm_engine_|, expecting success. void LoadResponse(bool binary_provisioning); + // Loads the response into the |cdm_engine_|, returning the + // result from CDM. + CdmResponseType LoadResponseReturnStatus(bool binary_provisioning); const std::string& request() const { return request_; } // Sets the request to be used on next call to FetchResponse(). diff --git a/factory_upload_tool/ce/Makefile b/factory_upload_tool/ce/Makefile new file mode 100644 index 00000000..e3397092 --- /dev/null +++ b/factory_upload_tool/ce/Makefile @@ -0,0 +1,200 @@ +# +# Copyright 2024 Google LLC. All Rights Reserved. This file and proprietary +# source code may only be used and distributed under the Widevine +# License Agreement. +# +CDM_DIR ?= $(shell pwd)/../.. + +# CROSS_COMPILE: prefix for cross compilers, e.g. arm-none-gnueabihf- +cc ?= $(CROSS_COMPILE)gcc +cxx ?= $(CROSS_COMPILE)g++ +ARCH ?= 64 +IS_ARM ?= 0 + +project := wv_factory_extractor +srcdir := $(shell realpath --relative-to=$(CURDIR) $(CDM_DIR)) +output = $(project) + +# Place outputs in $CDM_DIR/out/wv_factory_extractor/ +default_build_dir := $(CDM_DIR)/out/$(project) +# Check if OUT_DIR is set and not empty +ifeq ($(strip $(OUT_DIR)),) + builddir := $(default_build_dir)/ +else + builddir := $(OUT_DIR)/ +endif + +# All file locations are relative to the $CDM_DIR path. +REPO_TOP := + +# Path to the modified example_main file used by tests +MODIFIED_MAIN_SRC := example_main_modified.cpp + +# Conditionally add the source file based on its existence +ifneq ($(wildcard $(MODIFIED_MAIN_SRC)),) + extractor_sources += $(REPO_TOP)/factory_upload_tool/ce/$MODIFIED_MAIN_SRC + $(info Using modified source for tests) +else + extractor_sources += $(REPO_TOP)/factory_upload_tool/ce/example_main.cpp +endif + +# other sources +extractor_sources += \ + $(REPO_TOP)/factory_upload_tool/ce/log.cpp \ + $(REPO_TOP)/factory_upload_tool/ce/properties_ce.cpp \ + $(REPO_TOP)/factory_upload_tool/ce/wv_factory_extractor.cpp \ + $(REPO_TOP)/util/src/string_conversions.cpp \ + $(REPO_TOP)/factory_upload_tool/common/src/WidevineOemcryptoInterface.cpp \ + +extractor_includes := \ + $(REPO_TOP)/factory_upload_tool/common/include \ + $(REPO_TOP)/oemcrypto/include \ + $(REPO_TOP)/util/include + +srcs := $(extractor_sources) +incs := $(extractor_includes) + +ifdef USE_VALIDATOR + oemcrypto_util_dir := $(REPO_TOP)/oemcrypto/util + + validator_sources := \ + $(oemcrypto_util_dir)/src/bcc_validator.cpp \ + $(oemcrypto_util_dir)/src/cbor_validator.cpp \ + $(oemcrypto_util_dir)/src/device_info_validator.cpp \ + $(oemcrypto_util_dir)/src/prov4_validation_helper.cpp \ + $(oemcrypto_util_dir)/src/cmac.cpp \ + $(oemcrypto_util_dir)/src/oemcrypto_drm_key.cpp \ + $(oemcrypto_util_dir)/src/oemcrypto_ecc_key.cpp \ + $(oemcrypto_util_dir)/src/oemcrypto_key_deriver.cpp \ + $(oemcrypto_util_dir)/src/oemcrypto_oem_cert.cpp \ + $(oemcrypto_util_dir)/src/oemcrypto_rsa_key.cpp \ + $(oemcrypto_util_dir)/src/signed_csr_payload_validator.cpp \ + $(oemcrypto_util_dir)/src/wvcrc.cpp + + validator_includes := \ + $(oemcrypto_util_dir)/include + + cppbor_dir := $(REPO_TOP)/third_party/libcppbor + + cppbor_sources += \ + $(cppbor_dir)/src/cppbor.cpp \ + $(cppbor_dir)/src/cppbor_parse.cpp + + cppbor_includes += \ + $(cppbor_dir)/include \ + $(cppbor_dir)/include/cppbor + + include $(CDM_DIR)/third_party/boringssl/kit/sources.mk + boringssl_sources_raw += $(crypto_sources) + ifeq ($(ARCH), 64) + ifneq ($(IS_ARM), 0) + boringssl_sources_raw += $(linux_aarch64_sources) + else + boringssl_sources_raw += $(linux_x86_64_sources) + endif + else ifeq ($(ARCH), 32) + ifneq ($(IS_ARM), 0) + boringssl_sources_raw += $(linux_arm_sources) + else + boringssl_sources_raw += $(linux_x86_sources) + endif + else + $(error No known value for ARCH; assembly not included for BoringSSL) + endif + + boringssl_dir := $(REPO_TOP)/third_party/boringssl + + boringssl_sources += \ + $(addprefix $(boringssl_dir)/kit/, $(boringssl_sources_raw)) + + boringssl_includes += \ + $(boringssl_dir)/kit/src/include + + srcs += \ + $(validator_sources) \ + $(cppbor_sources) \ + $(boringssl_sources) + + incs += \ + $(validator_includes) \ + $(cppbor_includes) \ + $(boringssl_includes) +endif + +# flags +cflags += \ + -Wall \ + -Werror \ + -fPIC \ + -fsanitize=address \ + $(addprefix -I, $(includes)) + +cflags_c += \ + $(cflags) \ + -std=c11 \ + -D_POSIX_C_SOURCE=200809L \ + -fno-inline + +cppflags += \ + $(cflags) + +ifdef USE_VALIDATOR + cppflags += -DUSE_VALIDATOR +endif + +ldflags = \ + -ldl \ + -rdynamic \ + -fsanitize=address \ + +# make rules +ifneq ($V,1) +q := @ +cmd-echo := true +cmd-echo-silent := echo +else +q := +cmd-echo := echo +cmd-echo-silent := true +endif + +ssrc := $(patsubst %.S, %.o, $(filter %.S, $(srcs))) +csrc := $(patsubst %.c, %.o, $(filter %.c, $(srcs))) +cppsrc := $(patsubst %.cpp, %.o, $(filter %.cpp, $(srcs))) +ccsrc := $(patsubst %.cc, %.o, $(filter %.cc, $(srcs))) +objs := $(sort $(addprefix $(builddir), $(csrc) $(cppsrc) $(ccsrc) $(ssrc))) + +includes := $(addprefix $(srcdir), $(incs)) $(global-incs) + +.PHONY: all +all: $(builddir)$(output) + +$(builddir)$(output): $(objs) + @$(cmd-echo-silent) ' LD $@' + ${q}$(cxx) -o $@ $(objs) $(ldflags) + +$(builddir)%.o: $(srcdir)%.c + ${q}mkdir -p $(shell dirname $@) + @$(cmd-echo-silent) ' CC $@' + ${q}$(cc) $(cflags_c) -c $< -o $@ + +$(builddir)%.o: $(srcdir)%.cc + ${q}mkdir -p $(shell dirname $@) + @$(cmd-echo-silent) ' CPP $@' + ${q}$(cxx) $(cppflags) -c $< -o $@ + +$(builddir)%.o: $(srcdir)%.cpp + ${q}mkdir -p $(shell dirname $@) + @$(cmd-echo-silent) ' CPP $@' + $(cxx) $(cppflags) -c $< -o $@ + +$(builddir)%.o: $(srcdir)%.S + ${q}mkdir -p $(shell dirname $@) + @$(cmd-echo-silent) ' CC $@' + ${q}$(cc) $(cflags_c) -c $< -o $@ + +.PHONY: clean +clean: + @$(cmd-echo-silent) ' CLEAN $(builddir)' + ${q}rm -f $(objs) $(output) + @if [ -d $(builddir) ]; then rm -r $(builddir); fi diff --git a/factory_upload_tool/ce/properties_ce.cpp b/factory_upload_tool/ce/properties_ce.cpp index c922485c..b265272f 100644 --- a/factory_upload_tool/ce/properties_ce.cpp +++ b/factory_upload_tool/ce/properties_ce.cpp @@ -71,6 +71,16 @@ bool Properties::GetBuildInfo(std::string* build_info) { return GetValue(client_info_.build_info, build_info); } +// static +bool Properties::GetPlatform(std::string* /* platform */) { + return false; // No attempt for upload tool. +} + +// static +bool Properties::GetFormFactor(std::string* /* form_factor */) { + return false; // No attempt for upload tool. +} + // static bool Properties::GetOEMCryptoPath(std::string* path) { if (path == nullptr) return false; diff --git a/factory_upload_tool/ce/wv_upload_tool.py b/factory_upload_tool/ce/wv_upload_tool.py index e70ce74c..e1d34b66 100755 --- a/factory_upload_tool/ce/wv_upload_tool.py +++ b/factory_upload_tool/ce/wv_upload_tool.py @@ -130,9 +130,13 @@ def parse_args(): Returns: An argparse.Namespace object populated with the arguments. """ - parser = argparse.ArgumentParser(description='Widevine BCC Batch Upload/Check Tool') + parser = argparse.ArgumentParser( + description='Widevine BCC Batch Upload/Check Tool' + ) - parser.add_argument("--version", action="version", version="20240822") #yyyymmdd + parser.add_argument( + '--version', action='version', version='20240822' + ) # yyyymmdd parser.add_argument( '--json-csr', @@ -200,6 +204,7 @@ def parse_json_csrs(filename, batches): batches: Output dict containing a mapping from json dumped device metadata to BCCs. """ + base_filename = os.path.basename(filename) line_count = 0 for line in open(filename): line_count = line_count + 1 @@ -210,7 +215,10 @@ def parse_json_csrs(filename, batches): die(f'{e.msg} {filename}:{line_count}, char {e.pos}') try: - bcc = {'boot_certificate_chain': obj['bcc']} + bcc = { + 'boot_certificate_chain': obj['bcc'], + 'name': f'{base_filename}#{line_count}', + } device_metadata = json.dumps({ 'company': obj['company'], 'architecture': obj['architecture'], @@ -227,7 +235,7 @@ def parse_json_csrs(filename, batches): die(f'Invalid object at {filename}:{line_count}, missing {e}') if line_count == 0: - die('Empty BCC file!') + die('Empty BCC file!') def format_request_body(args, device_metadata, bccs): @@ -481,6 +489,7 @@ def batch_action_single_attempt(args, path, body): eprint('Failed with unexpected response:') eprint(response_body) + def main(): args = parse_args() if args.dryrun: diff --git a/oemcrypto/include/OEMCryptoCENC.h b/oemcrypto/include/OEMCryptoCENC.h index f2cf3891..0ff0c313 100644 --- a/oemcrypto/include/OEMCryptoCENC.h +++ b/oemcrypto/include/OEMCryptoCENC.h @@ -3,7 +3,7 @@ // License Agreement. /** - * @mainpage OEMCrypto API v19.5 + * @mainpage OEMCrypto API v19.6 * * OEMCrypto is the low level library implemented by the OEM to provide key and * content protection, usually in a separate secure memory or process space. The @@ -766,6 +766,8 @@ typedef enum OEMCrypto_SignatureHashAlgorithm { #define OEMCrypto_GetBCCSignatureType _oecc156 #define OEMCrypto_GetPVRKey _oecc157 #define OEMCrypto_LoadPVRKey _oecc158 +#define OEMCrypto_LoadLicenseData _oecc159 +#define OEMCrypto_SaveLicenseData _oecc160 // clang-format on /// @addtogroup initcontrol @@ -1027,7 +1029,10 @@ OEMCryptoResult OEMCrypto_CloseSession(OEMCrypto_SESSION session); * state, an error of OEMCrypto_ERROR_INVALID_CONTEXT is returned. * * @param[in] session: handle for the session to be used. - * @param[out] nonce: pointer to memory to receive the computed nonce. + * @param[out] nonce pointer to memory to receive the computed nonce. The nonce + * will only be stored into this memory location if the function returns + * OEMCrypto_SUCCESS. If any other OEMCryptoResult is returned, the contents + * of the memory pointed to by nonce will remain unchanged. * * Results: * nonce: the nonce is also stored in secure memory. @@ -3639,7 +3644,9 @@ uint32_t OEMCrypto_MinorAPIVersion(void); * defined * * While not required, another optional top level struct can be added to the - * build information string to provide information about liboemcrypto.so: + * build information string to provide information about liboemcrypto.so. The + * fields within this struct are not required, but if they are included they + * must match the listed data type: * - "ree" { * - "liboemcrypto_ver" [string]: liboemcrypto.so version in string format * eg "2.15.0+tag". Note that this is separate from the "ta_ver" field @@ -4314,8 +4321,8 @@ OEMCryptoResult OEMCrypto_LoadProvisioning( * Receiver certificates may refuse to load these keys and return an error of * OEMCrypto_ERROR_NOT_IMPLEMENTED. The main use case for these alternative * signing algorithms is to support devices that use X509 certificates for - * authentication when acting as a ChromeCast receiver. This is not needed for - * devices that wish to send data to a ChromeCast. Keys loaded from this + * authentication when acting as a Google Cast receiver. This is not needed for + * devices that wish to send data to a Google Cast. Keys loaded from this * function may not be used with OEMCrypto_PrepAndSignLicenseRequest(). * * First, OEMCrypto should generate three secondary keys, mac_key[server], @@ -4388,8 +4395,8 @@ OEMCryptoResult OEMCrypto_LoadProvisioning( * algorithms may refuse to load these keys and return an error of * OEMCrypto_ERROR_NOT_IMPLEMENTED. The main use case for these * alternative signing algorithms is to support devices that use X.509 - * certificates for authentication when acting as a ChromeCast receiver. - * This is not needed for devices that wish to send data to a ChromeCast. + * certificates for authentication when acting as a Google Cast receiver. + * This is not needed for devices that wish to send data to a Google Cast. * 7. After possibly skipping past the first 8 bytes signifying the allowed * signing algorithm, the rest of the buffer private_key contains an ECC * private key or an RSA private key in PKCS#8 binary DER encoded @@ -4562,7 +4569,7 @@ OEMCryptoResult OEMCrypto_LoadTestRSAKey(void); * * The second padding scheme is for devices that use X509 certificates for * authentication. The main example is devices that work as a Cast receiver, - * like a ChromeCast, not for devices that wish to send to the Cast device, + * like a Google Cast, not for devices that wish to send to the Cast device, * such as almost all Android devices. OEMs that do not support X509 * certificate authentication need not implement this function and can return * OEMCrypto_ERROR_NOT_IMPLEMENTED. @@ -6398,6 +6405,44 @@ OEMCryptoResult OEMCrypto_UseSecondaryKey(OEMCrypto_SESSION session_id, */ OEMCryptoResult OEMCrypto_MarkOfflineSession(OEMCrypto_SESSION session); +/** + * Loads the license data into the given session. + * + * @param[in] session: session id for operation. + * @param[in] data: the buffer to import. + * @param[in] data_length: the number of bytes in |data|. + * + * @ignore + * @retval OEMCrypto_SUCCESS on success + * @retval OEMCrypto_ERROR_INVALID_SESSION + * @retval OEMCrypto_ERROR_INVALID_CONTEXT + * @retval OEMCrypto_ERROR_SESSION_STATE_LOST + * @retval OEMCrypto_ERROR_SYSTEM_INVALIDATED + * @retval OEMCrypto_ERROR_NOT_IMPLEMENTED + */ +OEMCryptoResult OEMCrypto_LoadLicenseData(OEMCrypto_SESSION session, + const uint8_t* data, + size_t data_length); + +/** + * Saves the license data for the given session. + * + * @param[in] session: session id for operation. + * @param[out] data: the buffer to export into. + * @param[in,out] data_length: (in) length of the data buffer, in bytes. + * (out) actual length of the data, in bytes. + * + * @ignore + * @retval OEMCrypto_SUCCESS on success + * @retval OEMCrypto_ERROR_INVALID_SESSION + * @retval OEMCrypto_ERROR_INVALID_CONTEXT + * @retval OEMCrypto_ERROR_SESSION_STATE_LOST + * @retval OEMCrypto_ERROR_SYSTEM_INVALIDATED + * @retval OEMCrypto_ERROR_NOT_IMPLEMENTED + */ +OEMCryptoResult OEMCrypto_SaveLicenseData(OEMCrypto_SESSION session, + uint8_t* data, size_t* data_length); + #ifdef __cplusplus } #endif diff --git a/oemcrypto/odk/include/core_message_features.h b/oemcrypto/odk/include/core_message_features.h index 3edfc027..1ac3c756 100644 --- a/oemcrypto/odk/include/core_message_features.h +++ b/oemcrypto/odk/include/core_message_features.h @@ -26,9 +26,9 @@ struct CoreMessageFeatures { // This is the published version of the ODK Core Message library. The default // behavior is for the server to restrict messages to at most this version - // number. The default is 19.5. + // number. The default is 19.6. uint32_t maximum_major_version = 19; - uint32_t maximum_minor_version = 5; + uint32_t maximum_minor_version = 6; bool operator==(const CoreMessageFeatures &other) const; bool operator!=(const CoreMessageFeatures &other) const { diff --git a/oemcrypto/odk/include/odk.h b/oemcrypto/odk/include/odk.h index 679d4e00..f050d2ab 100644 --- a/oemcrypto/odk/include/odk.h +++ b/oemcrypto/odk/include/odk.h @@ -98,6 +98,36 @@ OEMCryptoResult ODK_InitializeSessionValues(ODK_TimerLimits* timer_limits, uint32_t api_major_version, uint32_t session_id); +/* + * This function initializes the session's data structures. It shall be + * called from OEMCrypto_OpenSession. + * + * This function is an extended "Ex" version of + * ODK_InitializeSessionValues(). It is not intended for production systems; + * ODK_InitializeSessionValues() should be used instead. + * + * This function is intentionally excluded from Doxygen. + * + * @param[out] timer_limits: the session's timer limits. + * @param[out] clock_values: the session's clock values. + * @param[out] nonce_values: the session's ODK nonce values. + * @param[in] api_major_version: the API major version of OEMCrypto. + * @param[in] api_minor_version: the API minor version of OEMCrypto. + * @param[in] session_id: the session id of the newly created session. + * + * @retval OEMCrypto_SUCCESS + * @retval OEMCrypto_ERROR_INVALID_CONTEXT + * + * @version + * This method is new in version 19.6 of the API. + */ +OEMCryptoResult ODK_InitializeSessionValuesEx(ODK_TimerLimits* timer_limits, + ODK_ClockValues* clock_values, + ODK_NonceValues* nonce_values, + uint32_t api_major_version, + uint32_t api_minor_version, + uint32_t session_id); + /** * This function sets the nonce value in the session's nonce structure. It * shall be called from OEMCrypto_GenerateNonce. diff --git a/oemcrypto/odk/include/odk_structs.h b/oemcrypto/odk/include/odk_structs.h index 0b905aa9..803f8329 100644 --- a/oemcrypto/odk/include/odk_structs.h +++ b/oemcrypto/odk/include/odk_structs.h @@ -16,10 +16,10 @@ extern "C" { /* The version of this library. */ #define ODK_MAJOR_VERSION 19 -#define ODK_MINOR_VERSION 5 +#define ODK_MINOR_VERSION 6 /* ODK Version string. Date changed automatically on each release. */ -#define ODK_RELEASE_DATE "ODK v19.5 2025-03-11" +#define ODK_RELEASE_DATE "ODK v19.6 2025-06-03" /* The lowest version number for an ODK message. */ #define ODK_FIRST_VERSION 16 diff --git a/oemcrypto/odk/src/core_message_deserialize.cpp b/oemcrypto/odk/src/core_message_deserialize.cpp index e1cdbd0c..f3320b63 100644 --- a/oemcrypto/odk/src/core_message_deserialize.cpp +++ b/oemcrypto/odk/src/core_message_deserialize.cpp @@ -89,10 +89,8 @@ bool ParseRequest(uint32_t message_type, return true; } -} // namespace - -static bool GetNonceFromMessage(const std::string& oemcrypto_core_message, - ODK_NonceValues* nonce_values) { +bool GetNonceFromMessage(const std::string& oemcrypto_core_message, + ODK_NonceValues* nonce_values) { if (nonce_values == nullptr) return false; if (oemcrypto_core_message.size() < sizeof(ODK_CoreMessage)) return false; @@ -125,6 +123,8 @@ bool CopyCounterInfo(ODK_MessageCounter* dest, ODK_MessageCounterInfo* src) { return true; } +} // namespace + bool CoreLicenseRequestFromMessage(const std::string& oemcrypto_core_message, ODK_LicenseRequest* core_license_request) { ODK_NonceValues nonce; diff --git a/oemcrypto/odk/src/core_message_features.cpp b/oemcrypto/odk/src/core_message_features.cpp index 3e8abf42..760ce06e 100644 --- a/oemcrypto/odk/src/core_message_features.cpp +++ b/oemcrypto/odk/src/core_message_features.cpp @@ -33,7 +33,7 @@ CoreMessageFeatures CoreMessageFeatures::DefaultFeatures( features.maximum_minor_version = 4; // 18.4 break; case 19: - features.maximum_minor_version = 5; // 19.5 + features.maximum_minor_version = 6; // 19.6 break; default: features.maximum_minor_version = 0; diff --git a/oemcrypto/odk/src/odk_timer.c b/oemcrypto/odk/src/odk_timer.c index a1a9eb8d..3de8eeb8 100644 --- a/oemcrypto/odk/src/odk_timer.c +++ b/oemcrypto/odk/src/odk_timer.c @@ -3,12 +3,16 @@ // License Agreement. #include -#include #include "odk.h" #include "odk_attributes.h" #include "odk_overflow.h" #include "odk_structs_priv.h" +#include "odk_versions.h" + +/* This is a special value used to signal that the latest API + * minor version should be used for a particular API major version. */ +#define ODK_LATEST_API_MINOR_VERSION UINT32_MAX /* Private function. Checks to see if the license is active. Returns * ODK_TIMER_EXPIRED if the license is valid but inactive. Returns @@ -241,6 +245,62 @@ OEMCryptoResult ODK_ComputeRenewalDuration(const ODK_TimerLimits* timer_limits, return ODK_SET_TIMER; } +/* Private function. Initialize the timer limits to default values. */ +static void InitializeTimerLimits(ODK_TimerLimits* timer_limits) { + if (timer_limits == NULL) { + return; + } + timer_limits->soft_enforce_rental_duration = false; + timer_limits->soft_enforce_playback_duration = false; + timer_limits->earliest_playback_start_seconds = 0; + timer_limits->rental_duration_seconds = 0; + timer_limits->total_playback_duration_seconds = 0; + timer_limits->initial_renewal_duration_seconds = 0; +} + +/* Private function. Obtains the maximum minor version for a given major + * version. */ +static uint32_t GetApiMinorVersion(uint32_t api_major_version) { + /* This needs to be updated with new major version releases. */ + switch (api_major_version) { + case 16: + return ODK_V16_MINOR_VERSION; + case 17: + return ODK_V17_MINOR_VERSION; + case 18: + return ODK_V18_MINOR_VERSION; + case 19: + return ODK_V19_MINOR_VERSION; + default: + return 0; + } +} + +/* Private function. Initialize the nonce values. + * Note: |api_minor_version| may be set to ODK_LATEST_API_MINOR_VERSION.*/ +static void InitializeNonceValues(ODK_NonceValues* nonce_values, + uint32_t api_major_version, + uint32_t api_minor_version, + uint32_t session_id) { + if (nonce_values == NULL) { + return; + } + if (api_major_version > ODK_MAJOR_VERSION) { + api_major_version = ODK_MAJOR_VERSION; + } + /* Floor the API minor version to the maximum minor version for the API major + * version. */ + const uint32_t max_api_minor_version = GetApiMinorVersion(api_major_version); + if (api_minor_version > max_api_minor_version) { + api_minor_version = max_api_minor_version; + } + + nonce_values->api_major_version = api_major_version; + nonce_values->api_minor_version = api_minor_version; + nonce_values->nonce = 0; + nonce_values->session_id = session_id; +} + /************************************************************************/ /************************************************************************/ /* Public functions, declared in odk.h. */ @@ -254,38 +314,27 @@ OEMCryptoResult ODK_InitializeSessionValues(ODK_TimerLimits* timer_limits, if (timer_limits == NULL || clock_values == NULL || nonce_values == NULL) { return OEMCrypto_ERROR_INVALID_CONTEXT; } - timer_limits->soft_enforce_rental_duration = false; - timer_limits->soft_enforce_playback_duration = false; - timer_limits->earliest_playback_start_seconds = 0; - timer_limits->rental_duration_seconds = 0; - timer_limits->total_playback_duration_seconds = 0; - timer_limits->initial_renewal_duration_seconds = 0; - + InitializeTimerLimits(timer_limits); ODK_InitializeClockValues(clock_values, 0); + InitializeNonceValues(nonce_values, api_major_version, + ODK_LATEST_API_MINOR_VERSION, session_id); + return OEMCrypto_SUCCESS; +} - nonce_values->api_major_version = api_major_version; - // This needs to be updated with new version releases in the default features - // of core message features. - switch (nonce_values->api_major_version) { - case 16: - nonce_values->api_minor_version = 5; - break; - case 17: - nonce_values->api_minor_version = 2; - break; - case 18: - nonce_values->api_minor_version = 4; - break; - case 19: - nonce_values->api_minor_version = 5; - break; - default: - nonce_values->api_minor_version = 0; - break; +/* This is called when certain OEMCrypto implementations opens a new session. */ +OEMCryptoResult ODK_InitializeSessionValuesEx(ODK_TimerLimits* timer_limits, + ODK_ClockValues* clock_values, + ODK_NonceValues* nonce_values, + uint32_t api_major_version, + uint32_t api_minor_version, + uint32_t session_id) { + if (timer_limits == NULL || clock_values == NULL || nonce_values == NULL) { + return OEMCrypto_ERROR_INVALID_CONTEXT; } - nonce_values->nonce = 0; - nonce_values->session_id = session_id; - + InitializeTimerLimits(timer_limits); + ODK_InitializeClockValues(clock_values, 0); + InitializeNonceValues(nonce_values, api_major_version, api_minor_version, + session_id); return OEMCrypto_SUCCESS; } diff --git a/oemcrypto/odk/src/odk_versions.h b/oemcrypto/odk/src/odk_versions.h new file mode 100644 index 00000000..17fa449d --- /dev/null +++ b/oemcrypto/odk/src/odk_versions.h @@ -0,0 +1,24 @@ +// Copyright 2025 Google LLC. This file and proprietary +// source code may only be used and distributed under the Widevine +// License Agreement. + +#ifndef WIDEVINE_ODK_SRC_ODK_VERSIONS_H_ +#define WIDEVINE_ODK_SRC_ODK_VERSIONS_H_ + +#include + +#include "odk_structs.h" + +/* Highest ODK minor version number by major version. */ +#define ODK_V16_MINOR_VERSION 5 +#define ODK_V17_MINOR_VERSION 8 +#define ODK_V18_MINOR_VERSION 10 + +/* Whenever the next major version is released, this should be updated to the + * new major version. */ +#if ODK_MAJOR_VERSION != 19 +# error "ODK_MAJOR_VERSION has changed. Please update this file." +#endif +#define ODK_V19_MINOR_VERSION ODK_MINOR_VERSION + +#endif // WIDEVINE_ODK_SRC_ODK_VERSIONS_H_ diff --git a/oemcrypto/odk/test/odk_test.cpp b/oemcrypto/odk/test/odk_test.cpp index bd148832..80b0f523 100644 --- a/oemcrypto/odk/test/odk_test.cpp +++ b/oemcrypto/odk/test/odk_test.cpp @@ -1275,7 +1275,7 @@ std::vector TestCases() { {16, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 16, 5}, {17, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 17, 2}, {18, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 18, 4}, - {19, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 19, 5}, + {19, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 19, 6}, // Here are some known good versions. Make extra sure they work. {ODK_MAJOR_VERSION, 16, 3, 16, 3}, {ODK_MAJOR_VERSION, 16, 4, 16, 4}, @@ -1291,6 +1291,7 @@ std::vector TestCases() { {ODK_MAJOR_VERSION, 19, 3, 19, 3}, {ODK_MAJOR_VERSION, 19, 4, 19, 4}, {ODK_MAJOR_VERSION, 19, 5, 19, 5}, + {ODK_MAJOR_VERSION, 19, 6, 19, 6}, {0, 16, 3, 16, 3}, {0, 16, 4, 16, 4}, {0, 16, 5, 16, 5}, @@ -1304,6 +1305,7 @@ std::vector TestCases() { {0, 19, 3, 19, 3}, {0, 19, 4, 19, 4}, {0, 19, 5, 19, 5}, + {0, 19, 6, 19, 6}, }; return test_cases; } diff --git a/oemcrypto/test/GEN_api_lock_file.c b/oemcrypto/test/GEN_api_lock_file.c index 9a5af65c..0481006e 100644 --- a/oemcrypto/test/GEN_api_lock_file.c +++ b/oemcrypto/test/GEN_api_lock_file.c @@ -450,3 +450,11 @@ OEMCryptoResult _oecc157(OEMCrypto_SESSION session, uint8_t* wrapped_pvr_key, OEMCryptoResult _oecc158(OEMCrypto_SESSION session, const uint8_t* wrapped_pvr_key, size_t wrapped_pvr_key_length); + +// OEMCrypto_LoadLicenseData defined in v19.6 +OEMCryptoResult _oecc159(OEMCrypto_SESSION session, const uint8_t* data, + size_t data_length); + +// OEMCrypto_SaveLicenseData defined in v19.6 +OEMCryptoResult _oecc160(OEMCrypto_SESSION session, uint8_t* data, + size_t* data_length); diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc index 231cf406..e270e4ff 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc +++ b/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc @@ -7,13 +7,13 @@ namespace { void OpenOEMCryptoTASession() { uint8_t request_body[] = { - 0x06, // TAG_UINT32 + 0x07, // TAG_UINT32 0x09, 0x00, 0x00, 0x00, // API value (0x09) - 0x07, // TAG_UINT64 + 0x08, // TAG_UINT64 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Timestamp 0x01, // TAG_BOOL 0x00, // value (false) - 0x0a // TAG_EOM + 0x0b // TAG_EOM }; ODK_Message request = ODK_Message_Create(request_body, sizeof(request_body)); ODK_Message_SetSize(&request, sizeof(request_body)); @@ -23,11 +23,11 @@ void OpenOEMCryptoTASession() { void InitializeOEMCryptoTA() { uint8_t request_body[] = { - 0x06, // TAG_UINT32 + 0x07, // TAG_UINT32 0x01, 0x00, 0x00, 0x00, // API value (0x01) - 0x07, // TAG_UINT64 + 0x08, // TAG_UINT64 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Timestamp - 0x0a // TAG_EOM + 0x0b // TAG_EOM }; ODK_Message request = ODK_Message_Create(request_body, sizeof(request_body)); ODK_Message_SetSize(&request, sizeof(request_body)); diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp b/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp index be0ab62c..52d368d7 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp +++ b/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp @@ -42,12 +42,8 @@ { 'target_name': 'oemcrypto_opk_dispatcher_fuzz', 'include_dirs': [ - '<(oemcrypto_dir)/opk/serialization/common', '<(oemcrypto_dir)/opk/serialization/common/include', '<(oemcrypto_dir)/opk/serialization/os_interfaces', - '<(oemcrypto_dir)/opk/serialization/tee', - '<(oemcrypto_dir)/opk/serialization/tee/include', - '<(oemcrypto_dir)/opk/ports/trusty/include/', ], 'dependencies': [ '<(oemcrypto_dir)/opk/serialization/tee/tee.gyp:opk_tee', @@ -55,9 +51,9 @@ 'sources': [ 'oemcrypto_opk_dispatcher_fuzz.cc', '<(oemcrypto_dir)/opk/serialization/test/tos_secure_buffers.c', - '<(oemcrypto_dir)/opk/serialization/test/tos_transport_interface.c', '<(oemcrypto_dir)/opk/serialization/test/tos_logging.c', - '<(oemcrypto_dir)/opk/ports/trusty/serialization_adapter/shared_memory.c', + '<(oemcrypto_dir)/opk/serialization/test/tos_shared_memory.c', + '<(oemcrypto_dir)/opk/serialization/test/tos_transport_interface.c', ], }, { diff --git a/oemcrypto/test/oec_session_util.cpp b/oemcrypto/test/oec_session_util.cpp index 70297a81..1e169b41 100644 --- a/oemcrypto/test/oec_session_util.cpp +++ b/oemcrypto/test/oec_session_util.cpp @@ -1796,7 +1796,11 @@ void Session::close() { void Session::GenerateNonce(int* error_counter) { // We make one attempt. If it fails, we assume there was a nonce flood. - if (OEMCrypto_SUCCESS == OEMCrypto_GenerateNonce(session_id(), &nonce_)) { + // Using |temp_nonce| to avoid member |nonce_| being modified + // during failure. + uint32_t temp_nonce = 0; + if (OEMCrypto_SUCCESS == OEMCrypto_GenerateNonce(session_id(), &temp_nonce)) { + nonce_ = temp_nonce; return; } if (error_counter) { @@ -1806,7 +1810,8 @@ void Session::GenerateNonce(int* error_counter) { // The following is after a 1 second pause, so it cannot be from a nonce // flood. ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_GenerateNonce(session_id(), &nonce_)); + OEMCrypto_GenerateNonce(session_id(), &temp_nonce)); + nonce_ = temp_nonce; } } diff --git a/oemcrypto/test/oemcrypto_basic_test.cpp b/oemcrypto/test/oemcrypto_basic_test.cpp index 71ad453e..e37e2d5e 100644 --- a/oemcrypto/test/oemcrypto_basic_test.cpp +++ b/oemcrypto/test/oemcrypto_basic_test.cpp @@ -317,14 +317,14 @@ TEST_F(OEMCryptoClientTest, FreeUnallocatedSecureBufferNoFailure) { */ TEST_F(OEMCryptoClientTest, VersionNumber) { const std::string log_message = - "OEMCrypto unit tests for API 19.5. Tests last updated 2025-03-11"; + "OEMCrypto unit tests for API 19.6. Tests last updated 2025-06-03"; cout << " " << log_message << "\n"; cout << " " << "These tests are part of Android V." << "\n"; LOGI("%s", log_message.c_str()); // If any of the following fail, then it is time to update the log message // above. EXPECT_EQ(ODK_MAJOR_VERSION, 19); - EXPECT_EQ(ODK_MINOR_VERSION, 5); + EXPECT_EQ(ODK_MINOR_VERSION, 6); EXPECT_EQ(kCurrentAPI, static_cast(ODK_MAJOR_VERSION)); RecordWvProperty("test_major_version", std::to_string(ODK_MAJOR_VERSION)); RecordWvProperty("test_minor_version", std::to_string(ODK_MINOR_VERSION)); @@ -498,45 +498,58 @@ TEST_F(OEMCryptoClientTest, CheckBuildInformation_OutputLengthAPI17) { ASSERT_GT(build_info_length, kZero) << "Signaling ERROR_SHORT_BUFFER should have assigned a length"; + // Try again using the size they provided, ensuring that it + // is successful. + const size_t initial_estimate_length = build_info_length; + build_info.assign(build_info_length, kNullChar); + result = OEMCrypto_BuildInformation(&build_info[0], &build_info_length); + ASSERT_EQ(result, OEMCrypto_SUCCESS) + << "initial_estimate_length = " << initial_estimate_length + << ", build_info_length (output) = " << build_info_length; + ASSERT_GT(build_info_length, kZero) << "Build info cannot be empty"; + // Ensure the real length is within the size originally specified. + // OK if final length is smaller than estimated length. + ASSERT_LE(build_info_length, initial_estimate_length); + const size_t expected_length = build_info_length; + // Force a ERROR_SHORT_BUFFER using a non-zero value. // Note: It is assumed that vendors will provide more than a single // character of info. - const size_t second_attempt_length = - (build_info_length >= 2) ? build_info_length / 2 : 1; - build_info.assign(second_attempt_length, kNullChar); + const size_t short_length = (expected_length >= 2) ? expected_length / 2 : 1; + build_info.assign(short_length, kNullChar); build_info_length = build_info.size(); result = OEMCrypto_BuildInformation(&build_info[0], &build_info_length); ASSERT_EQ(result, OEMCrypto_ERROR_SHORT_BUFFER) - << "second_attempt_length = " << second_attempt_length - << ", build_info_length" << build_info_length; + << "short_length = " << short_length + << ", expected_length = " << expected_length << ", build_info_length" + << build_info_length; // OEM specified build info length should be larger than the // original length if returning ERROR_SHORT_BUFFER. - ASSERT_GT(build_info_length, second_attempt_length); + ASSERT_GT(build_info_length, short_length); // Final attempt with a buffer large enough buffer, padding to // ensure the caller truncates. constexpr size_t kBufferPadSize = 42; - const size_t expected_length = build_info_length; - const size_t final_attempt_length = expected_length + kBufferPadSize; - build_info.assign(final_attempt_length, kNullChar); + const size_t oversize_length = expected_length + kBufferPadSize; + build_info.assign(oversize_length, kNullChar); build_info_length = build_info.size(); result = OEMCrypto_BuildInformation(&build_info[0], &build_info_length); ASSERT_EQ(result, OEMCrypto_SUCCESS) - << "final_attempt_length = " << final_attempt_length + << "oversize_length = " << oversize_length << ", expected_length = " << expected_length - << ", build_info_length = " << build_info_length; + << ", build_info_length (output) = " << build_info_length; // Ensure not empty. ASSERT_GT(build_info_length, kZero) << "Build info cannot be empty"; // Ensure it was truncated down from the padded length. - ASSERT_LT(build_info_length, final_attempt_length) + ASSERT_LT(build_info_length, oversize_length) << "Should have truncated from oversized buffer: expected_length = " << expected_length; - // Ensure the real length is within the size originally specified. - // OK if final length is smaller than estimated length. - ASSERT_LE(build_info_length, expected_length); + // Ensure that length is equal to the length of the previous + // successful call. + ASSERT_EQ(build_info_length, expected_length); } // Verifies that OEMCrypto_BuildInformation() is behaving as expected @@ -680,7 +693,7 @@ TEST_F(OEMCryptoClientTest, CheckJsonBuildInformationAPI18) { // Whether this was built with FACTORY_MODE_ONLY defined {"is_factory_mode", JSMN_PRIMITIVE}, // ... provide information about liboemcrypto.so - // Special case, see kOptionalReeFields for details. + // Special case, see kReeOptionalFields for details. {kSpecialCaseReeKey, JSMN_OBJECT}, // Technically required, but several implementations // do not implement this fields. @@ -778,7 +791,7 @@ TEST_F(OEMCryptoClientTest, CheckJsonBuildInformationAPI18) { // The optional field "ree", if present, must follow the required // format. - const std::map kReeRequiredFields = { + const std::map kReeOptionalFields = { // liboemcrypto.so version in string format eg "2.15.0+tag" {"liboemcrypto_ver", JSMN_STRING}, // git hash of code that compiled liboemcrypto.so @@ -786,7 +799,6 @@ TEST_F(OEMCryptoClientTest, CheckJsonBuildInformationAPI18) { // ISO 8601 timestamp for when liboemcrypto.so was built {"build_timestamp", JSMN_STRING}}; - found_required_fields.clear(); for (int32_t i = 0; (i + 1) < static_cast(ree_tokens.size()); i += 2) { const jsmntok_t& key_token = ree_tokens[i]; @@ -796,11 +808,10 @@ TEST_F(OEMCryptoClientTest, CheckJsonBuildInformationAPI18) { const std::string key = build_info.substr(key_token.start, key_token.end - key_token.start); - if (kReeRequiredFields.find(key) != kReeRequiredFields.end()) { - ASSERT_EQ(value_token.type, kReeRequiredFields.at(key)) + if (kReeOptionalFields.find(key) != kReeOptionalFields.end()) { + ASSERT_EQ(value_token.type, kReeOptionalFields.at(key)) << "Unexpected optional REE field type: ree_field = " << key << ", build_info = " << build_info; - found_required_fields.insert(key); RecordWvProperty(kReeBuildInfoRecordPrefix + key, build_info.substr(value_token.start, value_token.end - value_token.start)); @@ -810,25 +821,6 @@ TEST_F(OEMCryptoClientTest, CheckJsonBuildInformationAPI18) { i += JsmnAncestorCount(ree_tokens, i + 1); } - // Step 4b: Ensure all required fields of the "ree" object were found. - if (found_required_fields.size() == kReeRequiredFields.size()) return; - // Generate a list of all the missing REE fields. - std::string missing_ree_fields; - for (const auto& required_field : kReeRequiredFields) { - if (found_required_fields.find(required_field.first) != - found_required_fields.end()) - continue; - if (!missing_ree_fields.empty()) { - missing_ree_fields.append(", "); - } - missing_ree_fields.push_back('"'); - missing_ree_fields.append(required_field.first); - missing_ree_fields.push_back('"'); - } - - FAIL() << "REE info JSON object does not contain all required keys; " - << "missing_ree_fields = [" << missing_ree_fields - << "], build_info = " << build_info; } TEST_F(OEMCryptoClientTest, CheckMaxNumberOfSessionsAPI10) { diff --git a/oemcrypto/test/oemcrypto_corpus_generator_helper.cpp b/oemcrypto/test/oemcrypto_corpus_generator_helper.cpp index 6b603dcf..52dc0a01 100644 --- a/oemcrypto/test/oemcrypto_corpus_generator_helper.cpp +++ b/oemcrypto/test/oemcrypto_corpus_generator_helper.cpp @@ -9,7 +9,9 @@ namespace wvoec { +namespace { bool g_generate_corpus; +} void AppendToFile(const std::string& file_name, const char* message, const size_t message_size) { diff --git a/oemcrypto/test/oemcrypto_decrypt_test.cpp b/oemcrypto/test/oemcrypto_decrypt_test.cpp index 34a26474..fa0ee5ea 100644 --- a/oemcrypto/test/oemcrypto_decrypt_test.cpp +++ b/oemcrypto/test/oemcrypto_decrypt_test.cpp @@ -604,11 +604,13 @@ TEST_P(OEMCryptoSessionTestsDecryptTests, ContinueDecryptionAfterIdleAndWake) { ASSERT_NO_FATAL_FAILURE(TestDecryptCENC()); } +namespace { // Used to construct a specific pattern. constexpr OEMCrypto_CENCEncryptPatternDesc MakePattern(size_t encrypt, size_t skip) { return {encrypt, skip}; } +} // namespace INSTANTIATE_TEST_SUITE_P( CTRTests, OEMCryptoSessionTestsDecryptTests, diff --git a/oemcrypto/test/oemcrypto_session_tests_helper.cpp b/oemcrypto/test/oemcrypto_session_tests_helper.cpp index 175bb518..4caa7d04 100644 --- a/oemcrypto/test/oemcrypto_session_tests_helper.cpp +++ b/oemcrypto/test/oemcrypto_session_tests_helper.cpp @@ -8,18 +8,6 @@ using namespace wvoec; namespace wvoec { -// Make this function available when in Fuzz mode because we are not inheriting -// from OEMCryptoClientTest. -const uint8_t* find(const vector& message, - const vector& substring) { - vector::const_iterator pos = search( - message.begin(), message.end(), substring.begin(), substring.end()); - if (pos == message.end()) { - return nullptr; - } - return &(*pos); -} - void SessionUtil::CreateWrappedDRMKey() { if (global_features.provisioning_method == OEMCrypto_BootCertificateChain) { // Have the device create a wrapped key. diff --git a/oemcrypto/util/src/wvcrc.cpp b/oemcrypto/util/src/wvcrc.cpp index 097ca707..5e1917d1 100644 --- a/oemcrypto/util/src/wvcrc.cpp +++ b/oemcrypto/util/src/wvcrc.cpp @@ -11,6 +11,7 @@ namespace wvoec { namespace util { #define INIT_CRC32 0xffffffff +namespace { uint32_t wvrunningcrc32(const uint8_t* p_begin, size_t i_count, uint32_t i_crc) { constexpr uint32_t CRC32[256] = { @@ -67,6 +68,7 @@ uint32_t wvrunningcrc32(const uint8_t* p_begin, size_t i_count, return(i_crc); } +} // namespace uint32_t wvcrc32(const uint8_t* p_begin, size_t i_count) { return(wvrunningcrc32(p_begin, i_count, INIT_CRC32)); diff --git a/oemcrypto/util/test/hmac_unittest.cpp b/oemcrypto/util/test/hmac_unittest.cpp index bc4f355c..0cbdb612 100644 --- a/oemcrypto/util/test/hmac_unittest.cpp +++ b/oemcrypto/util/test/hmac_unittest.cpp @@ -13,7 +13,8 @@ namespace wvoec { namespace util { -namespace { + +// Putting type in non-anonymous namespace to prevent linkage warnings. struct HmacTestVector { std::vector key; std::vector message; @@ -43,6 +44,7 @@ void PrintTo(const HmacTestVector& v, std::ostream* os) { *os << "signature_sha1 = " << wvutil::b2a_hex(v.signature_sha1) << "}"; } +namespace { std::vector FromString(const std::string& s) { return std::vector(s.begin(), s.end()); } diff --git a/third_party/libcppbor/README.md b/third_party/libcppbor/README.md index 0fbbcf3e..fc292040 100644 --- a/third_party/libcppbor/README.md +++ b/third_party/libcppbor/README.md @@ -1,12 +1,17 @@ LibCppBor: A Modern C++ CBOR Parser and Generator ============================================== -TODO(b/254108623): -This is a modified version of LibCppBor and is C++-14 compliant. The released -version can be found at -https://android.googlesource.com/platform/external/libcppbor, which requires -C++-17. This is a reminder of refreshing the library with the latest source -above once we officially move to C++-17. +This is a modified version of LibCppBor, based on commit +`61d9bff9605ad2ffd877bd99a3bde414e21f01a2` from the upstream source at +https://android.googlesource.com/platform/external/libcppbor. + +It's worth noting that while the latest LibCppBor release requires C++20, +CE CDM currently operates on C++17. This serves as a reminder to refresh our +local copy of the library with the latest upstream source it officially +transitions to C++20. + +Below is copied from the latest README.md of LibCppBor. +============================================== LibCppBor provides a natural and easy-to-use syntax for constructing and parsing CBOR messages. It does not (yet) support all features of diff --git a/third_party/libcppbor/include/cppbor/cppbor.h b/third_party/libcppbor/include/cppbor/cppbor.h index 75e577a0..ad28345d 100644 --- a/third_party/libcppbor/include/cppbor/cppbor.h +++ b/third_party/libcppbor/include/cppbor/cppbor.h @@ -24,63 +24,50 @@ #include #include #include +#include +#include +#include #include -#if __cplusplus >= 201402L || \ - (defined __cpp_lib_make_unique && __cpp_lib_make_unique >= 201304L) || \ - (defined(_MSC_VER) && _MSC_VER >= 1900) -using std::make_unique; -#else -template -std::unique_ptr make_unique(Args&&... args) { - return std::unique_ptr(new T(std::forward(args)...)); -} -#endif +#ifdef OS_WINDOWS +#include -template -using enable_if_t = typename std::enable_if::type; +#define ssize_t SSIZE_T +#endif // OS_WINDOWS -template -using remove_cv_t = typename std::remove_cv::type; - -template -using remove_reference_t = typename std::remove_reference::type; - -template -using remove_pointer_t = typename std::remove_pointer::type; - -template -using decay_t = typename std::decay::type; - -template -struct is_null_pointer : std::is_same> {}; +#ifdef TRUE +#undef TRUE +#endif // TRUE +#ifdef FALSE +#undef FALSE +#endif // FALSE namespace cppbor { enum MajorType : uint8_t { - UINT = 0 << 5, - NINT = 1 << 5, - BSTR = 2 << 5, - TSTR = 3 << 5, - ARRAY = 4 << 5, - MAP = 5 << 5, - SEMANTIC = 6 << 5, - SIMPLE = 7 << 5, + UINT = 0 << 5, + NINT = 1 << 5, + BSTR = 2 << 5, + TSTR = 3 << 5, + ARRAY = 4 << 5, + MAP = 5 << 5, + SEMANTIC = 6 << 5, + SIMPLE = 7 << 5, }; enum SimpleType { - BOOLEAN, - NULL_T, // Only two supported, as yet. + BOOLEAN, + NULL_T, // Only two supported, as yet. }; enum SpecialAddlInfoValues : uint8_t { - FALSE = 20, - TRUE = 21, - NULL_V = 22, - ONE_BYTE_LENGTH = 24, - TWO_BYTE_LENGTH = 25, - FOUR_BYTE_LENGTH = 26, - EIGHT_BYTE_LENGTH = 27, + FALSE = 20, + TRUE = 21, + NULL_V = 22, + ONE_BYTE_LENGTH = 24, + TWO_BYTE_LENGTH = 25, + FOUR_BYTE_LENGTH = 26, + EIGHT_BYTE_LENGTH = 27, }; class Item; @@ -96,811 +83,877 @@ class Map; class Null; class SemanticTag; class EncodedItem; +class ViewTstr; +class ViewBstr; /** - * Returns the size of a CBOR header that contains the additional info value - * addlInfo. + * Returns the size of a CBOR header that contains the additional info value addlInfo. */ size_t headerSize(uint64_t addlInfo); /** - * Encodes a CBOR header with the specified type and additional info into the - * range [pos, end). Returns a pointer to one past the last byte written, or - * nullptr if there isn't sufficient space to write the header. + * Encodes a CBOR header with the specified type and additional info into the range [pos, end). + * Returns a pointer to one past the last byte written, or nullptr if there isn't sufficient space + * to write the header. */ -uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, - const uint8_t* end); +uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, const uint8_t* end); using EncodeCallback = std::function; /** - * Encodes a CBOR header with the specified type and additional info, passing - * each byte in turn to encodeCallback. + * Encodes a CBOR header with the specified type and additional info, passing each byte in turn to + * encodeCallback. */ -void encodeHeader(MajorType type, uint64_t addlInfo, - EncodeCallback encodeCallback); +void encodeHeader(MajorType type, uint64_t addlInfo, EncodeCallback encodeCallback); /** - * Encodes a CBOR header with the specified type and additional info, writing - * each byte to the provided OutputIterator. + * Encodes a CBOR header witht he specified type and additional info, writing each byte to the + * provided OutputIterator. */ template ::iterator_category>::value>> + typename = std::enable_if_t::iterator_category>>> void encodeHeader(MajorType type, uint64_t addlInfo, OutputIterator iter) { - return encodeHeader(type, addlInfo, [&](uint8_t v) { *iter++ = v; }); + return encodeHeader(type, addlInfo, [&](uint8_t v) { *iter++ = v; }); } /** - * Item represents a CBOR-encodeable data item. Item is an abstract interface - * with a set of virtual methods that allow encoding of the item or conversion - * to the appropriate derived type. + * Item represents a CBOR-encodeable data item. Item is an abstract interface with a set of virtual + * methods that allow encoding of the item or conversion to the appropriate derived type. */ class Item { - public: - virtual ~Item() {} + public: + virtual ~Item() {} - /** - * Returns the CBOR type of the item. - */ - virtual MajorType type() const = 0; + /** + * Returns the CBOR type of the item. + */ + virtual MajorType type() const = 0; - // These methods safely downcast an Item to the appropriate subclass. - virtual Int* asInt() { return nullptr; } - const Int* asInt() const { return const_cast(this)->asInt(); } - virtual Uint* asUint() { return nullptr; } - const Uint* asUint() const { return const_cast(this)->asUint(); } - virtual Nint* asNint() { return nullptr; } - const Nint* asNint() const { return const_cast(this)->asNint(); } - virtual Tstr* asTstr() { return nullptr; } - const Tstr* asTstr() const { return const_cast(this)->asTstr(); } - virtual Bstr* asBstr() { return nullptr; } - const Bstr* asBstr() const { return const_cast(this)->asBstr(); } - virtual Simple* asSimple() { return nullptr; } - const Simple* asSimple() const { return const_cast(this)->asSimple(); } - virtual Map* asMap() { return nullptr; } - const Map* asMap() const { return const_cast(this)->asMap(); } - virtual Array* asArray() { return nullptr; } - const Array* asArray() const { return const_cast(this)->asArray(); } + // These methods safely downcast an Item to the appropriate subclass. + virtual Int* asInt() { return nullptr; } + const Int* asInt() const { return const_cast(this)->asInt(); } + virtual Uint* asUint() { return nullptr; } + const Uint* asUint() const { return const_cast(this)->asUint(); } + virtual Nint* asNint() { return nullptr; } + const Nint* asNint() const { return const_cast(this)->asNint(); } + virtual Tstr* asTstr() { return nullptr; } + const Tstr* asTstr() const { return const_cast(this)->asTstr(); } + virtual Bstr* asBstr() { return nullptr; } + const Bstr* asBstr() const { return const_cast(this)->asBstr(); } + virtual Simple* asSimple() { return nullptr; } + const Simple* asSimple() const { return const_cast(this)->asSimple(); } + virtual Map* asMap() { return nullptr; } + const Map* asMap() const { return const_cast(this)->asMap(); } + virtual Array* asArray() { return nullptr; } + const Array* asArray() const { return const_cast(this)->asArray(); } - // Like those above, these methods safely downcast an Item when it's actually - // a SemanticTag. However, if you think you want to use these methods, you - // probably don't. Typically, the way you should handle tagged Items is by - // calling the appropriate method above (e.g. asInt()) which will return a - // pointer to the tagged Item, rather than the tag itself. If you want to - // find out if the Item* you're holding is to something with one or more tags - // applied, see semanticTagCount() and semanticTag() below. - virtual SemanticTag* asSemanticTag() { return nullptr; } - const SemanticTag* asSemanticTag() const { - return const_cast(this)->asSemanticTag(); - } + virtual ViewTstr* asViewTstr() { return nullptr; } + const ViewTstr* asViewTstr() const { return const_cast(this)->asViewTstr(); } + virtual ViewBstr* asViewBstr() { return nullptr; } + const ViewBstr* asViewBstr() const { return const_cast(this)->asViewBstr(); } - /** - * Returns the number of semantic tags prefixed to this Item. - */ - virtual size_t semanticTagCount() const { return 0; } + // Like those above, these methods safely downcast an Item when it's actually a SemanticTag. + // However, if you think you want to use these methods, you probably don't. Typically, the way + // you should handle tagged Items is by calling the appropriate method above (e.g. asInt()) + // which will return a pointer to the tagged Item, rather than the tag itself. If you want to + // find out if the Item* you're holding is to something with one or more tags applied, see + // semanticTagCount() and semanticTag() below. + virtual SemanticTag* asSemanticTag() { return nullptr; } + const SemanticTag* asSemanticTag() const { return const_cast(this)->asSemanticTag(); } - /** - * Returns the semantic tag at the specified nesting level `nesting`, iff - * `nesting` is less than the value returned by semanticTagCount(). - * - * CBOR tags are "nested" by applying them in sequence. The "rightmost" tag - * is the "inner" tag. That is, given: - * - * 4(5(6("AES"))) which encodes as C1 C2 C3 63 414553 - * - * The tstr "AES" is tagged with 6. The combined entity ("AES" tagged with 6) - * is tagged with 5, etc. So in this example, semanticTagCount() would return - * 3, and semanticTag(0) would return 5 semanticTag(1) would return 5 and - * semanticTag(2) would return 4. For values of n > 2, semanticTag(n) will - * return 0, but this is a meaningless value. - * - * If this layering is confusing, you probably don't have to worry about it. - * Nested tagging does not appear to be common, so semanticTag(0) is the only - * one you'll use. - */ - virtual uint64_t semanticTag(size_t /* nesting */ = 0) const { return 0; } + /** + * Returns the number of semantic tags prefixed to this Item. + */ + virtual size_t semanticTagCount() const { return 0; } - /** - * Returns true if this is a "compound" item, i.e. one that contains one or - * more other items. - */ - virtual bool isCompound() const { return false; } + /** + * Returns the semantic tag at the specified nesting level `nesting`, iff `nesting` is less than + * the value returned by semanticTagCount(). + * + * CBOR tags are "nested" by applying them in sequence. The "rightmost" tag is the "inner" tag. + * That is, given: + * + * 4(5(6("AES"))) which encodes as C1 C2 C3 63 414553 + * + * The tstr "AES" is tagged with 6. The combined entity ("AES" tagged with 6) is tagged with 5, + * etc. So in this example, semanticTagCount() would return 3, and semanticTag(0) would return + * 5 semanticTag(1) would return 5 and semanticTag(2) would return 4. For values of n > 2, + * semanticTag(n) will return 0, but this is a meaningless value. + * + * If this layering is confusing, you probably don't have to worry about it. Nested tagging does + * not appear to be common, so semanticTag(0) is the only one you'll use. + */ + virtual uint64_t semanticTag(size_t /* nesting */ = 0) const { return 0; } - bool operator==(const Item& other) const&; - bool operator!=(const Item& other) const& { return !(*this == other); } + /** + * Returns true if this is a "compound" item, i.e. one that contains one or more other items. + */ + virtual bool isCompound() const { return false; } - /** - * Returns the number of bytes required to encode this Item into CBOR. Note - * that if this is a complex Item, calling this method will require walking - * the whole tree. - */ - virtual size_t encodedSize() const = 0; + bool operator==(const Item& other) const&; + bool operator!=(const Item& other) const& { return !(*this == other); } - /** - * Encodes the Item into buffer referenced by range [*pos, end). Returns a - * pointer to one past the last position written. Returns nullptr if there - * isn't enough space to encode. - */ - virtual uint8_t* encode(uint8_t* pos, const uint8_t* end) const = 0; + /** + * Returns the number of bytes required to encode this Item into CBOR. Note that if this is a + * complex Item, calling this method will require walking the whole tree. + */ + virtual size_t encodedSize() const = 0; - /** - * Encodes the Item by passing each encoded byte to encodeCallback. - */ - virtual void encode(EncodeCallback encodeCallback) const = 0; + /** + * Encodes the Item into buffer referenced by range [*pos, end). Returns a pointer to one past + * the last position written. Returns nullptr if there isn't enough space to encode. + */ + virtual uint8_t* encode(uint8_t* pos, const uint8_t* end) const = 0; - /** - * Clones the Item - */ - virtual std::unique_ptr clone() const = 0; + /** + * Encodes the Item by passing each encoded byte to encodeCallback. + */ + virtual void encode(EncodeCallback encodeCallback) const = 0; - /** - * Encodes the Item into the provided OutputIterator. - */ - template ::iterator_category> - void encode(OutputIterator i) const { - return encode([&](uint8_t v) { *i++ = v; }); - } + /** + * Clones the Item + */ + virtual std::unique_ptr clone() const = 0; - /** - * Encodes the Item into a new std::vector. - */ - std::vector encode() const { - std::vector retval; - retval.reserve(encodedSize()); - encode(std::back_inserter(retval)); - return retval; - } + /** + * Encodes the Item into the provided OutputIterator. + */ + template ::iterator_category> + void encode(OutputIterator i) const { + return encode([&](uint8_t v) { *i++ = v; }); + } - /** - * Encodes the Item into a new std::string. - */ - std::string toString() const { - std::string retval; - retval.reserve(encodedSize()); - encode([&](uint8_t v) { retval.push_back(v); }); - return retval; - } + /** + * Encodes the Item into a new std::vector. + */ + std::vector encode() const { + std::vector retval; + retval.reserve(encodedSize()); + encode(std::back_inserter(retval)); + return retval; + } - /** - * Encodes only the header of the Item. - */ - inline uint8_t* encodeHeader(uint64_t addlInfo, uint8_t* pos, - const uint8_t* end) const { - return ::cppbor::encodeHeader(type(), addlInfo, pos, end); - } + /** + * Encodes the Item into a new std::string. + */ + std::string toString() const { + std::string retval; + retval.reserve(encodedSize()); + encode([&](uint8_t v) { retval.push_back(v); }); + return retval; + } - /** - * Encodes only the header of the Item. - */ - inline void encodeHeader(uint64_t addlInfo, - EncodeCallback encodeCallback) const { - ::cppbor::encodeHeader(type(), addlInfo, std::move(encodeCallback)); - } + /** + * Encodes only the header of the Item. + */ + inline uint8_t* encodeHeader(uint64_t addlInfo, uint8_t* pos, const uint8_t* end) const { + return ::cppbor::encodeHeader(type(), addlInfo, pos, end); + } + + /** + * Encodes only the header of the Item. + */ + inline void encodeHeader(uint64_t addlInfo, EncodeCallback encodeCallback) const { + ::cppbor::encodeHeader(type(), addlInfo, encodeCallback); + } }; /** - * EncodedItem represents a bit of already-encoded CBOR. Caveat emptor: It does - * no checking to ensure that the provided data is a valid encoding, cannot be - * meaninfully-compared with other kinds of items and you cannot use the as*() - * methods to find out what's inside it. + * EncodedItem represents a bit of already-encoded CBOR. Caveat emptor: It does no checking to + * ensure that the provided data is a valid encoding, cannot be meaninfully-compared with other + * kinds of items and you cannot use the as*() methods to find out what's inside it. */ class EncodedItem : public Item { - public: - explicit EncodedItem(std::vector value) : mValue(std::move(value)) {} + public: + explicit EncodedItem(std::vector value) : mValue(std::move(value)) {} - bool operator==(const EncodedItem& other) const& { - return mValue == other.mValue; - } + bool operator==(const EncodedItem& other) const& { return mValue == other.mValue; } - // Type can't be meaningfully-obtained. We could extract the type from the - // first byte and return it, but you can't do any of the normal things with an - // EncodedItem so there's no point. - MajorType type() const override { - assert(false); - return static_cast(-1); - } - size_t encodedSize() const override { return mValue.size(); } - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { - if (end - pos < static_cast(mValue.size())) return nullptr; - return std::copy(mValue.begin(), mValue.end(), pos); - } - void encode(EncodeCallback encodeCallback) const override { - std::for_each(mValue.begin(), mValue.end(), std::move(encodeCallback)); - } - std::unique_ptr clone() const override { - return make_unique(mValue); - } + // Type can't be meaningfully-obtained. We could extract the type from the first byte and return + // it, but you can't do any of the normal things with an EncodedItem so there's no point. + MajorType type() const override { + assert(false); + return static_cast(-1); + } + size_t encodedSize() const override { return mValue.size(); } + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { + if (end - pos < static_cast(mValue.size())) return nullptr; + return std::copy(mValue.begin(), mValue.end(), pos); + } + void encode(EncodeCallback encodeCallback) const override { + std::for_each(mValue.begin(), mValue.end(), encodeCallback); + } + std::unique_ptr clone() const override { return std::make_unique(mValue); } - private: - std::vector mValue; + private: + std::vector mValue; }; /** - * Int is an abstraction that allows Uint and Nint objects to be manipulated - * without caring about the sign. + * Int is an abstraction that allows Uint and Nint objects to be manipulated without caring about + * the sign. */ class Int : public Item { - public: - bool operator==(const Int& other) const& { return value() == other.value(); } + public: + bool operator==(const Int& other) const& { return value() == other.value(); } - virtual int64_t value() const = 0; - using Item::asInt; - Int* asInt() override { return this; } + virtual int64_t value() const = 0; + using Item::asInt; + Int* asInt() override { return this; } }; /** * Uint is a concrete Item that implements CBOR major type 0. */ class Uint : public Int { - public: - static constexpr MajorType kMajorType = UINT; + public: + static constexpr MajorType kMajorType = UINT; - explicit Uint(uint64_t v) : mValue(v) {} + explicit Uint(uint64_t v) : mValue(v) {} - bool operator==(const Uint& other) const& { return mValue == other.mValue; } + bool operator==(const Uint& other) const& { return mValue == other.mValue; } - MajorType type() const override { return kMajorType; } - using Item::asUint; - Uint* asUint() override { return this; } + MajorType type() const override { return kMajorType; } + using Item::asUint; + Uint* asUint() override { return this; } - size_t encodedSize() const override { return headerSize(mValue); } + size_t encodedSize() const override { return headerSize(mValue); } - int64_t value() const override { return mValue; } - uint64_t unsignedValue() const { return mValue; } + int64_t value() const override { return mValue; } + uint64_t unsignedValue() const { return mValue; } - using Item::encode; - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { - return encodeHeader(mValue, pos, end); - } - void encode(EncodeCallback encodeCallback) const override { - encodeHeader(mValue, std::move(encodeCallback)); - } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { + return encodeHeader(mValue, pos, end); + } + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(mValue, encodeCallback); + } - std::unique_ptr clone() const override { - return make_unique(mValue); - } + std::unique_ptr clone() const override { return std::make_unique(mValue); } - private: - uint64_t mValue; + private: + uint64_t mValue; }; /** * Nint is a concrete Item that implements CBOR major type 1. - * Note that it is incapable of expressing the full range of major type 1 - values, becaue it can only - * express values that fall into the range [std::numeric_limits::min(), - -1]. It cannot + * Note that it is incapable of expressing the full range of major type 1 values, becaue it can only + * express values that fall into the range [std::numeric_limits::min(), -1]. It cannot * express values in the range [std::numeric_limits::min() - 1, * -std::numeric_limits::max()]. */ class Nint : public Int { - public: - static constexpr MajorType kMajorType = NINT; + public: + static constexpr MajorType kMajorType = NINT; - explicit Nint(int64_t v); + explicit Nint(int64_t v); - bool operator==(const Nint& other) const& { return mValue == other.mValue; } + bool operator==(const Nint& other) const& { return mValue == other.mValue; } - MajorType type() const override { return kMajorType; } - using Item::asNint; - Nint* asNint() override { return this; } - size_t encodedSize() const override { return headerSize(addlInfo()); } + MajorType type() const override { return kMajorType; } + using Item::asNint; + Nint* asNint() override { return this; } + size_t encodedSize() const override { return headerSize(addlInfo()); } - int64_t value() const override { return mValue; } + int64_t value() const override { return mValue; } - using Item::encode; - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { - return encodeHeader(addlInfo(), pos, end); - } - void encode(EncodeCallback encodeCallback) const override { - encodeHeader(addlInfo(), std::move(encodeCallback)); - } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { + return encodeHeader(addlInfo(), pos, end); + } + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(addlInfo(), encodeCallback); + } - std::unique_ptr clone() const override { - return make_unique(mValue); - } + std::unique_ptr clone() const override { return std::make_unique(mValue); } - private: - uint64_t addlInfo() const { return -1ll - mValue; } + private: + uint64_t addlInfo() const { return -1ll - mValue; } - int64_t mValue; + int64_t mValue; }; /** * Bstr is a concrete Item that implements major type 2. */ class Bstr : public Item { - public: - static constexpr MajorType kMajorType = BSTR; + public: + static constexpr MajorType kMajorType = BSTR; - // Construct an empty Bstr - explicit Bstr() {} + // Construct an empty Bstr + explicit Bstr() {} - // Construct from a vector - explicit Bstr(std::vector v) : mValue(std::move(v)) {} + // Construct from a vector + explicit Bstr(std::vector v) : mValue(std::move(v)) {} - // Construct from a string - explicit Bstr(const std::string& v) - : mValue(reinterpret_cast(v.data()), - reinterpret_cast(v.data()) + v.size()) {} + // Construct from a string + explicit Bstr(const std::string& v) + : mValue(reinterpret_cast(v.data()), + reinterpret_cast(v.data()) + v.size()) {} - // Construct from a pointer/size pair - explicit Bstr(const std::pair& buf) - : mValue(buf.first, buf.first + buf.second) {} + // Construct from a pointer/size pair + explicit Bstr(const std::pair& buf) + : mValue(buf.first, buf.first + buf.second) {} - // Construct from a pair of iterators - template ::iterator_category, - typename = typename std::iterator_traits::iterator_category> - explicit Bstr(const std::pair& pair) - : mValue(pair.first, pair.second) {} + // Construct from a pair of iterators + template ::iterator_category, + typename = typename std::iterator_traits::iterator_category> + explicit Bstr(const std::pair& pair) : mValue(pair.first, pair.second) {} - // Construct from an iterator range. - template ::iterator_category, - typename = typename std::iterator_traits::iterator_category> - Bstr(I1 begin, I2 end) : mValue(begin, end) {} + // Construct from an iterator range. + template ::iterator_category, + typename = typename std::iterator_traits::iterator_category> + Bstr(I1 begin, I2 end) : mValue(begin, end) {} - bool operator==(const Bstr& other) const& { return mValue == other.mValue; } + bool operator==(const Bstr& other) const& { return mValue == other.mValue; } - MajorType type() const override { return kMajorType; } - using Item::asBstr; - Bstr* asBstr() override { return this; } - size_t encodedSize() const override { - return headerSize(mValue.size()) + mValue.size(); - } - using Item::encode; - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; - void encode(EncodeCallback encodeCallback) const override { - encodeHeader(mValue.size(), encodeCallback); - encodeValue(std::move(encodeCallback)); - } + MajorType type() const override { return kMajorType; } + using Item::asBstr; + Bstr* asBstr() override { return this; } + size_t encodedSize() const override { return headerSize(mValue.size()) + mValue.size(); } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(mValue.size(), encodeCallback); + encodeValue(encodeCallback); + } - const std::vector& value() const { return mValue; } - std::vector&& moveValue() { return std::move(mValue); } + const std::vector& value() const { return mValue; } + std::vector&& moveValue() { return std::move(mValue); } - std::unique_ptr clone() const override { - return make_unique(mValue); - } + std::unique_ptr clone() const override { return std::make_unique(mValue); } - private: - void encodeValue(EncodeCallback encodeCallback) const; + private: + void encodeValue(EncodeCallback encodeCallback) const; - std::vector mValue; + std::vector mValue; +}; + +/** + * ViewBstr is a read-only version of Bstr backed by std::string_view + */ +class ViewBstr : public Item { + public: + static constexpr MajorType kMajorType = BSTR; + + // Construct an empty ViewBstr + explicit ViewBstr() {} + + // Construct from a string_view of uint8_t values + explicit ViewBstr(std::basic_string_view v) : mView(std::move(v)) {} + + // Construct from a string_view + explicit ViewBstr(std::string_view v) + : mView(reinterpret_cast(v.data()), v.size()) {} + + // Construct from an iterator range + template ::iterator_category, + typename = typename std::iterator_traits::iterator_category> + ViewBstr(I1 begin, I2 end) : mView(begin, end) {} + + // Construct from a uint8_t pointer pair + ViewBstr(const uint8_t* begin, const uint8_t* end) + : mView(begin, std::distance(begin, end)) {} + + bool operator==(const ViewBstr& other) const& { return mView == other.mView; } + + MajorType type() const override { return kMajorType; } + using Item::asViewBstr; + ViewBstr* asViewBstr() override { return this; } + size_t encodedSize() const override { return headerSize(mView.size()) + mView.size(); } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(mView.size(), encodeCallback); + encodeValue(encodeCallback); + } + + const std::basic_string_view& view() const { return mView; } + + std::unique_ptr clone() const override { return std::make_unique(mView); } + + private: + void encodeValue(EncodeCallback encodeCallback) const; + + std::basic_string_view mView; }; /** * Tstr is a concrete Item that implements major type 3. */ class Tstr : public Item { - public: - static constexpr MajorType kMajorType = TSTR; + public: + static constexpr MajorType kMajorType = TSTR; - // Construct from a string - explicit Tstr(std::string v) : mValue(std::move(v)) {} + // Construct from a string + explicit Tstr(std::string v) : mValue(std::move(v)) {} - // Construct from a C string - explicit Tstr(const char* v) : mValue(std::string(v)) {} + // Construct from a string_view + explicit Tstr(const std::string_view& v) : mValue(v) {} - // Construct from a pair of iterators - template ::iterator_category, - typename = typename std::iterator_traits::iterator_category> - explicit Tstr(const std::pair& pair) - : mValue(pair.first, pair.second) {} + // Construct from a C string + explicit Tstr(const char* v) : mValue(std::string(v)) {} - // Construct from an iterator range - template ::iterator_category, - typename = typename std::iterator_traits::iterator_category> - Tstr(I1 begin, I2 end) : mValue(begin, end) {} + // Construct from a pair of iterators + template ::iterator_category, + typename = typename std::iterator_traits::iterator_category> + explicit Tstr(const std::pair& pair) : mValue(pair.first, pair.second) {} - bool operator==(const Tstr& other) const& { return mValue == other.mValue; } + // Construct from an iterator range + template ::iterator_category, + typename = typename std::iterator_traits::iterator_category> + Tstr(I1 begin, I2 end) : mValue(begin, end) {} - MajorType type() const override { return kMajorType; } - using Item::asTstr; - Tstr* asTstr() override { return this; } - size_t encodedSize() const override { - return headerSize(mValue.size()) + mValue.size(); - } - using Item::encode; - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; - void encode(EncodeCallback encodeCallback) const override { - encodeHeader(mValue.size(), encodeCallback); - encodeValue(std::move(encodeCallback)); - } + bool operator==(const Tstr& other) const& { return mValue == other.mValue; } - const std::string& value() const { return mValue; } - std::string&& moveValue() { return std::move(mValue); } + MajorType type() const override { return kMajorType; } + using Item::asTstr; + Tstr* asTstr() override { return this; } + size_t encodedSize() const override { return headerSize(mValue.size()) + mValue.size(); } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(mValue.size(), encodeCallback); + encodeValue(encodeCallback); + } - std::unique_ptr clone() const override { - return make_unique(mValue); - } + const std::string& value() const { return mValue; } + std::string&& moveValue() { return std::move(mValue); } - private: - void encodeValue(EncodeCallback encodeCallback) const; + std::unique_ptr clone() const override { return std::make_unique(mValue); } - std::string mValue; + private: + void encodeValue(EncodeCallback encodeCallback) const; + + std::string mValue; +}; + +/** + * ViewTstr is a read-only version of Tstr backed by std::string_view + */ +class ViewTstr : public Item { + public: + static constexpr MajorType kMajorType = TSTR; + + // Construct an empty ViewTstr + explicit ViewTstr() {} + + // Construct from a string_view + explicit ViewTstr(std::string_view v) : mView(std::move(v)) {} + + // Construct from an iterator range + template ::iterator_category, + typename = typename std::iterator_traits::iterator_category> + ViewTstr(I1 begin, I2 end) : mView(begin, end) {} + + // Construct from a uint8_t pointer pair + ViewTstr(const uint8_t* begin, const uint8_t* end) + : mView(reinterpret_cast(begin), + std::distance(begin, end)) {} + + bool operator==(const ViewTstr& other) const& { return mView == other.mView; } + + MajorType type() const override { return kMajorType; } + using Item::asViewTstr; + ViewTstr* asViewTstr() override { return this; } + size_t encodedSize() const override { return headerSize(mView.size()) + mView.size(); } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(mView.size(), encodeCallback); + encodeValue(encodeCallback); + } + + const std::string_view& view() const { return mView; } + + std::unique_ptr clone() const override { return std::make_unique(mView); } + + private: + void encodeValue(EncodeCallback encodeCallback) const; + + std::string_view mView; }; /* * Array is a concrete Item that implements CBOR major type 4. * - * Note that Arrays are not copyable. This is because copying them is expensive - * and making them move-only ensures that they're never copied accidentally. If - * you actually want to copy an Array, use the clone() method. + * Note that Arrays are not copyable. This is because copying them is expensive and making them + * move-only ensures that they're never copied accidentally. If you actually want to copy an Array, + * use the clone() method. */ class Array : public Item { - public: - static constexpr MajorType kMajorType = ARRAY; + public: + static constexpr MajorType kMajorType = ARRAY; - Array() = default; - Array(const Array& other) = delete; - Array(Array&&) = default; - Array& operator=(const Array&) = delete; - Array& operator=(Array&&) = default; + Array() = default; + Array(const Array& other) = delete; + Array(Array&&) = default; + Array& operator=(const Array&) = delete; + Array& operator=(Array&&) = default; - bool operator==(const Array& other) const&; + bool operator==(const Array& other) const&; - /** - * Construct an Array from a variable number of arguments of different types. - * See details::makeItem below for details on what types may be provided. In - * general, this accepts all of the types you'd expect and doest the things - * you'd expect (integral values are addes as Uint or Nint, std::string and - * char* are added as Tstr, bools are added as Bool, etc.). - */ - // template - // Array(Args&&... args); + /** + * Construct an Array from a variable number of arguments of different types. See + * details::makeItem below for details on what types may be provided. In general, this accepts + * all of the types you'd expect and doest the things you'd expect (integral values are addes as + * Uint or Nint, std::string and char* are added as Tstr, bools are added as Bool, etc.). + */ + template + Array(Args&&... args); - /** - * Append a single element to the Array, of any compatible type. - */ - template - Array& add(T&& v) &; - template - Array&& add(T&& v) &&; + /** + * The above variadic constructor is disabled if sizeof(Args) != 1, so special + * case an explicit Array constructor for creating an Array with one Item. + */ + template + explicit Array(T&& v); - bool isCompound() const override { return true; } + /** + * Append a single element to the Array, of any compatible type. + */ + template + Array& add(T&& v) &; + template + Array&& add(T&& v) &&; - virtual size_t size() const { return mEntries.size(); } + bool isCompound() const override { return true; } - size_t encodedSize() const override { - return std::accumulate(mEntries.begin(), mEntries.end(), headerSize(size()), - [](size_t sum, const std::unique_ptr& entry) { - return sum + entry->encodedSize(); - }); - } + virtual size_t size() const { return mEntries.size(); } - using Item::encode; // Make base versions visible. - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; - void encode(EncodeCallback encodeCallback) const override; + size_t encodedSize() const override { + return std::accumulate(mEntries.begin(), mEntries.end(), headerSize(size()), + [](size_t sum, auto& entry) { return sum + entry->encodedSize(); }); + } - const std::unique_ptr& operator[](size_t index) const { - return get(index); - } - std::unique_ptr& operator[](size_t index) { return get(index); } + using Item::encode; // Make base versions visible. + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; + void encode(EncodeCallback encodeCallback) const override; - const std::unique_ptr& get(size_t index) const { - return mEntries[index]; - } - std::unique_ptr& get(size_t index) { return mEntries[index]; } + const std::unique_ptr& operator[](size_t index) const { return get(index); } + std::unique_ptr& operator[](size_t index) { return get(index); } - MajorType type() const override { return kMajorType; } - using Item::asArray; - Array* asArray() override { return this; } + const std::unique_ptr& get(size_t index) const { return mEntries[index]; } + std::unique_ptr& get(size_t index) { return mEntries[index]; } - std::unique_ptr clone() const override; + MajorType type() const override { return kMajorType; } + using Item::asArray; + Array* asArray() override { return this; } - std::vector>::iterator begin() { - return mEntries.begin(); - } - std::vector>::const_iterator begin() const { - return mEntries.begin(); - } - std::vector>::iterator end() { return mEntries.end(); } - std::vector>::const_iterator end() const { - return mEntries.end(); - } + std::unique_ptr clone() const override; - protected: - std::vector> mEntries; + auto begin() { return mEntries.begin(); } + auto begin() const { return mEntries.begin(); } + auto end() { return mEntries.end(); } + auto end() const { return mEntries.end(); } + + protected: + std::vector> mEntries; }; /* * Map is a concrete Item that implements CBOR major type 5. * - * Note that Maps are not copyable. This is because copying them is expensive - * and making them move-only ensures that they're never copied accidentally. If - * you actually want to copy a Map, use the clone() method. + * Note that Maps are not copyable. This is because copying them is expensive and making them + * move-only ensures that they're never copied accidentally. If you actually want to copy a + * Map, use the clone() method. */ class Map : public Item { - public: - static constexpr MajorType kMajorType = MAP; + public: + static constexpr MajorType kMajorType = MAP; - using entry_type = std::pair, std::unique_ptr>; + using entry_type = std::pair, std::unique_ptr>; - Map() = default; - Map(const Map& other) = delete; - Map(Map&&) = default; - Map& operator=(const Map& other) = delete; - Map& operator=(Map&&) = default; + Map() = default; + Map(const Map& other) = delete; + Map(Map&&) = default; + Map& operator=(const Map& other) = delete; + Map& operator=(Map&&) = default; - bool operator==(const Map& other) const&; + bool operator==(const Map& other) const&; - /** - * Construct a Map from a variable number of arguments of different types. An - * even number of arguments must be provided (this is verified statically). - * See details::makeItem below for details on what types may be provided. In - * general, this accepts all of the types you'd expect and doest the things - * you'd expect (integral values are addes as Uint or Nint, std::string and - * char* are added as Tstr, bools are added as Bool, etc.). - */ - template - Map(Args&&... args); + /** + * Construct a Map from a variable number of arguments of different types. An even number of + * arguments must be provided (this is verified statically). See details::makeItem below for + * details on what types may be provided. In general, this accepts all of the types you'd + * expect and doest the things you'd expect (integral values are addes as Uint or Nint, + * std::string and char* are added as Tstr, bools are added as Bool, etc.). + */ + template + Map(Args&&... args); - /** - * Append a key/value pair to the Map, of any compatible types. - */ - template - Map& add(Key&& key, Value&& value) &; - template - Map&& add(Key&& key, Value&& value) &&; + /** + * Append a key/value pair to the Map, of any compatible types. + */ + template + Map& add(Key&& key, Value&& value) &; + template + Map&& add(Key&& key, Value&& value) &&; - bool isCompound() const override { return true; } + bool isCompound() const override { return true; } - virtual size_t size() const { return mEntries.size(); } + virtual size_t size() const { return mEntries.size(); } - size_t encodedSize() const override { - return std::accumulate(mEntries.begin(), mEntries.end(), headerSize(size()), - [](size_t sum, const entry_type& entry) { - return sum + entry.first->encodedSize() + - entry.second->encodedSize(); - }); - } + size_t encodedSize() const override { + return std::accumulate( + mEntries.begin(), mEntries.end(), headerSize(size()), [](size_t sum, auto& entry) { + return sum + entry.first->encodedSize() + entry.second->encodedSize(); + }); + } - using Item::encode; // Make base versions visible. - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; - void encode(EncodeCallback encodeCallback) const override; + using Item::encode; // Make base versions visible. + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; + void encode(EncodeCallback encodeCallback) const override; - /** - * Find and return the value associated with `key`, if any. - * - * If the searched-for `key` is not present, returns `nullptr`. - * - * Note that if the map is canonicalized (sorted), Map::get() peforms a binary - * search. If your map is large and you're searching in it many times, it may - * be worthwhile to canonicalize it to make Map::get() faster. Any use of a - * method that might modify the map disables the speedup. - */ - template - const std::unique_ptr& get(Key key) const; + /** + * Find and return the value associated with `key`, if any. + * + * If the searched-for `key` is not present, returns `nullptr`. + * + * Note that if the map is canonicalized (sorted), Map::get() performs a binary search. If your + * map is large and you're searching in it many times, it may be worthwhile to canonicalize it + * to make Map::get() faster. Any use of a method that might modify the map disables the + * speedup. + */ + template + const std::unique_ptr& get(Key key) const; - // Note that use of non-const operator[] marks the map as not canonicalized. - entry_type& operator[](size_t index) { - mCanonicalized = false; - return mEntries[index]; - } - const entry_type& operator[](size_t index) const { return mEntries[index]; } + // Note that use of non-const operator[] marks the map as not canonicalized. + auto& operator[](size_t index) { + mCanonicalized = false; + return mEntries[index]; + } + const auto& operator[](size_t index) const { return mEntries[index]; } - MajorType type() const override { return kMajorType; } - using Item::asMap; - Map* asMap() override { return this; } + MajorType type() const override { return kMajorType; } + using Item::asMap; + Map* asMap() override { return this; } - /** - * Sorts the map in canonical order, as defined in RFC 7049. Use this before - * encoding if you want canonicalization; cppbor does not canonicalize by - * default, though the integer encodings are always canonical and cppbor does - * not support indefinite-length encodings, so map order canonicalization is - * the only thing that needs to be done. - * - * @param recurse If set to true, canonicalize() will also walk the contents - * of the map and canonicalize any contained maps as well. - */ - Map& canonicalize(bool recurse = false) &; - Map&& canonicalize(bool recurse = false) && { - canonicalize(recurse); - return std::move(*this); - } + /** + * Sorts the map in canonical order, as defined in RFC 7049. Use this before encoding if you + * want canonicalization; cppbor does not canonicalize by default, though the integer encodings + * are always canonical and cppbor does not support indefinite-length encodings, so map order + * canonicalization is the only thing that needs to be done. + * + * @param recurse If set to true, canonicalize() will also walk the contents of the map and + * canonicalize any contained maps as well. + */ + Map& canonicalize(bool recurse = false) &; + Map&& canonicalize(bool recurse = false) && { + canonicalize(recurse); + return std::move(*this); + } - bool isCanonical() { return mCanonicalized; } + bool isCanonical() { return mCanonicalized; } - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::vector::iterator begin() { - mCanonicalized = false; - return mEntries.begin(); - } - std::vector::const_iterator begin() const { - return mEntries.begin(); - } - std::vector::iterator end() { - mCanonicalized = false; - return mEntries.end(); - } - std::vector::const_iterator end() const { return mEntries.end(); } + auto begin() { + mCanonicalized = false; + return mEntries.begin(); + } + auto begin() const { return mEntries.begin(); } + auto end() { + mCanonicalized = false; + return mEntries.end(); + } + auto end() const { return mEntries.end(); } - // Returns true if a < b, per CBOR map key canonicalization rules. - static bool keyLess(const Item* a, const Item* b); + // Returns true if a < b, per CBOR map key canonicalization rules. + static bool keyLess(const Item* a, const Item* b); - protected: - std::vector mEntries; + protected: + std::vector mEntries; - private: - bool mCanonicalized = false; + private: + bool mCanonicalized = false; }; class SemanticTag : public Item { - public: - static constexpr MajorType kMajorType = SEMANTIC; + public: + static constexpr MajorType kMajorType = SEMANTIC; - template - SemanticTag(uint64_t tagValue, T&& taggedItem); - SemanticTag(const SemanticTag& other) = delete; - SemanticTag(SemanticTag&&) = default; - SemanticTag& operator=(const SemanticTag& other) = delete; - SemanticTag& operator=(SemanticTag&&) = default; + template + SemanticTag(uint64_t tagValue, T&& taggedItem); + SemanticTag(const SemanticTag& other) = delete; + SemanticTag(SemanticTag&&) = default; + SemanticTag& operator=(const SemanticTag& other) = delete; + SemanticTag& operator=(SemanticTag&&) = default; - bool operator==(const SemanticTag& other) const& { - return mValue == other.mValue && *mTaggedItem == *other.mTaggedItem; - } + bool operator==(const SemanticTag& other) const& { + return mValue == other.mValue && *mTaggedItem == *other.mTaggedItem; + } - bool isCompound() const override { return true; } + bool isCompound() const override { return true; } - virtual size_t size() const { return 1; } + virtual size_t size() const { return 1; } - // Encoding returns the tag + enclosed Item. - size_t encodedSize() const override { - return headerSize(mValue) + mTaggedItem->encodedSize(); - } + // Encoding returns the tag + enclosed Item. + size_t encodedSize() const override { return headerSize(mValue) + mTaggedItem->encodedSize(); } - using Item::encode; // Make base versions visible. - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; - void encode(EncodeCallback encodeCallback) const override; + using Item::encode; // Make base versions visible. + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override; + void encode(EncodeCallback encodeCallback) const override; - // type() is a bit special. In normal usage it should return the wrapped - // type, but during parsing when we haven't yet parsed the tagged item, it - // needs to return SEMANTIC. - MajorType type() const override { - return mTaggedItem ? mTaggedItem->type() : SEMANTIC; - } - using Item::asSemanticTag; - SemanticTag* asSemanticTag() override { return this; } + // type() is a bit special. In normal usage it should return the wrapped type, but during + // parsing when we haven't yet parsed the tagged item, it needs to return SEMANTIC. + MajorType type() const override { return mTaggedItem ? mTaggedItem->type() : SEMANTIC; } + using Item::asSemanticTag; + SemanticTag* asSemanticTag() override { return this; } - // Type information reflects the enclosed Item. Note that if the - // immediately-enclosed Item is another tag, these methods will recurse down - // to the non-tag Item. - using Item::asInt; - Int* asInt() override { return mTaggedItem->asInt(); } - using Item::asUint; - Uint* asUint() override { return mTaggedItem->asUint(); } - using Item::asNint; - Nint* asNint() override { return mTaggedItem->asNint(); } - using Item::asTstr; - Tstr* asTstr() override { return mTaggedItem->asTstr(); } - using Item::asBstr; - Bstr* asBstr() override { return mTaggedItem->asBstr(); } - using Item::asSimple; - Simple* asSimple() override { return mTaggedItem->asSimple(); } - using Item::asMap; - Map* asMap() override { return mTaggedItem->asMap(); } - using Item::asArray; - Array* asArray() override { return mTaggedItem->asArray(); } + // Type information reflects the enclosed Item. Note that if the immediately-enclosed Item is + // another tag, these methods will recurse down to the non-tag Item. + using Item::asInt; + Int* asInt() override { return mTaggedItem->asInt(); } + using Item::asUint; + Uint* asUint() override { return mTaggedItem->asUint(); } + using Item::asNint; + Nint* asNint() override { return mTaggedItem->asNint(); } + using Item::asTstr; + Tstr* asTstr() override { return mTaggedItem->asTstr(); } + using Item::asBstr; + Bstr* asBstr() override { return mTaggedItem->asBstr(); } + using Item::asSimple; + Simple* asSimple() override { return mTaggedItem->asSimple(); } + using Item::asMap; + Map* asMap() override { return mTaggedItem->asMap(); } + using Item::asArray; + Array* asArray() override { return mTaggedItem->asArray(); } + using Item::asViewTstr; + ViewTstr* asViewTstr() override { return mTaggedItem->asViewTstr(); } + using Item::asViewBstr; + ViewBstr* asViewBstr() override { return mTaggedItem->asViewBstr(); } - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - size_t semanticTagCount() const override; - uint64_t semanticTag(size_t nesting = 0) const override; + size_t semanticTagCount() const override; + uint64_t semanticTag(size_t nesting = 0) const override; - protected: - SemanticTag() = default; - SemanticTag(uint64_t value) : mValue(value) {} - uint64_t mValue; - std::unique_ptr mTaggedItem; + protected: + SemanticTag() = default; + SemanticTag(uint64_t value) : mValue(value) {} + uint64_t mValue; + std::unique_ptr mTaggedItem; }; /** - * Simple is abstract Item that implements CBOR major type 7. It is intended to - * be subclassed to create concrete Simple types. At present only Bool is - * provided. + * Simple is abstract Item that implements CBOR major type 7. It is intended to be subclassed to + * create concrete Simple types. At present only Bool is provided. */ class Simple : public Item { - public: - static constexpr MajorType kMajorType = SIMPLE; + public: + static constexpr MajorType kMajorType = SIMPLE; - bool operator==(const Simple& other) const&; + bool operator==(const Simple& other) const&; - virtual SimpleType simpleType() const = 0; - MajorType type() const override { return kMajorType; } + virtual SimpleType simpleType() const = 0; + MajorType type() const override { return kMajorType; } - Simple* asSimple() override { return this; } + Simple* asSimple() override { return this; } - virtual const Bool* asBool() const { return nullptr; }; - virtual const Null* asNull() const { return nullptr; }; + virtual const Bool* asBool() const { return nullptr; }; + virtual const Null* asNull() const { return nullptr; }; }; /** - * Bool is a concrete type that implements CBOR major type 7, with additional - * item values for TRUE and FALSE. + * Bool is a concrete type that implements CBOR major type 7, with additional item values for TRUE + * and FALSE. */ class Bool : public Simple { - public: - static constexpr SimpleType kSimpleType = BOOLEAN; + public: + static constexpr SimpleType kSimpleType = BOOLEAN; - explicit Bool(bool v) : mValue(v) {} + explicit Bool(bool v) : mValue(v) {} - bool operator==(const Bool& other) const& { return mValue == other.mValue; } + bool operator==(const Bool& other) const& { return mValue == other.mValue; } - SimpleType simpleType() const override { return kSimpleType; } - const Bool* asBool() const override { return this; } + SimpleType simpleType() const override { return kSimpleType; } + const Bool* asBool() const override { return this; } - size_t encodedSize() const override { return 1; } + size_t encodedSize() const override { return 1; } - using Item::encode; - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { - return encodeHeader(mValue ? TRUE : FALSE, pos, end); - } - void encode(EncodeCallback encodeCallback) const override { - encodeHeader(mValue ? TRUE : FALSE, std::move(encodeCallback)); - } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { + return encodeHeader(mValue ? TRUE : FALSE, pos, end); + } + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(mValue ? TRUE : FALSE, encodeCallback); + } - bool value() const { return mValue; } + bool value() const { return mValue; } - std::unique_ptr clone() const override { - return make_unique(mValue); - } + std::unique_ptr clone() const override { return std::make_unique(mValue); } - private: - bool mValue; + private: + bool mValue; }; /** - * Null is a concrete type that implements CBOR major type 7, with additional - * item value for NULL + * Null is a concrete type that implements CBOR major type 7, with additional item value for NULL */ class Null : public Simple { - public: - static constexpr SimpleType kSimpleType = NULL_T; + public: + static constexpr SimpleType kSimpleType = NULL_T; - explicit Null() {} + explicit Null() {} - SimpleType simpleType() const override { return kSimpleType; } - const Null* asNull() const override { return this; } + SimpleType simpleType() const override { return kSimpleType; } + const Null* asNull() const override { return this; } - size_t encodedSize() const override { return 1; } + size_t encodedSize() const override { return 1; } - using Item::encode; - uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { - return encodeHeader(NULL_V, pos, end); - } - void encode(EncodeCallback encodeCallback) const override { - encodeHeader(NULL_V, std::move(encodeCallback)); - } + using Item::encode; + uint8_t* encode(uint8_t* pos, const uint8_t* end) const override { + return encodeHeader(NULL_V, pos, end); + } + void encode(EncodeCallback encodeCallback) const override { + encodeHeader(NULL_V, encodeCallback); + } - std::unique_ptr clone() const override { return make_unique(); } + std::unique_ptr clone() const override { return std::make_unique(); } }; /** * Returns pretty-printed CBOR for |item| * - * If a byte-string is larger than |maxBStrSize| its contents will not be - * printed, instead the value of the form "" will be printed. Pass zero - * for |maxBStrSize| to disable this. + * If a byte-string is larger than |maxBStrSize| its contents will not be printed, instead the value + * of the form "" will be + * printed. Pass zero for |maxBStrSize| to disable this. * - * The |mapKeysToNotPrint| parameter specifies the name of map values to not - * print. This is useful for unit tests. + * The |mapKeysToNotPrint| parameter specifies the name of map values to not print. This is useful + * for unit tests. */ std::string prettyPrint(const Item* item, size_t maxBStrSize = 32, - const std::vector& mapKeysNotToPrint = {}); + const std::vector& mapKeysToNotPrint = {}); /** - * Details. Mostly you shouldn't have to look below, except perhaps at the - * docstring for makeItem. + * Returns pretty-printed CBOR for |value|. + * + * Only valid CBOR should be passed to this function. + * + * If a byte-string is larger than |maxBStrSize| its contents will not be printed, instead the value + * of the form "" will be + * printed. Pass zero for |maxBStrSize| to disable this. + * + * The |mapKeysToNotPrint| parameter specifies the name of map values to not print. This is useful + * for unit tests. + */ +std::string prettyPrint(const std::vector& encodedCbor, size_t maxBStrSize = 32, + const std::vector& mapKeysToNotPrint = {}); + +/** + * Details. Mostly you shouldn't have to look below, except perhaps at the docstring for makeItem. */ namespace details { @@ -909,38 +962,35 @@ struct is_iterator_pair_over : public std::false_type {}; template struct is_iterator_pair_over< - std::pair, V, - typename std::enable_if::value_type>::value>::type> + std::pair, V, + typename std::enable_if_t::value_type>>> : public std::true_type {}; template struct is_unique_ptr_of_subclass_of_v : public std::false_type {}; template -struct is_unique_ptr_of_subclass_of_v< - T, std::unique_ptr

, - typename std::enable_if::value>::type> +struct is_unique_ptr_of_subclass_of_v, + typename std::enable_if_t>> : public std::true_type {}; -/* check if type is one of std::string (1), std::string_view (2), - * null-terminated char* (3) or pair of iterators (4)*/ +/* check if type is one of std::string (1), std::string_view (2), null-terminated char* (3) or pair + * of iterators (4)*/ template struct is_text_type_v : public std::false_type {}; template struct is_text_type_v< - T, typename std::enable_if< - /* case 1 */ // - std::is_same>, std::string>::value - /* case 2 */ // - // || is_same_v>, std::string_view> - /* case 3 */ // - || std::is_same>, char*>::value // - || std::is_same>, const char*>::value - /* case 4 */ - || details::is_iterator_pair_over::value>::type> - : public std::true_type {}; + T, typename std::enable_if_t< + /* case 1 */ // + std::is_same_v>, std::string> + /* case 2 */ // + || std::is_same_v>, std::string_view> + /* case 3 */ // + || std::is_same_v>, char*> // + || std::is_same_v>, const char*> + /* case 4 */ + || details::is_iterator_pair_over::value>> : public std::true_type {}; /** * Construct a unique_ptr from many argument types. Accepts: @@ -948,180 +998,144 @@ struct is_text_type_v< * (a) booleans; * (b) integers, all sizes and signs; * (c) text strings, as defined by is_text_type_v above; - * (d) byte strings, as std::vector(d1), pair of iterators (d2) or - * pair (d3); and (e) Item subclass instances, including Array - * and Map. Items may be provided by naked pointer (e1), unique_ptr (e2), - * reference (e3) or value (e3). If provided by reference or value, will be - * moved if possible. If provided by pointer, ownership is taken. (f) null - * pointer; (g) enums, using the underlying integer value. + * (d) byte strings, as std::vector(d1), pair of iterators (d2) or pair + * (d3); and + * (e) Item subclass instances, including Array and Map. Items may be provided by naked pointer + * (e1), unique_ptr (e2), reference (e3) or value (e3). If provided by reference or value, will + * be moved if possible. If provided by pointer, ownership is taken. + * (f) null pointer; + * (g) enums, using the underlying integer value. */ -template ::value>::type* = nullptr> +template std::unique_ptr makeItem(T v) { - Item* p = new Bool(v); - return std::unique_ptr(p); -} - -template ::value>::type* = nullptr> -std::unique_ptr makeItem(T v) { - Item* p = nullptr; - if (v < 0) { - p = new Nint(v); - } else { - p = new Uint(static_cast(v)); - } - return std::unique_ptr(p); -} - -template ::value>::type* = nullptr> -std::unique_ptr makeItem(T v) { - Item* p = new Tstr(v); - return std::unique_ptr(p); -} - -template < - typename T, - typename std::enable_if< - /* case d1 */ std::is_same>, - std::vector>::value - /* case d2 */ // - || details::is_iterator_pair_over::value - /* case d3 */ // - || std::is_same>, - std::pair>::value>::type* = nullptr> -std::unique_ptr makeItem(T v) { - Item* p = new Bstr(std::move(v)); - return std::unique_ptr(p); -} - -template < - typename T, - typename std::enable_if< - /* case e1 */ std::is_pointer::value && - std::is_base_of>::value>::type* = nullptr> -std::unique_ptr makeItem(T v) { - Item* p = v; - return std::unique_ptr(p); -} - -template ::value>::type* = nullptr> -std::unique_ptr makeItem(T v) { - Item* p = v.release(); - return std::unique_ptr(p); -} - -template ::value>::type* = nullptr> -std::unique_ptr makeItem(T v) { - Item* p = new T(std::move(v)); - return std::unique_ptr(p); -} - -template ::value>::type* = nullptr> -std::unique_ptr makeItem(T) { - Item* p = new Null(); - return std::unique_ptr(p); -} - -template ::value>::type* = nullptr> -std::unique_ptr makeItem(T v) { - Item* p = makeItem(static_cast::type>(v)); - return std::unique_ptr(p); + Item* p = nullptr; + if constexpr (/* case a */ std::is_same_v) { + p = new Bool(v); + } else if constexpr (/* case b */ std::is_integral_v) { // b + if (v < 0) { + p = new Nint(v); + } else { + p = new Uint(static_cast(v)); + } + } else if constexpr (/* case c */ // + details::is_text_type_v::value) { + p = new Tstr(v); + } else if constexpr (/* case d1 */ // + std::is_same_v>, + std::vector> + /* case d2 */ // + || details::is_iterator_pair_over::value + /* case d3 */ // + || std::is_same_v>, + std::pair>) { + p = new Bstr(v); + } else if constexpr (/* case e1 */ // + std::is_pointer_v && + std::is_base_of_v>) { + p = v; + } else if constexpr (/* case e2 */ // + details::is_unique_ptr_of_subclass_of_v::value) { + p = v.release(); + } else if constexpr (/* case e3 */ // + std::is_base_of_v) { + p = new T(std::move(v)); + } else if constexpr (/* case f */ std::is_null_pointer_v) { + p = new Null(); + } else if constexpr (/* case g */ std::is_enum_v) { + return makeItem(static_cast>(v)); + } else { + // It's odd that this can't be static_assert(false), since it shouldn't be evaluated if one + // of the above ifs matches. But static_assert(false) always triggers. + static_assert(std::is_same_v, "makeItem called with unsupported type"); + } + return std::unique_ptr(p); } inline void map_helper(Map& /* map */) {} template inline void map_helper(Map& map, Key&& key, Value&& value, Rest&&... rest) { - map.add(std::forward(key), std::forward(value)); - map_helper(map, std::forward(rest)...); + map.add(std::forward(key), std::forward(value)); + map_helper(map, std::forward(rest)...); } } // namespace details -// template >>::value || ...)>> -// Array::Array(Args&&... args) { -// mEntries.reserve(sizeof...(args)); -// (mEntries.push_back(details::makeItem(std::forward(args))), ...); -// } +template > +Array::Array(Args&&... args) { + mEntries.reserve(sizeof...(args)); + (mEntries.push_back(details::makeItem(std::forward(args))), ...); +} + +template >>>> +Array::Array(T&& v) { + mEntries.push_back(details::makeItem(std::forward(v))); +} template Array& Array::add(T&& v) & { - mEntries.push_back(details::makeItem(std::forward(v))); - return *this; + mEntries.push_back(details::makeItem(std::forward(v))); + return *this; } template Array&& Array::add(T&& v) && { - mEntries.push_back(details::makeItem(std::forward(v))); - return std::move(*this); + mEntries.push_back(details::makeItem(std::forward(v))); + return std::move(*this); } template > + /* Prevent use as copy ctor */ typename = std::enable_if_t<(sizeof...(Args)) != 1>> Map::Map(Args&&... args) { - static_assert((sizeof...(Args)) % 2 == 0, - "Map must have an even number of entries"); - mEntries.reserve(sizeof...(args) / 2); - details::map_helper(*this, std::forward(args)...); + static_assert((sizeof...(Args)) % 2 == 0, "Map must have an even number of entries"); + mEntries.reserve(sizeof...(args) / 2); + details::map_helper(*this, std::forward(args)...); } template Map& Map::add(Key&& key, Value&& value) & { - mEntries.push_back({details::makeItem(std::forward(key)), - details::makeItem(std::forward(value))}); - mCanonicalized = false; - return *this; + mEntries.push_back({details::makeItem(std::forward(key)), + details::makeItem(std::forward(value))}); + mCanonicalized = false; + return *this; } template Map&& Map::add(Key&& key, Value&& value) && { - this->add(std::forward(key), std::forward(value)); - return std::move(*this); + this->add(std::forward(key), std::forward(value)); + return std::move(*this); } static const std::unique_ptr kEmptyItemPtr; template ::value || - std::is_enum::value || - details::is_text_type_v::value>> + typename = std::enable_if_t || std::is_enum_v || + details::is_text_type_v::value>> const std::unique_ptr& Map::get(Key key) const { - auto keyItem = details::makeItem(key); + auto keyItem = details::makeItem(key); - if (mCanonicalized) { - // It's sorted, so binary-search it. - auto found = - std::lower_bound(begin(), end(), keyItem.get(), - [](const entry_type& entry, const Item* itemKey) { - return keyLess(entry.first.get(), itemKey); - }); - return (found == end() || *found->first != *keyItem) ? kEmptyItemPtr - : found->second; - } else { - // Unsorted, do a linear search. - auto found = std::find_if(begin(), end(), [&](const entry_type& entry) { - return *entry.first == *keyItem; - }); - return found == end() ? kEmptyItemPtr : found->second; - } + if (mCanonicalized) { + // It's sorted, so binary-search it. + auto found = std::lower_bound(begin(), end(), keyItem.get(), + [](const entry_type& entry, const Item* ikey) { + return keyLess(entry.first.get(), ikey); + }); + return (found == end() || *found->first != *keyItem) ? kEmptyItemPtr : found->second; + } else { + // Unsorted, do a linear search. + auto found = std::find_if( + begin(), end(), [&](const entry_type& entry) { return *entry.first == *keyItem; }); + return found == end() ? kEmptyItemPtr : found->second; + } } template SemanticTag::SemanticTag(uint64_t value, T&& taggedItem) - : mValue(value), - mTaggedItem(details::makeItem(std::forward(taggedItem))) {} + : mValue(value), mTaggedItem(details::makeItem(std::forward(taggedItem))) {} -} // namespace cppbor \ No newline at end of file +} // namespace cppbor diff --git a/third_party/libcppbor/include/cppbor/cppbor_parse.h b/third_party/libcppbor/include/cppbor/cppbor_parse.h index f1b36472..22cd18d0 100644 --- a/third_party/libcppbor/include/cppbor/cppbor_parse.h +++ b/third_party/libcppbor/include/cppbor/cppbor_parse.h @@ -36,6 +36,24 @@ using ParseResult = std::tuple /* result */, const uint8_t */ ParseResult parse(const uint8_t* begin, const uint8_t* end); +/** + * Parse the first CBOR data item (possibly compound) from the range [begin, end). + * + * Returns a tuple of Item pointer, buffer pointer and error message. If parsing is successful, the + * Item pointer is non-null, the buffer pointer points to the first byte after the + * successfully-parsed item and the error message string is empty. If parsing fails, the Item + * pointer is null, the buffer pointer points to the first byte that was unparseable (the first byte + * of a data item header that is malformed in some way, e.g. an invalid value, or a length that is + * too large for the remaining buffer, etc.) and the string contains an error message describing the + * problem encountered. + * + * The returned CBOR data item will contain View* items backed by + * std::string_view types over the input range. + * WARNING! If the input range changes underneath, the corresponding views will + * carry the same change. + */ +ParseResult parseWithViews(const uint8_t* begin, const uint8_t* end); + /** * Parse the first CBOR data item (possibly compound) from the byte vector. * @@ -66,6 +84,26 @@ inline ParseResult parse(const uint8_t* begin, size_t size) { return parse(begin, begin + size); } +/** + * Parse the first CBOR data item (possibly compound) from the range [begin, begin + size). + * + * Returns a tuple of Item pointer, buffer pointer and error message. If parsing is successful, the + * Item pointer is non-null, the buffer pointer points to the first byte after the + * successfully-parsed item and the error message string is empty. If parsing fails, the Item + * pointer is null, the buffer pointer points to the first byte that was unparseable (the first byte + * of a data item header that is malformed in some way, e.g. an invalid value, or a length that is + * too large for the remaining buffer, etc.) and the string contains an error message describing the + * problem encountered. + * + * The returned CBOR data item will contain View* items backed by + * std::string_view types over the input range. + * WARNING! If the input range changes underneath, the corresponding views will + * carry the same change. + */ +inline ParseResult parseWithViews(const uint8_t* begin, size_t size) { + return parseWithViews(begin, begin + size); +} + /** * Parse the first CBOR data item (possibly compound) from the value contained in a Bstr. * @@ -91,6 +129,13 @@ class ParseClient; */ void parse(const uint8_t* begin, const uint8_t* end, ParseClient* parseClient); +/** + * Parse the CBOR data in the range [begin, end) in streaming fashion, calling methods on the + * provided ParseClient when elements are found. Uses the View* item types + * instead of the copying ones. + */ +void parseWithViews(const uint8_t* begin, const uint8_t* end, ParseClient* parseClient); + /** * Parse the CBOR data in the vector in streaming fashion, calling methods on the * provided ParseClient when elements are found. diff --git a/third_party/libcppbor/src/cppbor.cpp b/third_party/libcppbor/src/cppbor.cpp index 254f2041..45adb599 100644 --- a/third_party/libcppbor/src/cppbor.cpp +++ b/third_party/libcppbor/src/cppbor.cpp @@ -17,9 +17,12 @@ #include "cppbor.h" #include +#include #include +#include "cppbor_parse.h" + using std::string; using std::vector; @@ -29,539 +32,570 @@ namespace cppbor { namespace { -template ::value>> +template ::value>> Iterator writeBigEndian(T value, Iterator pos) { - for (unsigned i = 0; i < sizeof(value); ++i) { - *pos++ = static_cast(value >> (8 * (sizeof(value) - 1))); - value = static_cast(value << 8); - } - return pos; + for (unsigned i = 0; i < sizeof(value); ++i) { + *pos++ = static_cast(value >> (8 * (sizeof(value) - 1))); + value = static_cast(value << 8); + } + return pos; } template ::value>> void writeBigEndian(T value, std::function& cb) { - for (unsigned i = 0; i < sizeof(value); ++i) { - cb(static_cast(value >> (8 * (sizeof(value) - 1)))); - value = static_cast(value << 8); - } + for (unsigned i = 0; i < sizeof(value); ++i) { + cb(static_cast(value >> (8 * (sizeof(value) - 1)))); + value = static_cast(value << 8); + } } bool cborAreAllElementsNonCompound(const Item* compoundItem) { - if (compoundItem->type() == ARRAY) { - const Array* array = compoundItem->asArray(); - for (size_t n = 0; n < array->size(); n++) { - const Item* entry = (*array)[n].get(); - switch (entry->type()) { - case ARRAY: - case MAP: - return false; - default: - break; - } + if (compoundItem->type() == ARRAY) { + const Array* array = compoundItem->asArray(); + for (size_t n = 0; n < array->size(); n++) { + const Item* entry = (*array)[n].get(); + switch (entry->type()) { + case ARRAY: + case MAP: + return false; + default: + break; + } + } + } else { + const Map* map = compoundItem->asMap(); + for (auto& [keyEntry, valueEntry] : *map) { + switch (keyEntry->type()) { + case ARRAY: + case MAP: + return false; + default: + break; + } + switch (valueEntry->type()) { + case ARRAY: + case MAP: + return false; + default: + break; + } + } } - } else { - const Map* map = compoundItem->asMap(); - for (auto& entry : *map) { - auto& keyEntry = entry.first; - auto& valueEntry = entry.second; - - switch (keyEntry->type()) { - case ARRAY: - case MAP: - return false; - default: - break; - } - switch (valueEntry->type()) { - case ARRAY: - case MAP: - return false; - default: - break; - } - } - } - return true; + return true; } -bool prettyPrintInternal(const Item* item, string& out, size_t indent, - size_t maxBStrSize, +bool prettyPrintInternal(const Item* item, string& out, size_t indent, size_t maxBStrSize, const vector& mapKeysToNotPrint) { - if (!item) { - out.append(""); - return false; - } - - char buf[80]; - - string indentString(indent, ' '); - - size_t tagCount = item->semanticTagCount(); - while (tagCount > 0) { - --tagCount; - snprintf(buf, sizeof(buf), "tag %" PRIu64 " ", item->semanticTag(tagCount)); - out.append(buf); - } - - switch (item->type()) { - case SEMANTIC: - // Handled above. - break; - - case UINT: - snprintf(buf, sizeof(buf), "%" PRIu64, item->asUint()->unsignedValue()); - out.append(buf); - break; - - case NINT: - snprintf(buf, sizeof(buf), "%" PRId64, item->asNint()->value()); - out.append(buf); - break; - - case BSTR: { - const uint8_t* valueData; - size_t valueSize; - const Bstr* bstr = item->asBstr(); - if (bstr == nullptr) { + if (!item) { + out.append(""); return false; - } - const vector& value = bstr->value(); - valueData = value.data(); - valueSize = value.size(); + } - out.append("{"); - for (size_t n = 0; n < valueSize; n++) { - if (n > 0) { - out.append(", "); - } - snprintf(buf, sizeof(buf), "0x%02x", valueData[n]); + char buf[80]; + + string indentString(indent, ' '); + + size_t tagCount = item->semanticTagCount(); + while (tagCount > 0) { + --tagCount; + snprintf(buf, sizeof(buf), "tag %" PRIu64 " ", item->semanticTag(tagCount)); out.append(buf); - } - out.append("}"); - } break; + } - case TSTR: - out.append("'"); - { - // TODO: escape "'" characters - if (item->asTstr() != nullptr) { - out.append(item->asTstr()->value().c_str()); - } else { - } - } - out.append("'"); - break; + switch (item->type()) { + case SEMANTIC: + // Handled above. + break; - case ARRAY: { - const Array* array = item->asArray(); - if (array->size() == 0) { - out.append("[]"); - } else if (cborAreAllElementsNonCompound(array)) { - out.append("["); - for (size_t n = 0; n < array->size(); n++) { - if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, - maxBStrSize, mapKeysToNotPrint)) { - return false; - } - out.append(", "); - } - out.append("]"); - } else { - out.append("[\n" + indentString); - for (size_t n = 0; n < array->size(); n++) { - out.append(" "); - if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, - maxBStrSize, mapKeysToNotPrint)) { - return false; - } - out.append(",\n" + indentString); - } - out.append("]"); - } - } break; + case UINT: + snprintf(buf, sizeof(buf), "%" PRIu64, item->asUint()->unsignedValue()); + out.append(buf); + break; - case MAP: { - const Map* map = item->asMap(); + case NINT: + snprintf(buf, sizeof(buf), "%" PRId64, item->asNint()->value()); + out.append(buf); + break; - if (map->size() == 0) { - out.append("{}"); - } else { - out.append("{\n" + indentString); - for (auto& entry : *map) { - auto& map_key = entry.first; - auto& map_value = entry.second; + case BSTR: { + const uint8_t* valueData; + size_t valueSize; + const Bstr* bstr = item->asBstr(); + if (bstr != nullptr) { + const vector& value = bstr->value(); + valueData = value.data(); + valueSize = value.size(); + } else { + const ViewBstr* viewBstr = item->asViewBstr(); + assert(viewBstr != nullptr); - out.append(" "); - - if (!prettyPrintInternal(map_key.get(), out, indent + 2, maxBStrSize, - mapKeysToNotPrint)) { - return false; - } - out.append(" : "); - if (map_key->type() == TSTR && - std::find(mapKeysToNotPrint.begin(), mapKeysToNotPrint.end(), - map_key->asTstr()->value()) != - mapKeysToNotPrint.end()) { - out.append(""); - } else { - if (!prettyPrintInternal(map_value.get(), out, indent + 2, - maxBStrSize, mapKeysToNotPrint)) { - return false; + valueData = viewBstr->view().data(); + valueSize = viewBstr->view().size(); } - } - out.append(",\n" + indentString); - } - out.append("}"); - } - } break; - case SIMPLE: - const Bool* asBool = item->asSimple()->asBool(); - const Null* asNull = item->asSimple()->asNull(); - if (asBool != nullptr) { - out.append(asBool->value() ? "true" : "false"); - } else if (asNull != nullptr) { - out.append("null"); - } else { - return false; - } - break; - } + if (valueSize > maxBStrSize) { + snprintf(buf, sizeof(buf), "", valueSize); + out.append(buf); + } else { + out.append("{"); + for (size_t n = 0; n < valueSize; n++) { + if (n > 0) { + out.append(", "); + } + snprintf(buf, sizeof(buf), "0x%02x", valueData[n]); + out.append(buf); + } + out.append("}"); + } + } break; - return true; + case TSTR: + out.append("'"); + { + // TODO: escape "'" characters + if (item->asTstr() != nullptr) { + out.append(item->asTstr()->value().c_str()); + } else { + const ViewTstr* viewTstr = item->asViewTstr(); + assert(viewTstr != nullptr); + out.append(viewTstr->view()); + } + } + out.append("'"); + break; + + case ARRAY: { + const Array* array = item->asArray(); + if (array->size() == 0) { + out.append("[]"); + } else if (cborAreAllElementsNonCompound(array)) { + out.append("["); + for (size_t n = 0; n < array->size(); n++) { + if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize, + mapKeysToNotPrint)) { + return false; + } + out.append(", "); + } + out.append("]"); + } else { + out.append("[\n" + indentString); + for (size_t n = 0; n < array->size(); n++) { + out.append(" "); + if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize, + mapKeysToNotPrint)) { + return false; + } + out.append(",\n" + indentString); + } + out.append("]"); + } + } break; + + case MAP: { + const Map* map = item->asMap(); + + if (map->size() == 0) { + out.append("{}"); + } else { + out.append("{\n" + indentString); + for (auto& [map_key, map_value] : *map) { + out.append(" "); + + if (!prettyPrintInternal(map_key.get(), out, indent + 2, maxBStrSize, + mapKeysToNotPrint)) { + return false; + } + out.append(" : "); + if (map_key->type() == TSTR && + std::find(mapKeysToNotPrint.begin(), mapKeysToNotPrint.end(), + map_key->asTstr()->value()) != mapKeysToNotPrint.end()) { + out.append(""); + } else { + if (!prettyPrintInternal(map_value.get(), out, indent + 2, maxBStrSize, + mapKeysToNotPrint)) { + return false; + } + } + out.append(",\n" + indentString); + } + out.append("}"); + } + } break; + + case SIMPLE: + const Bool* asBool = item->asSimple()->asBool(); + const Null* asNull = item->asSimple()->asNull(); + if (asBool != nullptr) { + out.append(asBool->value() ? "true" : "false"); + } else if (asNull != nullptr) { + out.append("null"); + } else { + return false; + } + break; + } + + return true; } } // namespace size_t headerSize(uint64_t addlInfo) { - if (addlInfo < ONE_BYTE_LENGTH) return 1; - if (addlInfo <= std::numeric_limits::max()) return 2; - if (addlInfo <= std::numeric_limits::max()) return 3; - if (addlInfo <= std::numeric_limits::max()) return 5; - return 9; + if (addlInfo < ONE_BYTE_LENGTH) return 1; + if (addlInfo <= std::numeric_limits::max()) return 2; + if (addlInfo <= std::numeric_limits::max()) return 3; + if (addlInfo <= std::numeric_limits::max()) return 5; + return 9; } -uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, - const uint8_t* end) { - size_t sz = headerSize(addlInfo); - if (end - pos < static_cast(sz)) return nullptr; - switch (sz) { - case 1: - *pos++ = type | static_cast(addlInfo); - return pos; - case 2: - *pos++ = type | ONE_BYTE_LENGTH; - *pos++ = static_cast(addlInfo); - return pos; - case 3: - *pos++ = type | TWO_BYTE_LENGTH; - return writeBigEndian(static_cast(addlInfo), pos); - case 5: - *pos++ = type | FOUR_BYTE_LENGTH; - return writeBigEndian(static_cast(addlInfo), pos); - case 9: - *pos++ = type | EIGHT_BYTE_LENGTH; - return writeBigEndian(addlInfo, pos); - default: - CHECK(false); // Impossible to get here. - return nullptr; - } +uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, const uint8_t* end) { + size_t sz = headerSize(addlInfo); + if (end - pos < static_cast(sz)) return nullptr; + switch (sz) { + case 1: + *pos++ = type | static_cast(addlInfo); + return pos; + case 2: + *pos++ = type | ONE_BYTE_LENGTH; + *pos++ = static_cast(addlInfo); + return pos; + case 3: + *pos++ = type | TWO_BYTE_LENGTH; + return writeBigEndian(static_cast(addlInfo), pos); + case 5: + *pos++ = type | FOUR_BYTE_LENGTH; + return writeBigEndian(static_cast(addlInfo), pos); + case 9: + *pos++ = type | EIGHT_BYTE_LENGTH; + return writeBigEndian(addlInfo, pos); + default: + CHECK(false); // Impossible to get here. + return nullptr; + } } -void encodeHeader(MajorType type, uint64_t addlInfo, - EncodeCallback encodeCallback) { - size_t sz = headerSize(addlInfo); - switch (sz) { - case 1: - encodeCallback(type | static_cast(addlInfo)); - break; - case 2: - encodeCallback(type | ONE_BYTE_LENGTH); - encodeCallback(static_cast(addlInfo)); - break; - case 3: - encodeCallback(type | TWO_BYTE_LENGTH); - writeBigEndian(static_cast(addlInfo), encodeCallback); - break; - case 5: - encodeCallback(type | FOUR_BYTE_LENGTH); - writeBigEndian(static_cast(addlInfo), encodeCallback); - break; - case 9: - encodeCallback(type | EIGHT_BYTE_LENGTH); - writeBigEndian(addlInfo, encodeCallback); - break; - default: - CHECK(false); // Impossible to get here. - } +void encodeHeader(MajorType type, uint64_t addlInfo, EncodeCallback encodeCallback) { + size_t sz = headerSize(addlInfo); + switch (sz) { + case 1: + encodeCallback(type | static_cast(addlInfo)); + break; + case 2: + encodeCallback(type | ONE_BYTE_LENGTH); + encodeCallback(static_cast(addlInfo)); + break; + case 3: + encodeCallback(type | TWO_BYTE_LENGTH); + writeBigEndian(static_cast(addlInfo), encodeCallback); + break; + case 5: + encodeCallback(type | FOUR_BYTE_LENGTH); + writeBigEndian(static_cast(addlInfo), encodeCallback); + break; + case 9: + encodeCallback(type | EIGHT_BYTE_LENGTH); + writeBigEndian(addlInfo, encodeCallback); + break; + default: + CHECK(false); // Impossible to get here. + } } bool Item::operator==(const Item& other) const& { - if (type() != other.type()) return false; - switch (type()) { - case UINT: - return *asUint() == *(other.asUint()); - case NINT: - return *asNint() == *(other.asNint()); - case BSTR: - if (asBstr() != nullptr && other.asBstr() != nullptr) { - return *asBstr() == *(other.asBstr()); - } - // Interesting corner case: comparing a Bstr and ViewBstr with - // identical contents. The function currently returns false for - // this case. - // TODO: if it should return true, this needs a deep comparison - return false; - case TSTR: - if (asTstr() != nullptr && other.asTstr() != nullptr) { - return *asTstr() == *(other.asTstr()); - } - // Same corner case as Bstr - return false; - case ARRAY: - return *asArray() == *(other.asArray()); - case MAP: - return *asMap() == *(other.asMap()); - case SIMPLE: - return *asSimple() == *(other.asSimple()); - case SEMANTIC: - return *asSemanticTag() == *(other.asSemanticTag()); - default: - CHECK(false); // Impossible to get here. - return false; - } + if (type() != other.type()) return false; + switch (type()) { + case UINT: + return *asUint() == *(other.asUint()); + case NINT: + return *asNint() == *(other.asNint()); + case BSTR: + if (asBstr() != nullptr && other.asBstr() != nullptr) { + return *asBstr() == *(other.asBstr()); + } + if (asViewBstr() != nullptr && other.asViewBstr() != nullptr) { + return *asViewBstr() == *(other.asViewBstr()); + } + // Interesting corner case: comparing a Bstr and ViewBstr with + // identical contents. The function currently returns false for + // this case. + // TODO: if it should return true, this needs a deep comparison + return false; + case TSTR: + if (asTstr() != nullptr && other.asTstr() != nullptr) { + return *asTstr() == *(other.asTstr()); + } + if (asViewTstr() != nullptr && other.asViewTstr() != nullptr) { + return *asViewTstr() == *(other.asViewTstr()); + } + // Same corner case as Bstr + return false; + case ARRAY: + return *asArray() == *(other.asArray()); + case MAP: + return *asMap() == *(other.asMap()); + case SIMPLE: + return *asSimple() == *(other.asSimple()); + case SEMANTIC: + return *asSemanticTag() == *(other.asSemanticTag()); + default: + CHECK(false); // Impossible to get here. + return false; + } } -Nint::Nint(int64_t v) : mValue(v) { CHECK(v < 0); } +Nint::Nint(int64_t v) : mValue(v) { + CHECK(v < 0); +} bool Simple::operator==(const Simple& other) const& { - if (simpleType() != other.simpleType()) return false; + if (simpleType() != other.simpleType()) return false; - switch (simpleType()) { - case BOOLEAN: - return *asBool() == *(other.asBool()); - case NULL_T: - return true; - default: - CHECK(false); // Impossible to get here. - return false; - } + switch (simpleType()) { + case BOOLEAN: + return *asBool() == *(other.asBool()); + case NULL_T: + return true; + default: + CHECK(false); // Impossible to get here. + return false; + } } uint8_t* Bstr::encode(uint8_t* pos, const uint8_t* end) const { - pos = encodeHeader(mValue.size(), pos, end); - if (!pos || end - pos < static_cast(mValue.size())) return nullptr; - return std::copy(mValue.begin(), mValue.end(), pos); + pos = encodeHeader(mValue.size(), pos, end); + if (!pos || end - pos < static_cast(mValue.size())) return nullptr; + return std::copy(mValue.begin(), mValue.end(), pos); } void Bstr::encodeValue(EncodeCallback encodeCallback) const { - for (auto c : mValue) { - encodeCallback(c); - } + for (auto c : mValue) { + encodeCallback(c); + } +} + +uint8_t* ViewBstr::encode(uint8_t* pos, const uint8_t* end) const { + pos = encodeHeader(mView.size(), pos, end); + if (!pos || end - pos < static_cast(mView.size())) return nullptr; + return std::copy(mView.begin(), mView.end(), pos); +} + +void ViewBstr::encodeValue(EncodeCallback encodeCallback) const { + for (auto c : mView) { + encodeCallback(static_cast(c)); + } } uint8_t* Tstr::encode(uint8_t* pos, const uint8_t* end) const { - pos = encodeHeader(mValue.size(), pos, end); - if (!pos || end - pos < static_cast(mValue.size())) return nullptr; - return std::copy(mValue.begin(), mValue.end(), pos); + pos = encodeHeader(mValue.size(), pos, end); + if (!pos || end - pos < static_cast(mValue.size())) return nullptr; + return std::copy(mValue.begin(), mValue.end(), pos); } void Tstr::encodeValue(EncodeCallback encodeCallback) const { - for (auto c : mValue) { - encodeCallback(static_cast(c)); - } + for (auto c : mValue) { + encodeCallback(static_cast(c)); + } +} + +uint8_t* ViewTstr::encode(uint8_t* pos, const uint8_t* end) const { + pos = encodeHeader(mView.size(), pos, end); + if (!pos || end - pos < static_cast(mView.size())) return nullptr; + return std::copy(mView.begin(), mView.end(), pos); +} + +void ViewTstr::encodeValue(EncodeCallback encodeCallback) const { + for (auto c : mView) { + encodeCallback(static_cast(c)); + } } bool Array::operator==(const Array& other) const& { - return size() == other.size() - // Can't use vector::operator== because the contents are pointers. - // std::equal lets us provide a predicate that does the dereferencing. - && std::equal(mEntries.begin(), mEntries.end(), other.mEntries.begin(), - [](const std::unique_ptr& a, - const std::unique_ptr& b) -> bool { - return *a == *b; - }); + return size() == other.size() + // Can't use vector::operator== because the contents are pointers. std::equal lets us + // provide a predicate that does the dereferencing. + && std::equal(mEntries.begin(), mEntries.end(), other.mEntries.begin(), + [](auto& a, auto& b) -> bool { return *a == *b; }); } uint8_t* Array::encode(uint8_t* pos, const uint8_t* end) const { - pos = encodeHeader(size(), pos, end); - if (!pos) return nullptr; - for (auto& entry : mEntries) { - pos = entry->encode(pos, end); + pos = encodeHeader(size(), pos, end); if (!pos) return nullptr; - } - return pos; + for (auto& entry : mEntries) { + pos = entry->encode(pos, end); + if (!pos) return nullptr; + } + return pos; } void Array::encode(EncodeCallback encodeCallback) const { - encodeHeader(size(), encodeCallback); - for (auto& entry : mEntries) { - entry->encode(encodeCallback); - } + encodeHeader(size(), encodeCallback); + for (auto& entry : mEntries) { + entry->encode(encodeCallback); + } } std::unique_ptr Array::clone() const { - auto res = make_unique(); - for (size_t i = 0; i < mEntries.size(); i++) { - res->add(mEntries[i]->clone()); - } - return res; + auto res = std::make_unique(); + for (size_t i = 0; i < mEntries.size(); i++) { + res->add(mEntries[i]->clone()); + } + return res; } bool Map::operator==(const Map& other) const& { - return size() == other.size() - // Can't use vector::operator== because the contents are pairs of - // pointers. std::equal lets us provide a predicate that does the - // dereferencing. - && std::equal(begin(), end(), other.begin(), - [](const entry_type& a, const entry_type& b) { - return *a.first == *b.first && *a.second == *b.second; - }); + return size() == other.size() + // Can't use vector::operator== because the contents are pairs of pointers. std::equal + // lets us provide a predicate that does the dereferencing. + && std::equal(begin(), end(), other.begin(), [](auto& a, auto& b) { + return *a.first == *b.first && *a.second == *b.second; + }); } uint8_t* Map::encode(uint8_t* pos, const uint8_t* end) const { - pos = encodeHeader(size(), pos, end); - if (!pos) return nullptr; - for (auto& entry : mEntries) { - pos = entry.first->encode(pos, end); + pos = encodeHeader(size(), pos, end); if (!pos) return nullptr; - pos = entry.second->encode(pos, end); - if (!pos) return nullptr; - } - return pos; + for (auto& entry : mEntries) { + pos = entry.first->encode(pos, end); + if (!pos) return nullptr; + pos = entry.second->encode(pos, end); + if (!pos) return nullptr; + } + return pos; } void Map::encode(EncodeCallback encodeCallback) const { - encodeHeader(size(), encodeCallback); - for (auto& entry : mEntries) { - entry.first->encode(encodeCallback); - entry.second->encode(encodeCallback); - } + encodeHeader(size(), encodeCallback); + for (auto& entry : mEntries) { + entry.first->encode(encodeCallback); + entry.second->encode(encodeCallback); + } } bool Map::keyLess(const Item* a, const Item* b) { - // CBOR map canonicalization rules are: + // CBOR map canonicalization rules are: - // 1. If two keys have different lengths, the shorter one sorts earlier. - if (a->encodedSize() < b->encodedSize()) return true; - if (a->encodedSize() > b->encodedSize()) return false; + // 1. If two keys have different lengths, the shorter one sorts earlier. + if (a->encodedSize() < b->encodedSize()) return true; + if (a->encodedSize() > b->encodedSize()) return false; - // 2. If two keys have the same length, the one with the lower value in - // (byte-wise) lexical order sorts earlier. This requires encoding both - // items. - auto encodedA = a->encode(); - auto encodedB = b->encode(); + // 2. If two keys have the same length, the one with the lower value in (byte-wise) lexical + // order sorts earlier. This requires encoding both items. + auto encodedA = a->encode(); + auto encodedB = b->encode(); - return std::lexicographical_compare(encodedA.begin(), encodedA.end(), // - encodedB.begin(), encodedB.end()); + return std::lexicographical_compare(encodedA.begin(), encodedA.end(), // + encodedB.begin(), encodedB.end()); } void recursivelyCanonicalize(std::unique_ptr& item) { - switch (item->type()) { - case UINT: - case NINT: - case BSTR: - case TSTR: - case SIMPLE: - return; + switch (item->type()) { + case UINT: + case NINT: + case BSTR: + case TSTR: + case SIMPLE: + return; - case ARRAY: - std::for_each(item->asArray()->begin(), item->asArray()->end(), - recursivelyCanonicalize); - return; + case ARRAY: + std::for_each(item->asArray()->begin(), item->asArray()->end(), + recursivelyCanonicalize); + return; - case MAP: - item->asMap()->canonicalize(true /* recurse */); - return; + case MAP: + item->asMap()->canonicalize(true /* recurse */); + return; - case SEMANTIC: - // This can't happen. SemanticTags delegate their type() method to the - // contained Item's type. - assert(false); - return; - } + case SEMANTIC: + // This can't happen. SemanticTags delegate their type() method to the contained Item's + // type. + assert(false); + return; + } } Map& Map::canonicalize(bool recurse) & { - if (recurse) { - for (auto& entry : mEntries) { - recursivelyCanonicalize(entry.first); - recursivelyCanonicalize(entry.second); + if (recurse) { + for (auto& entry : mEntries) { + recursivelyCanonicalize(entry.first); + recursivelyCanonicalize(entry.second); + } } - } - if (size() < 2 || mCanonicalized) { - // Trivially or already canonical; do nothing. + if (size() < 2 || mCanonicalized) { + // Trivially or already canonical; do nothing. + return *this; + } + + std::sort(begin(), end(), + [](auto& a, auto& b) { return keyLess(a.first.get(), b.first.get()); }); + mCanonicalized = true; return *this; - } - - std::sort(begin(), end(), [](const entry_type& a, const entry_type& b) { - return keyLess(a.first.get(), b.first.get()); - }); - mCanonicalized = true; - return *this; } std::unique_ptr Map::clone() const { - auto res = make_unique(); - for (auto& entry : *this) { - auto& key = entry.first; - auto& value = entry.second; - res->add(key->clone(), value->clone()); - } - res->mCanonicalized = mCanonicalized; - return res; + auto res = std::make_unique(); + for (auto& [key, value] : *this) { + res->add(key->clone(), value->clone()); + } + res->mCanonicalized = mCanonicalized; + return res; } std::unique_ptr SemanticTag::clone() const { - return make_unique(mValue, mTaggedItem->clone()); + return std::make_unique(mValue, mTaggedItem->clone()); } uint8_t* SemanticTag::encode(uint8_t* pos, const uint8_t* end) const { - // Can't use the encodeHeader() method that calls type() to get the major - // type, since that will return the tagged Item's type. - pos = ::cppbor::encodeHeader(kMajorType, mValue, pos, end); - if (!pos) return nullptr; - return mTaggedItem->encode(pos, end); + // Can't use the encodeHeader() method that calls type() to get the major type, since that will + // return the tagged Item's type. + pos = ::cppbor::encodeHeader(kMajorType, mValue, pos, end); + if (!pos) return nullptr; + return mTaggedItem->encode(pos, end); } void SemanticTag::encode(EncodeCallback encodeCallback) const { - // Can't use the encodeHeader() method that calls type() to get the major - // type, since that will return the tagged Item's type. - ::cppbor::encodeHeader(kMajorType, mValue, encodeCallback); - mTaggedItem->encode(std::move(encodeCallback)); + // Can't use the encodeHeader() method that calls type() to get the major type, since that will + // return the tagged Item's type. + ::cppbor::encodeHeader(kMajorType, mValue, encodeCallback); + mTaggedItem->encode(encodeCallback); } size_t SemanticTag::semanticTagCount() const { - size_t levelCount = 1; // Count this level. - const SemanticTag* cur = this; - while (cur->mTaggedItem && - (cur = cur->mTaggedItem->asSemanticTag()) != nullptr) - ++levelCount; - return levelCount; + size_t levelCount = 1; // Count this level. + const SemanticTag* cur = this; + while (cur->mTaggedItem && (cur = cur->mTaggedItem->asSemanticTag()) != nullptr) ++levelCount; + return levelCount; } uint64_t SemanticTag::semanticTag(size_t nesting) const { - // Getting the value of a specific nested tag is a bit tricky, because we - // start with the outer tag and don't know how many are inside. We count the - // number of nesting levels to find out how many there are in total, then to - // get the one we want we have to walk down levelCount - nesting steps. - size_t levelCount = semanticTagCount(); - if (nesting >= levelCount) return 0; + // Getting the value of a specific nested tag is a bit tricky, because we start with the outer + // tag and don't know how many are inside. We count the number of nesting levels to find out + // how many there are in total, then to get the one we want we have to walk down levelCount - + // nesting steps. + size_t levelCount = semanticTagCount(); + if (nesting >= levelCount) return 0; - levelCount -= nesting; - const SemanticTag* cur = this; - while (--levelCount > 0) cur = cur->mTaggedItem->asSemanticTag(); + levelCount -= nesting; + const SemanticTag* cur = this; + while (--levelCount > 0) cur = cur->mTaggedItem->asSemanticTag(); - return cur->mValue; + return cur->mValue; } -string prettyPrint(const Item* item, size_t maxBStrSize, +string prettyPrint(const Item* item, size_t maxBStrSize, const vector& mapKeysToNotPrint) { + string out; + prettyPrintInternal(item, out, 0, maxBStrSize, mapKeysToNotPrint); + return out; +} +string prettyPrint(const vector& encodedCbor, size_t maxBStrSize, const vector& mapKeysToNotPrint) { - string out; - prettyPrintInternal(item, out, 0, maxBStrSize, mapKeysToNotPrint); - return out; + auto [item, _, message] = parse(encodedCbor); + if (item == nullptr) { + return ""; + } + + return prettyPrint(item.get(), maxBStrSize, mapKeysToNotPrint); } } // namespace cppbor diff --git a/third_party/libcppbor/src/cppbor_parse.cpp b/third_party/libcppbor/src/cppbor_parse.cpp index 4388b7bc..6ddf7c71 100644 --- a/third_party/libcppbor/src/cppbor_parse.cpp +++ b/third_party/libcppbor/src/cppbor_parse.cpp @@ -19,10 +19,12 @@ #include #include #include +#include #ifndef __has_feature #define __has_feature(x) 0 #endif +#define CHECK(x) (void)(x) namespace cppbor { @@ -36,7 +38,7 @@ std::string insufficientLengthString(size_t bytesNeeded, size_t bytesAvail, return std::string(buf); } -template ::value>> +template >> std::tuple parseLength(const uint8_t* pos, const uint8_t* end, ParseClient* parseClient) { if (pos + sizeof(T) > end) { @@ -235,9 +237,8 @@ std::tuple parseRecursively(const uint8_t* begin, break; default: - // It's impossible to get here - parseClient->error(begin, "Invalid tag."); - return {}; + CHECK(false); // It's impossible to get here + break; } } @@ -251,10 +252,18 @@ std::tuple parseRecursively(const uint8_t* begin, return handleNint(addlData, begin, pos, parseClient); case BSTR: - return handleString(addlData, begin, pos, end, "byte string", parseClient); + if (emitViews) { + return handleString(addlData, begin, pos, end, "byte string", parseClient); + } else { + return handleString(addlData, begin, pos, end, "byte string", parseClient); + } case TSTR: - return handleString(addlData, begin, pos, end, "text string", parseClient); + if (emitViews) { + return handleString(addlData, begin, pos, end, "text string", parseClient); + } else { + return handleString(addlData, begin, pos, end, "text string", parseClient); + } case ARRAY: return handleCompound(std::make_unique(addlData), addlData, begin, pos, @@ -280,8 +289,7 @@ std::tuple parseRecursively(const uint8_t* begin, return {begin, nullptr}; } } - // Impossible to get here. - parseClient->error(begin, "Invalid type."); + CHECK(false); // Impossible to get here. return {}; } @@ -310,9 +318,7 @@ class FullParseClient : public ParseClient { virtual ParseClient* itemEnd(std::unique_ptr& item, const uint8_t*, const uint8_t*, const uint8_t* end) override { - if (!item->isCompound() || item.get() != mParentStack.top()) { - return nullptr; - } + CHECK(item->isCompound() && item.get() == mParentStack.top()); mParentStack.pop(); if (mParentStack.empty()) { @@ -340,10 +346,10 @@ class FullParseClient : public ParseClient { private: void appendToLastParent(std::unique_ptr item) { auto parent = mParentStack.top(); + #if __has_feature(cxx_rtti) assert(dynamic_cast(parent)); #endif - IncompleteItem* parentItem{}; if (parent->type() == ARRAY) { parentItem = static_cast(parent); @@ -352,7 +358,7 @@ class FullParseClient : public ParseClient { } else if (parent->asSemanticTag()) { parentItem = static_cast(parent); } else { - // Impossible to get here. + CHECK(false); // Impossible to get here. } parentItem->add(std::move(item)); } @@ -377,4 +383,16 @@ parse(const uint8_t* begin, const uint8_t* end) { return parseClient.parseResult(); } +void parseWithViews(const uint8_t* begin, const uint8_t* end, ParseClient* parseClient) { + parseRecursively(begin, end, true, parseClient); +} + +std::tuple /* result */, const uint8_t* /* newPos */, + std::string /* errMsg */> +parseWithViews(const uint8_t* begin, const uint8_t* end) { + FullParseClient parseClient; + parseWithViews(begin, end, &parseClient); + return parseClient.parseResult(); +} + } // namespace cppbor diff --git a/third_party/libcppbor/tests/cppbor_test.cpp b/third_party/libcppbor/tests/cppbor_test.cpp index d40de092..68778dc4 100644 --- a/third_party/libcppbor/tests/cppbor_test.cpp +++ b/third_party/libcppbor/tests/cppbor_test.cpp @@ -152,6 +152,31 @@ TEST(SimpleValueTest, NestedSemanticTagEncoding) { tripleTagged.toString()); } +TEST(SimpleValueTest, ViewByteStringEncodings) { + EXPECT_EQ("\x40", ViewBstr("").toString()); + EXPECT_EQ("\x41\x61", ViewBstr("a").toString()); + EXPECT_EQ("\x41\x41", ViewBstr("A").toString()); + EXPECT_EQ("\x44\x49\x45\x54\x46", ViewBstr("IETF").toString()); + EXPECT_EQ("\x42\x22\x5c", ViewBstr("\"\\").toString()); + EXPECT_EQ("\x42\xc3\xbc", ViewBstr("\xc3\xbc").toString()); + EXPECT_EQ("\x43\xe6\xb0\xb4", ViewBstr("\xe6\xb0\xb4").toString()); + EXPECT_EQ("\x44\xf0\x90\x85\x91", ViewBstr("\xf0\x90\x85\x91").toString()); + EXPECT_EQ("\x44\x01\x02\x03\x04", ViewBstr("\x01\x02\x03\x04").toString()); + EXPECT_EQ("\x44\x40\x40\x40\x40", ViewBstr("@@@@").toString()); +} + +TEST(SimpleValueTest, ViewTextStringEncodings) { + EXPECT_EQ("\x60"s, ViewTstr("").toString()); + EXPECT_EQ("\x61\x61"s, ViewTstr("a").toString()); + EXPECT_EQ("\x61\x41"s, ViewTstr("A").toString()); + EXPECT_EQ("\x64\x49\x45\x54\x46"s, ViewTstr("IETF").toString()); + EXPECT_EQ("\x62\x22\x5c"s, ViewTstr("\"\\").toString()); + EXPECT_EQ("\x62\xc3\xbc"s, ViewTstr("\xc3\xbc").toString()); + EXPECT_EQ("\x63\xe6\xb0\xb4"s, ViewTstr("\xe6\xb0\xb4").toString()); + EXPECT_EQ("\x64\xf0\x90\x85\x91"s, ViewTstr("\xf0\x90\x85\x91").toString()); + EXPECT_EQ("\x64\x01\x02\x03\x04"s, ViewTstr("\x01\x02\x03\x04").toString()); +} + TEST(IsIteratorPairOverTest, All) { EXPECT_TRUE(( details::is_iterator_pair_over, char>::value)); @@ -230,6 +255,13 @@ TEST(MakeEntryTest, StdStrings) { details::makeItem(std::move(s1))->toString()); // move string } +TEST(MakeEntryTest, StdStringViews) { + string_view s1("hello"); + const string_view s2("hello"); + EXPECT_EQ("\x65\x68\x65\x6c\x6c\x6f"s, details::makeItem(s1)->toString()); + EXPECT_EQ("\x65\x68\x65\x6c\x6c\x6f"s, details::makeItem(s2)->toString()); +} + TEST(MakeEntryTest, CStrings) { char s1[] = "hello"; const char s2[] = "hello"; @@ -496,6 +528,8 @@ TEST(EqualityTest, Uint) { EXPECT_NE(val, Bool(false)); EXPECT_NE(val, Array(99, 1)); EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("99")); } TEST(EqualityTest, Nint) { @@ -509,6 +543,8 @@ TEST(EqualityTest, Nint) { EXPECT_NE(val, Bool(false)); EXPECT_NE(val, Array(99)); EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("99")); } TEST(EqualityTest, Tstr) { @@ -523,6 +559,8 @@ TEST(EqualityTest, Tstr) { EXPECT_NE(val, Bool(false)); EXPECT_NE(val, Array(99, 1)); EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("99")); } TEST(EqualityTest, Bstr) { @@ -537,6 +575,8 @@ TEST(EqualityTest, Bstr) { EXPECT_NE(val, Bool(false)); EXPECT_NE(val, Array(99, 1)); EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("99")); } TEST(EqualityTest, Bool) { @@ -551,6 +591,8 @@ TEST(EqualityTest, Bool) { EXPECT_NE(val, Bool(true)); EXPECT_NE(val, Array(99, 1)); EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("98")); } TEST(EqualityTest, Array) { @@ -567,6 +609,8 @@ TEST(EqualityTest, Array) { EXPECT_NE(val, Array(98, 1)); EXPECT_NE(val, Array(99, 1, 2)); EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("98")); } TEST(EqualityTest, Map) { @@ -582,6 +626,8 @@ TEST(EqualityTest, Map) { EXPECT_NE(val, Array(99, 1)); EXPECT_NE(val, Map(99, 2)); EXPECT_NE(val, Map(99, 1, 99, 2)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("98")); } TEST(EqualityTest, Null) { @@ -597,6 +643,8 @@ TEST(EqualityTest, Null) { EXPECT_NE(val, Array(99, 1)); EXPECT_NE(val, Map(99, 2)); EXPECT_NE(val, Map(99, 1, 99, 2)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("98")); } TEST(EqualityTest, SemanticTag) { @@ -627,6 +675,40 @@ TEST(EqualityTest, NestedSemanticTag) { EXPECT_NE(val, Array(99, 1)); EXPECT_NE(val, Map(99, 2)); EXPECT_NE(val, Null()); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("98")); +} + +TEST(EqualityTest, ViewTstr) { + ViewTstr val("99"); + EXPECT_EQ(val, ViewTstr("99")); + + EXPECT_NE(val, Uint(99)); + EXPECT_NE(val, Nint(-1)); + EXPECT_NE(val, Nint(-4)); + EXPECT_NE(val, Tstr("99")); + EXPECT_NE(val, Bstr("99")); + EXPECT_NE(val, Bool(false)); + EXPECT_NE(val, Array(99, 1)); + EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("98")); + EXPECT_NE(val, ViewBstr("99")); +} + +TEST(EqualityTest, ViewBstr) { + ViewBstr val("99"); + EXPECT_EQ(val, ViewBstr("99")); + + EXPECT_NE(val, Uint(99)); + EXPECT_NE(val, Nint(-1)); + EXPECT_NE(val, Nint(-4)); + EXPECT_NE(val, Tstr("99")); + EXPECT_NE(val, Bstr("99")); + EXPECT_NE(val, Bool(false)); + EXPECT_NE(val, Array(99, 1)); + EXPECT_NE(val, Map(99, 1)); + EXPECT_NE(val, ViewTstr("99")); + EXPECT_NE(val, ViewBstr("98")); } TEST(ConvertTest, Uint) { @@ -641,6 +723,8 @@ TEST(ConvertTest, Uint) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ(10, item->asInt()->value()); EXPECT_EQ(10, item->asUint()->value()); @@ -658,6 +742,8 @@ TEST(ConvertTest, Nint) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ(-10, item->asInt()->value()); EXPECT_EQ(-10, item->asNint()->value()); @@ -675,6 +761,8 @@ TEST(ConvertTest, Tstr) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ("hello"s, item->asTstr()->value()); } @@ -692,6 +780,8 @@ TEST(ConvertTest, Bstr) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ(vec, item->asBstr()->value()); } @@ -708,6 +798,8 @@ TEST(ConvertTest, Bool) { EXPECT_NE(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ(cppbor::BOOLEAN, item->asSimple()->simpleType()); EXPECT_NE(nullptr, item->asSimple()->asBool()); @@ -728,6 +820,8 @@ TEST(ConvertTest, Map) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_NE(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ(0U, item->asMap()->size()); } @@ -744,6 +838,8 @@ TEST(ConvertTest, Array) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_NE(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ(0U, item->asArray()->size()); } @@ -759,6 +855,8 @@ TEST(ConvertTest, SemanticTag) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); // Both asTstr() (the contained type) and asSemanticTag() return non-null. EXPECT_NE(nullptr, item->asTstr()); @@ -785,6 +883,8 @@ TEST(ConvertTest, NestedSemanticTag) { EXPECT_EQ(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); // Both asTstr() (the contained type) and asSemanticTag() return non-null. EXPECT_NE(nullptr, item->asTstr()); @@ -814,12 +914,52 @@ TEST(ConvertTest, Null) { EXPECT_NE(nullptr, item->asSimple()); EXPECT_EQ(nullptr, item->asMap()); EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); EXPECT_EQ(NULL_T, item->asSimple()->simpleType()); EXPECT_EQ(nullptr, item->asSimple()->asBool()); EXPECT_NE(nullptr, item->asSimple()->asNull()); } +TEST(ConvertTest, ViewTstr) { + unique_ptr item = details::makeItem(ViewTstr("hello")); + + EXPECT_EQ(TSTR, item->type()); + EXPECT_EQ(nullptr, item->asInt()); + EXPECT_EQ(nullptr, item->asUint()); + EXPECT_EQ(nullptr, item->asNint()); + EXPECT_EQ(nullptr, item->asTstr()); + EXPECT_EQ(nullptr, item->asBstr()); + EXPECT_EQ(nullptr, item->asSimple()); + EXPECT_EQ(nullptr, item->asMap()); + EXPECT_EQ(nullptr, item->asArray()); + EXPECT_NE(nullptr, item->asViewTstr()); + EXPECT_EQ(nullptr, item->asViewBstr()); + + EXPECT_EQ("hello"sv, item->asViewTstr()->view()); +} + +TEST(ConvertTest, ViewBstr) { + array vec{0x23, 0x24, 0x22}; + basic_string_view sv(vec.data(), vec.size()); + unique_ptr item = details::makeItem(ViewBstr(sv)); + + EXPECT_EQ(BSTR, item->type()); + EXPECT_EQ(nullptr, item->asInt()); + EXPECT_EQ(nullptr, item->asUint()); + EXPECT_EQ(nullptr, item->asNint()); + EXPECT_EQ(nullptr, item->asTstr()); + EXPECT_EQ(nullptr, item->asBstr()); + EXPECT_EQ(nullptr, item->asSimple()); + EXPECT_EQ(nullptr, item->asMap()); + EXPECT_EQ(nullptr, item->asArray()); + EXPECT_EQ(nullptr, item->asViewTstr()); + EXPECT_NE(nullptr, item->asViewBstr()); + + EXPECT_EQ(sv, item->asViewBstr()->view()); +} + TEST(CloningTest, Uint) { Uint item(10); auto clone = item.clone(); @@ -928,6 +1068,26 @@ TEST(CloningTest, NestedSemanticTag) { EXPECT_EQ(*clone->asSemanticTag(), copy); } +TEST(CloningTest, ViewTstr) { + ViewTstr item("qwertyasdfgh"); + auto clone = item.clone(); + EXPECT_EQ(clone->type(), TSTR); + EXPECT_NE(clone->asViewTstr(), nullptr); + EXPECT_EQ(item, *clone->asViewTstr()); + EXPECT_EQ(*clone->asViewTstr(), ViewTstr("qwertyasdfgh")); +} + +TEST(CloningTest, ViewBstr) { + array vec{1, 2, 3, 255, 0}; + basic_string_view sv(vec.data(), vec.size()); + ViewBstr item(sv); + auto clone = item.clone(); + EXPECT_EQ(clone->type(), BSTR); + EXPECT_NE(clone->asViewBstr(), nullptr); + EXPECT_EQ(item, *clone->asViewBstr()); + EXPECT_EQ(*clone->asViewBstr(), ViewBstr(sv)); +} + TEST(PrettyPrintingTest, NestedSemanticTag) { SemanticTag item(20, // SemanticTag(30, // @@ -1354,6 +1514,36 @@ TEST(StreamParseTest, Map) { parse(encoded.data(), encoded.data() + encoded.size(), &mpc); } +TEST(StreamParseTest, ViewTstr) { + MockParseClient mpc; + + ViewTstr val("Hello"); + auto encoded = val.encode(); + uint8_t* encBegin = encoded.data(); + uint8_t* encEnd = encoded.data() + encoded.size(); + + EXPECT_CALL(mpc, item(MatchesItem(val), encBegin, encBegin + 1, encEnd)).WillOnce(Return(&mpc)); + EXPECT_CALL(mpc, itemEnd(_, _, _, _)).Times(0); + EXPECT_CALL(mpc, error(_, _)).Times(0); + + parseWithViews(encoded.data(), encoded.data() + encoded.size(), &mpc); +} + +TEST(StreamParseTest, ViewBstr) { + MockParseClient mpc; + + ViewBstr val("Hello"); + auto encoded = val.encode(); + uint8_t* encBegin = encoded.data(); + uint8_t* encEnd = encoded.data() + encoded.size(); + + EXPECT_CALL(mpc, item(MatchesItem(val), encBegin, encBegin + 1, encEnd)).WillOnce(Return(&mpc)); + EXPECT_CALL(mpc, itemEnd(_, _, _, _)).Times(0); + EXPECT_CALL(mpc, error(_, _)).Times(0); + + parseWithViews(encoded.data(), encoded.data() + encoded.size(), &mpc); +} + TEST(FullParserTest, Uint) { Uint val(10); @@ -1506,6 +1696,25 @@ TEST(FullParserTest, MapWithTruncatedEntry) { EXPECT_EQ("Need 4 byte(s) for length field, have 3.", message); } +TEST(FullParserTest, ViewTstr) { + ViewTstr val("Hello"); + + auto enc = val.encode(); + auto [item, pos, message] = parseWithViews(enc.data(), enc.size()); + EXPECT_THAT(item, MatchesItem(val)); +} + +TEST(FullParserTest, ViewBstr) { + const std::string strVal = "\x00\x01\x02"s; + const ViewBstr val(strVal); + EXPECT_EQ(val.toString(), "\x43\x00\x01\x02"s); + + auto enc = val.encode(); + auto [item, pos, message] = parseWithViews(enc.data(), enc.size()); + EXPECT_THAT(item, MatchesItem(val)); + EXPECT_EQ(hexDump(item->toString()), hexDump(val.toString())); +} + TEST(FullParserTest, ReservedAdditionalInformation) { vector reservedVal = {0x1D};