diff --git a/libwvdrmengine/cdm/core/src/license_protocol.proto b/libwvdrmengine/cdm/core/src/license_protocol.proto index 79206d67..9101ae7c 100644 --- a/libwvdrmengine/cdm/core/src/license_protocol.proto +++ b/libwvdrmengine/cdm/core/src/license_protocol.proto @@ -221,6 +221,7 @@ message License { // for verifying the received ECM/EMM signature. Only EC key is supported // for now. PROVIDER_ECM_VERIFIER_PUBLIC_KEY = 7; + OEM_ENTITLEMENT = 8; // Partner-specific entitlement key. } // The SecurityLevel enumeration allows the server to communicate the level diff --git a/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp b/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp index 0ed472e7..8f266bed 100644 --- a/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp +++ b/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp @@ -2,7 +2,7 @@ // source code may only be used and distributed under the Widevine License // Agreement. // These tests are for the cdm engine, and code below it in the stack. In -// particular, we assume that the OEMCrypo layer works, and has a valid keybox. +// particular, we assume that the OEMCrypto layer works, and has a valid keybox. // This is because we need a valid RSA certificate, and will attempt to connect // to the provisioning server to request one if we don't. @@ -65,7 +65,7 @@ class WvCdmEnginePreProvTest : public WvCdmTestBaseWithEngine { CdmResponseType status = cdm_engine_.OpenSession( config_.key_system(), nullptr, nullptr, &session_id_); if (status == NEED_PROVISIONING) { - Provision(); + EnsureProvisioned(); status = cdm_engine_.OpenSession(config_.key_system(), nullptr, nullptr, &session_id_); } @@ -335,20 +335,22 @@ TEST_F(WvCdmEngineTest, SetLicensingServiceInvalidCertificate) { NO_ERROR); }; -TEST_F(WvCdmEnginePreProvTestStaging, ProvisioningTest) { Provision(); } +TEST_F(WvCdmEnginePreProvTestStaging, ProvisioningTest) { EnsureProvisioned(); } -TEST_F(WvCdmEnginePreProvTestUatBinary, ProvisioningTest) { Provision(); } +TEST_F(WvCdmEnginePreProvTestUatBinary, ProvisioningTest) { + EnsureProvisioned(); +} // Test that provisioning works. -TEST_F(WvCdmEngineTest, ProvisioningTest) { Provision(); } +TEST_F(WvCdmEngineTest, ProvisioningTest) { EnsureProvisioned(); } // Test that provisioning works, even if device is already provisioned. TEST_F(WvCdmEngineTest, ReprovisioningTest) { // Provision once. - Provision(); + EnsureProvisioned(); // Verify that we can provision a second time, even though we already // provisioned once. - Provision(); + EnsureProvisioned(); } TEST_F(WvCdmEngineTest, BaseIsoBmffMessageTest) { diff --git a/libwvdrmengine/cdm/core/test/crypto_session_unittest.cpp b/libwvdrmengine/cdm/core/test/crypto_session_unittest.cpp index a49a39e1..3e67165b 100644 --- a/libwvdrmengine/cdm/core/test/crypto_session_unittest.cpp +++ b/libwvdrmengine/cdm/core/test/crypto_session_unittest.cpp @@ -96,6 +96,9 @@ TEST_F(CryptoSessionMetricsTest, OpenSessionValidMetrics) { } else if (token_type == kClientTokenDrmCert) { // TODO(blueeyes): Add support for getting the system id from a // pre-installed DRM certificate.. + } else if (token_type == kClientTokenBootCertChain) { + EXPECT_EQ(OEMCrypto_BootCertificateChain, + metrics_proto.oemcrypto_provisioning_method().int_value()); } else { FAIL() << "Unexpected token type: " << token_type; } @@ -134,9 +137,9 @@ TEST_F(CryptoSessionMetricsTest, GetProvisioningTokenValidMetrics) { ASSERT_GE(metrics_proto.oemcrypto_get_oem_public_certificate().size(), 1); EXPECT_THAT(metrics_proto.oemcrypto_get_oem_public_certificate(0).count(), AllOf(Ge(1), Le(2))); - - ASSERT_GE(metrics_proto.crypto_session_get_token().size(), 1); - EXPECT_GE(metrics_proto.crypto_session_get_token(0).count(), 1); + } else if (token_type == kClientTokenBootCertChain) { + EXPECT_EQ(OEMCrypto_BootCertificateChain, + metrics_proto.oemcrypto_provisioning_method().int_value()); } else { ASSERT_EQ(0, metrics_proto.crypto_session_get_token().size()); } diff --git a/libwvdrmengine/cdm/core/test/fake_provisioning_server.cpp b/libwvdrmengine/cdm/core/test/fake_provisioning_server.cpp index 2c6b813b..4b550450 100644 --- a/libwvdrmengine/cdm/core/test/fake_provisioning_server.cpp +++ b/libwvdrmengine/cdm/core/test/fake_provisioning_server.cpp @@ -284,7 +284,8 @@ bool FakeProvisioningServer::MakeResponse( wvoec::KeyDeriver key_deriver; // Not only is this Prov 2.0 specific, it assumes the device is using the // standard test keybox. - key_deriver.DeriveKeys(wvoec::kTestKeybox.device_key_, mac_context_v, + key_deriver.DeriveKeys(wvoec::kTestKeybox.device_key_, + sizeof(wvoec::kTestKeybox.device_key_), mac_context_v, enc_context_v); // Create a structure to hold the RSA private key. This is used by the key diff --git a/libwvdrmengine/cdm/core/test/reboot_test.cpp b/libwvdrmengine/cdm/core/test/reboot_test.cpp index 30d60b6f..cd704077 100644 --- a/libwvdrmengine/cdm/core/test/reboot_test.cpp +++ b/libwvdrmengine/cdm/core/test/reboot_test.cpp @@ -222,9 +222,11 @@ void RebootTest::SetUp() { EXPECT_EQ(read, file_size) << "Error reading persistent data file."; EXPECT_TRUE(ParseDump(dump, &persistent_data_)); } + TestSleep::SyncFakeClock(); } void RebootTest::TearDown() { + TestSleep::SyncFakeClock(); auto file = file_system_->Open(persistent_data_filename_, FileSystem::kCreate | FileSystem::kTruncate); ASSERT_TRUE(file) << "Failed to open file: " << persistent_data_filename_; @@ -404,6 +406,7 @@ class OfflineLicense { // Fetch and load the license. The session is left open. void LoadLicense() { license_holder_.OpenSession(); + TestSleep::SyncFakeClock(); start_of_rental_clock_ = wvutil::Clock().GetCurrentTime(); license_holder_.FetchLicense(); license_holder_.LoadLicense(); @@ -433,6 +436,7 @@ class OfflineLicense { // Verify that the license may be used to decrypt content. void Decrypt() { + TestSleep::SyncFakeClock(); if (start_of_playback_ == 0) { start_of_playback_ = wvutil::Clock().GetCurrentTime(); } @@ -450,6 +454,7 @@ class OfflineLicense { // Verify that the license has expired, and may not be used to decrypt // content. void FailDecrypt() { + TestSleep::SyncFakeClock(); const KeyId key_id = "0000000000000000"; EXPECT_EQ(NEED_KEY, license_holder_.Decrypt(key_id)) << "Decrypt should have failed for " << content_id_ @@ -628,6 +633,7 @@ class OfflineLicenseTest : public RebootTest { int decrypt_count = 0; int fail_count = 0; for (auto time : interesting_times_) { + TestSleep::SyncFakeClock(); int64_t now = wvutil::Clock().GetCurrentTime(); int64_t delta = (time - now); // It is not necessarily an error for the delta to be negative. But it is @@ -672,6 +678,7 @@ class OfflineLicenseTest : public RebootTest { for (size_t i = first_valid_[n] + 1; i < test_case_[n].size(); i++) { OfflineLicense* license = test_case_[n][i].get(); ASSERT_NO_FATAL_FAILURE(license->ReloadLicense()); + TestSleep::SyncFakeClock(); int64_t now = wvutil::Clock().GetCurrentTime(); if (now <= license->cutoff() - kFudge) { license->Decrypt(); diff --git a/libwvdrmengine/cdm/test/coverage-test.mk b/libwvdrmengine/cdm/test/coverage-test.mk index 963ec633..0c771811 100644 --- a/libwvdrmengine/cdm/test/coverage-test.mk +++ b/libwvdrmengine/cdm/test/coverage-test.mk @@ -44,6 +44,8 @@ LOCAL_SRC_FILES := \ ../../oemcrypto/test/oec_device_features.cpp \ ../../oemcrypto/test/oec_key_deriver.cpp \ ../../oemcrypto/test/oec_session_util.cpp \ + ../../oemcrypto/util/src/oemcrypto_ecc_key.cpp \ + ../../oemcrypto/util/src/oemcrypto_rsa_key.cpp \ LOCAL_C_INCLUDES := \ vendor/widevine/libwvdrmengine/android/cdm/test \ @@ -58,6 +60,7 @@ LOCAL_C_INCLUDES := \ vendor/widevine/libwvdrmengine/oemcrypto/test/fuzz_tests \ vendor/widevine/libwvdrmengine/oemcrypto/odk/include \ vendor/widevine/libwvdrmengine/oemcrypto/odk/kdo/include \ + vendor/widevine/libwvdrmengine/oemcrypto/util/include \ LOCAL_C_INCLUDES += external/protobuf/src diff --git a/libwvdrmengine/cdm/test/integration-test.mk b/libwvdrmengine/cdm/test/integration-test.mk index 3eb314c8..92227ee8 100644 --- a/libwvdrmengine/cdm/test/integration-test.mk +++ b/libwvdrmengine/cdm/test/integration-test.mk @@ -34,6 +34,8 @@ LOCAL_SRC_FILES := \ ../../oemcrypto/test/oec_device_features.cpp \ ../../oemcrypto/test/oec_key_deriver.cpp \ ../../oemcrypto/test/oec_session_util.cpp \ + ../../oemcrypto/util/src/oemcrypto_ecc_key.cpp \ + ../../oemcrypto/util/src/oemcrypto_rsa_key.cpp \ ../util/test/test_sleep.cpp \ LOCAL_C_INCLUDES := \ @@ -49,6 +51,7 @@ LOCAL_C_INCLUDES := \ vendor/widevine/libwvdrmengine/oemcrypto/test/fuzz_tests \ vendor/widevine/libwvdrmengine/oemcrypto/odk/include \ vendor/widevine/libwvdrmengine/oemcrypto/odk/kdo/include \ + vendor/widevine/libwvdrmengine/oemcrypto/util/include \ LOCAL_C_INCLUDES += external/protobuf/src diff --git a/libwvdrmengine/cdm/util/test/test_clock.cpp b/libwvdrmengine/cdm/util/test/test_clock.cpp index 18331bb6..33edd9c3 100644 --- a/libwvdrmengine/cdm/util/test/test_clock.cpp +++ b/libwvdrmengine/cdm/util/test/test_clock.cpp @@ -2,7 +2,9 @@ // source code may only be used and distributed under the Widevine License // Agreement. // -// Clock - A fake clock just for running tests. +// Clock - A fake clock just for running tests. This is used when running +// OEMCrypto unit tests. It is not used when tests include the CE CDM source +// code because that uses the clock in cdm/test_host.cpp instead. #include diff --git a/libwvdrmengine/oemcrypto/odk/include/OEMCryptoCENCCommon.h b/libwvdrmengine/oemcrypto/odk/include/OEMCryptoCENCCommon.h index cb343abb..ce51b8d0 100644 --- a/libwvdrmengine/oemcrypto/odk/include/OEMCryptoCENCCommon.h +++ b/libwvdrmengine/oemcrypto/odk/include/OEMCryptoCENCCommon.h @@ -120,6 +120,11 @@ typedef enum OEMCrypto_Usage_Entry_Status { kInactiveUnused = 4, } OEMCrypto_Usage_Entry_Status; +typedef enum OEMCrypto_ProvisioningRenewalType { + OEMCrypto_NoRenewal = 0, + OEMCrypto_RenewalACert = 1, +} OEMCrypto_ProvisioningRenewalType; + /** * OEMCrypto_LicenseType is used in the license message to indicate if the key * objects are for content keys, or for entitlement keys. diff --git a/libwvdrmengine/oemcrypto/odk/include/core_message_deserialize.h b/libwvdrmengine/oemcrypto/odk/include/core_message_deserialize.h index 76dccc5c..545a8062 100644 --- a/libwvdrmengine/oemcrypto/odk/include/core_message_deserialize.h +++ b/libwvdrmengine/oemcrypto/odk/include/core_message_deserialize.h @@ -17,6 +17,8 @@ #ifndef WIDEVINE_ODK_INCLUDE_CORE_MESSAGE_DESERIALIZE_H_ #define WIDEVINE_ODK_INCLUDE_CORE_MESSAGE_DESERIALIZE_H_ +#include + #include "core_message_types.h" namespace oemcrypto_core_message { @@ -53,6 +55,18 @@ bool CoreProvisioningRequestFromMessage( const std::string& oemcrypto_core_message, ODK_ProvisioningRequest* core_provisioning_request); +/** + * Counterpart (deserializer) of ODK_PrepareCoreRenewedProvisioningRequest + * (serializer) + * + * Parameters: + * [in] oemcrypto_core_message + * [out] core_provisioning_request + */ +bool CoreRenewedProvisioningRequestFromMessage( + const std::string& oemcrypto_core_message, + ODK_ProvisioningRequest* core_provisioning_request); + /** * Serializer counterpart is not used and is therefore not implemented. * diff --git a/libwvdrmengine/oemcrypto/odk/include/core_message_features.h b/libwvdrmengine/oemcrypto/odk/include/core_message_features.h index 07113e6e..16289c6b 100644 --- a/libwvdrmengine/oemcrypto/odk/include/core_message_features.h +++ b/libwvdrmengine/oemcrypto/odk/include/core_message_features.h @@ -30,13 +30,13 @@ struct CoreMessageFeatures { uint32_t maximum_major_version = 17; uint32_t maximum_minor_version = 0; - bool operator==(const CoreMessageFeatures& other) const; - bool operator!=(const CoreMessageFeatures& other) const { + bool operator==(const CoreMessageFeatures &other) const; + bool operator!=(const CoreMessageFeatures &other) const { return !(*this == other); } }; -std::ostream& operator<<(std::ostream& os, const CoreMessageFeatures& features); +std::ostream &operator<<(std::ostream &os, const CoreMessageFeatures &features); } // namespace features } // namespace oemcrypto_core_message diff --git a/libwvdrmengine/oemcrypto/odk/include/core_message_serialize.h b/libwvdrmengine/oemcrypto/odk/include/core_message_serialize.h index 0e1c287b..bd6d6354 100644 --- a/libwvdrmengine/oemcrypto/odk/include/core_message_serialize.h +++ b/libwvdrmengine/oemcrypto/odk/include/core_message_serialize.h @@ -17,6 +17,8 @@ #ifndef WIDEVINE_ODK_INCLUDE_CORE_MESSAGE_SERIALIZE_H_ #define WIDEVINE_ODK_INCLUDE_CORE_MESSAGE_SERIALIZE_H_ +#include + #include "core_message_features.h" #include "core_message_types.h" #include "odk_structs.h" diff --git a/libwvdrmengine/oemcrypto/odk/include/core_message_types.h b/libwvdrmengine/oemcrypto/odk/include/core_message_types.h index 3d02aa91..5315913e 100644 --- a/libwvdrmengine/oemcrypto/odk/include/core_message_types.h +++ b/libwvdrmengine/oemcrypto/odk/include/core_message_types.h @@ -96,7 +96,8 @@ struct ODK_RenewalRequest { }; /** - * Output structure for CoreProvisioningRequestFromMessage + * Output structure for CoreProvisioningRequestFromMessage and + * CoreRenewedProvisioningRequestFromMessage * Input structure for CreateCoreProvisioningResponse */ struct ODK_ProvisioningRequest { @@ -105,6 +106,8 @@ struct ODK_ProvisioningRequest { uint32_t nonce; uint32_t session_id; std::string device_id; + uint16_t renewal_type; + std::string renewal_data; }; } // namespace oemcrypto_core_message diff --git a/libwvdrmengine/oemcrypto/odk/include/odk.h b/libwvdrmengine/oemcrypto/odk/include/odk.h index 941afc11..e3499da2 100644 --- a/libwvdrmengine/oemcrypto/odk/include/odk.h +++ b/libwvdrmengine/oemcrypto/odk/include/odk.h @@ -326,7 +326,7 @@ OEMCryptoResult ODK_PrepareCoreRenewalRequest(uint8_t* message, * OEMCrypto_GetDeviceID. The device ID shall be unique to the device, and * stable across reboots and factory resets for an L1 device. * - * NOTE: if the message pointer is null and/or input core_message_size is + * NOTE: if the message pointer is null and/or input core_message_length is * zero, this function returns OEMCrypto_ERROR_SHORT_BUFFER and sets output * core_message_size to the size needed. * @@ -351,10 +351,56 @@ OEMCryptoResult ODK_PrepareCoreRenewalRequest(uint8_t* message, * This method is new in version 16 of the API. */ OEMCryptoResult ODK_PrepareCoreProvisioningRequest( - uint8_t* message, size_t message_length, size_t* core_message_size, + uint8_t* message, size_t message_length, size_t* core_message_length, const ODK_NonceValues* nonce_values, const uint8_t* device_id, size_t device_id_length); +/** + * Modifies the message to include a core renewal provisioning request at the + * beginning of the message buffer. The values in nonce_values are used to + * populate the message. + * + * This shall be called by OEMCrypto from + * OEMCrypto_PrepAndSignProvisioningRequest. + * + * The buffer device_id shall be the same string returned by + * OEMCrypto_GetDeviceID. The device ID shall be unique to the device, and + * stable across reboots and factory resets for an L1 device. + * + * NOTE: if the message pointer is null and/or input core_message_length is + * zero, this function returns OEMCrypto_ERROR_SHORT_BUFFER and sets output + * core_message_size to the size needed. + * + * @param[in,out] message: pointer to memory for the entire message. Modified by + * the ODK library. + * @param[in] message_length: length of the entire message buffer. + * @param[in,out] core_message_size: length of the core message at the beginning + * of the message. (in) size of buffer reserved for the core message, in + * bytes. (out) actual length of the core message, in bytes. + * @param[in] nonce_values: pointer to the session's nonce data. + * @param[in] device_id: For devices with a keybox, this is the device ID from + * the keybox. For devices with an OEM Certificate, this is a device + * unique id string. + * @param[in] device_id_length: length of device_id. The device ID can be at + * most 64 bytes. + * @param[in] renewal_type: type of renewal used + * @param[in] renewal_data: renewal data used. For renewal_type = 1, + * renewal_data is the Android attestation batch certificate. + * @param[in] renewal_data_length: length of renewal_data + * + * @retval OEMCrypto_SUCCESS + * @retval OEMCrypto_ERROR_SHORT_BUFFER: core_message_size is too small + * @retval OEMCrypto_ERROR_INVALID_CONTEXT + * + * @version + * This method is new in version 17 of the API. + */ +OEMCryptoResult ODK_PrepareCoreRenewedProvisioningRequest( + uint8_t* message, size_t message_length, size_t* core_message_length, + const ODK_NonceValues* nonce_values, const uint8_t* device_id, + size_t device_id_length, uint16_t renewal_type, const uint8_t* renewal_data, + size_t renewal_data_length); + /// @} /// @addtogroup odk_timer diff --git a/libwvdrmengine/oemcrypto/odk/include/odk_message.h b/libwvdrmengine/oemcrypto/odk/include/odk_message.h index e4f135f6..075f28cc 100644 --- a/libwvdrmengine/oemcrypto/odk/include/odk_message.h +++ b/libwvdrmengine/oemcrypto/odk/include/odk_message.h @@ -35,10 +35,10 @@ extern "C" { */ #if defined(__GNUC__) || defined(__clang__) -# define ALIGNED __attribute__((aligned)) +#define ALIGNED __attribute__((aligned)) #else -# define ALIGNED -# error ODK_Message must be aligned to the maximum useful alignment of the \ +#define ALIGNED +#error ODK_Message must be aligned to the maximum useful alignment of the \ machine you are compiling for. Define the ALIGNED macro accordingly. #endif @@ -48,19 +48,19 @@ typedef struct { } ALIGNED ODK_Message; typedef enum { - MESSAGE_STATUS_OK = 0xe937fcf7, - MESSAGE_STATUS_UNKNOWN_ERROR = 0xe06c1190, - MESSAGE_STATUS_OVERFLOW_ERROR = 0xc43ae4bc, + MESSAGE_STATUS_OK = 0x7937fcf7, + MESSAGE_STATUS_UNKNOWN_ERROR = 0x706c1190, + MESSAGE_STATUS_OVERFLOW_ERROR = 0x543ae4bc, MESSAGE_STATUS_UNDERFLOW_ERROR = 0x7123cd0b, MESSAGE_STATUS_PARSE_ERROR = 0x0b9f6189, MESSAGE_STATUS_NULL_POINTER_ERROR = 0x2d66837a, MESSAGE_STATUS_API_VALUE_ERROR = 0x6ba34f47, - MESSAGE_STATUS_END_OF_MESSAGE_ERROR = 0x998db72a, - MESSAGE_STATUS_INVALID_ENUM_VALUE = 0xedb88197, + MESSAGE_STATUS_END_OF_MESSAGE_ERROR = 0x798db72a, + MESSAGE_STATUS_INVALID_ENUM_VALUE = 0x7db88197, MESSAGE_STATUS_INVALID_TAG_ERROR = 0x14dce06a, MESSAGE_STATUS_NOT_INITIALIZED = 0x2990b6c6, - MESSAGE_STATUS_OUT_OF_MEMORY = 0xfc5c64cc, - MESSAGE_STATUS_MAP_SHARED_MEMORY_FAILED = 0xfafecacf, + MESSAGE_STATUS_OUT_OF_MEMORY = 0x7c5c64cc, + MESSAGE_STATUS_MAP_SHARED_MEMORY_FAILED = 0x7afecacf, MESSAGE_STATUS_SECURE_BUFFER_ERROR = 0x78f0e873 } ODK_MessageStatus; diff --git a/libwvdrmengine/oemcrypto/odk/include/odk_structs.h b/libwvdrmengine/oemcrypto/odk/include/odk_structs.h index 72be13bd..fba3c3aa 100644 --- a/libwvdrmengine/oemcrypto/odk/include/odk_structs.h +++ b/libwvdrmengine/oemcrypto/odk/include/odk_structs.h @@ -16,10 +16,10 @@ extern "C" { /* The version of this library. */ #define ODK_MAJOR_VERSION 17 -#define ODK_MINOR_VERSION 0 +#define ODK_MINOR_VERSION 1 /* ODK Version string. Date changed automatically on each release. */ -#define ODK_RELEASE_DATE "ODK v17.0 2022-02-15" +#define ODK_RELEASE_DATE "ODK v17.1 2022-06-17" /* The lowest version number for an ODK message. */ #define ODK_FIRST_VERSION 16 @@ -27,6 +27,7 @@ extern "C" { /* Some useful constants. */ #define ODK_DEVICE_ID_LEN_MAX 64 #define ODK_SHA256_HASH_SIZE 32 +#define ODK_KEYBOX_RENEWAL_DATA_SIZE 1600 /// @addtogroup odk_timer /// @{ diff --git a/libwvdrmengine/oemcrypto/odk/src/core_message_deserialize.cpp b/libwvdrmengine/oemcrypto/odk/src/core_message_deserialize.cpp index 9f485d54..2e69641d 100644 --- a/libwvdrmengine/oemcrypto/odk/src/core_message_deserialize.cpp +++ b/libwvdrmengine/oemcrypto/odk/src/core_message_deserialize.cpp @@ -10,6 +10,7 @@ #include #include +#include "OEMCryptoCENCCommon.h" #include "odk_serialize.h" #include "odk_structs.h" #include "odk_structs_priv.h" @@ -52,6 +53,7 @@ bool ParseRequest(uint32_t message_type, core_request->api_minor_version = core_message.nonce_values.api_minor_version; core_request->nonce = core_message.nonce_values.nonce; core_request->session_id = core_message.nonce_values.session_id; + // Verify that the minor version matches the released version for the given // major version. if (core_request->api_major_version < ODK_FIRST_VERSION) { @@ -68,10 +70,13 @@ bool ParseRequest(uint32_t message_type, // For v16, a release and a renewal use the same message structure. // However, for future API versions, the release might be a separate // message. Otherwise, we expect an exact match of message types. + // A provisioning request may contain a renewed provisioning message. if (message_type != ODK_Common_Request_Type && core_message.message_type != message_type && !(message_type == ODK_Renewal_Request_Type && - core_message.message_type == ODK_Release_Request_Type)) { + core_message.message_type == ODK_Release_Request_Type) && + !(message_type == ODK_Provisioning_Request_Type && + core_message.message_type == ODK_Renewed_Provisioning_Request_Type)) { return false; } // Verify that the amount of buffer we read, which is GetOffset, is not more @@ -125,6 +130,42 @@ bool CoreProvisioningRequestFromMessage( } core_provisioning_request->device_id.assign( reinterpret_cast(device_id), device_id_length); + core_provisioning_request->renewal_type = OEMCrypto_NoRenewal; + core_provisioning_request->renewal_data.clear(); + return true; +} + +bool CoreRenewedProvisioningRequestFromMessage( + const std::string& oemcrypto_core_message, + ODK_ProvisioningRequest* core_provisioning_request) { + const auto unpacker = Unpack_ODK_PreparedRenewedProvisioningRequest; + ODK_PreparedRenewedProvisioningRequest prepared_provision = {}; + if (!ParseRequest(ODK_Renewed_Provisioning_Request_Type, + oemcrypto_core_message, core_provisioning_request, + &prepared_provision, unpacker)) { + return false; + } + const uint8_t* device_id = prepared_provision.device_id; + const uint32_t device_id_length = prepared_provision.device_id_length; + if (device_id_length > ODK_DEVICE_ID_LEN_MAX) { + return false; + } + uint8_t zero[ODK_DEVICE_ID_LEN_MAX] = {}; + if (memcmp(zero, device_id + device_id_length, + ODK_DEVICE_ID_LEN_MAX - device_id_length)) { + return false; + } + core_provisioning_request->device_id.assign( + reinterpret_cast(device_id), device_id_length); + + if (prepared_provision.renewal_data_length > + sizeof(prepared_provision.renewal_data)) { + return false; + } + core_provisioning_request->renewal_type = OEMCrypto_RenewalACert; + core_provisioning_request->renewal_data.assign( + reinterpret_cast(prepared_provision.renewal_data), + prepared_provision.renewal_data_length); return true; } diff --git a/libwvdrmengine/oemcrypto/odk/src/core_message_features.cpp b/libwvdrmengine/oemcrypto/odk/src/core_message_features.cpp index 9dbbecbd..615e4779 100644 --- a/libwvdrmengine/oemcrypto/odk/src/core_message_features.cpp +++ b/libwvdrmengine/oemcrypto/odk/src/core_message_features.cpp @@ -8,7 +8,7 @@ namespace oemcrypto_core_message { namespace features { const CoreMessageFeatures CoreMessageFeatures::kDefaultFeatures; -bool CoreMessageFeatures::operator==(const CoreMessageFeatures& other) const { +bool CoreMessageFeatures::operator==(const CoreMessageFeatures &other) const { return maximum_major_version == other.maximum_major_version && maximum_minor_version == other.maximum_minor_version; } @@ -23,7 +23,7 @@ CoreMessageFeatures CoreMessageFeatures::DefaultFeatures( features.maximum_minor_version = 5; // 16.5 break; case 17: - features.maximum_minor_version = 0; // 17.0 + features.maximum_minor_version = 1; // 17.1 break; default: features.maximum_minor_version = 0; @@ -31,8 +31,8 @@ CoreMessageFeatures CoreMessageFeatures::DefaultFeatures( return features; } -std::ostream& operator<<(std::ostream& os, - const CoreMessageFeatures& features) { +std::ostream &operator<<(std::ostream &os, + const CoreMessageFeatures &features) { return os << "v" << features.maximum_major_version << "." << features.maximum_minor_version; } diff --git a/libwvdrmengine/oemcrypto/odk/src/core_message_serialize.cpp b/libwvdrmengine/oemcrypto/odk/src/core_message_serialize.cpp index 334f4429..3c3590ef 100644 --- a/libwvdrmengine/oemcrypto/odk/src/core_message_serialize.cpp +++ b/libwvdrmengine/oemcrypto/odk/src/core_message_serialize.cpp @@ -13,6 +13,7 @@ #include "odk_serialize.h" #include "odk_structs.h" #include "odk_structs_priv.h" +#include "odk_target.h" #include "serialization_base.h" namespace oemcrypto_core_message { @@ -122,6 +123,9 @@ bool CreateCoreLicenseResponse(const CoreMessageFeatures& features, license_response)) { return false; } + if (ODK_MAX_NUM_KEYS < license_response.parsed_license->key_array_length) { + return false; + } if (license_response.request.core_message.nonce_values.api_major_version == 16) { ODK_LicenseResponseV16 license_response_v16; @@ -143,7 +147,9 @@ bool CreateCoreLicenseResponse(const CoreMessageFeatures& features, license_response_v16.parsed_license.key_array_length = license_response.parsed_license->key_array_length; uint32_t i; - for (i = 0; i < license_response_v16.parsed_license.key_array_length; i++) { + for (i = 0; i < license_response_v16.parsed_license.key_array_length && + i < license_response.parsed_license->key_array_length; + i++) { license_response_v16.parsed_license.key_array[i] = license_response.parsed_license->key_array[i]; } diff --git a/libwvdrmengine/oemcrypto/odk/src/core_message_serialize_proto.cpp b/libwvdrmengine/oemcrypto/odk/src/core_message_serialize_proto.cpp index 860ea267..5132bda2 100644 --- a/libwvdrmengine/oemcrypto/odk/src/core_message_serialize_proto.cpp +++ b/libwvdrmengine/oemcrypto/odk/src/core_message_serialize_proto.cpp @@ -101,8 +101,11 @@ bool CreateCoreLicenseResponseFromProto(const CoreMessageFeatures& features, } case video_widevine::License_KeyContainer::CONTENT: case video_widevine::License_KeyContainer::OPERATOR_SESSION: + case video_widevine::License_KeyContainer::OEM_CONTENT: + case video_widevine::License_KeyContainer::OEM_ENTITLEMENT: case video_widevine::License_KeyContainer::ENTITLEMENT: { - if (k.type() == video_widevine::License_KeyContainer::ENTITLEMENT) { + if (k.type() == video_widevine::License_KeyContainer::ENTITLEMENT || + k.type() == video_widevine::License_KeyContainer::OEM_ENTITLEMENT) { any_entitlement = true; } else { any_content = true; diff --git a/libwvdrmengine/oemcrypto/odk/src/odk.c b/libwvdrmengine/oemcrypto/odk/src/odk.c index 019d50e2..4f283898 100644 --- a/libwvdrmengine/oemcrypto/odk/src/odk.c +++ b/libwvdrmengine/oemcrypto/odk/src/odk.c @@ -72,6 +72,17 @@ static OEMCryptoResult ODK_PrepareRequest( &msg, (ODK_PreparedProvisioningRequest*)prepared_request_buffer); break; } + case ODK_Renewed_Provisioning_Request_Type: { + core_message->message_length = ODK_RENEWED_PROVISIONING_REQUEST_SIZE; + if (sizeof(ODK_PreparedRenewedProvisioningRequest) > + prepared_request_buffer_length) { + return ODK_ERROR_CORE_MESSAGE; + } + Pack_ODK_PreparedRenewedProvisioningRequest( + &msg, + (ODK_PreparedRenewedProvisioningRequest*)prepared_request_buffer); + break; + } default: { return ODK_ERROR_CORE_MESSAGE; } @@ -238,6 +249,37 @@ OEMCryptoResult ODK_PrepareCoreProvisioningRequest( sizeof(ODK_PreparedProvisioningRequest)); } +OEMCryptoResult ODK_PrepareCoreRenewedProvisioningRequest( + uint8_t* message, size_t message_length, size_t* core_message_length, + const ODK_NonceValues* nonce_values, const uint8_t* device_id, + size_t device_id_length, uint16_t renewal_type, const uint8_t* renewal_data, + size_t renewal_data_length) { + if (core_message_length == NULL || nonce_values == NULL) { + return ODK_ERROR_CORE_MESSAGE; + } + ODK_PreparedRenewedProvisioningRequest provisioning_request = {0}; + if (device_id_length > sizeof(provisioning_request.device_id)) { + return ODK_ERROR_CORE_MESSAGE; + } + provisioning_request.device_id_length = (uint32_t)device_id_length; + if (device_id) { + memcpy(provisioning_request.device_id, device_id, device_id_length); + } + if (renewal_data_length > sizeof(provisioning_request.renewal_data)) { + return ODK_ERROR_CORE_MESSAGE; + } + provisioning_request.renewal_type = renewal_type; + provisioning_request.renewal_data_length = (uint32_t)renewal_data_length; + if (renewal_data) { + memcpy(provisioning_request.renewal_data, renewal_data, + renewal_data_length); + } + return ODK_PrepareRequest(message, message_length, core_message_length, + ODK_Renewed_Provisioning_Request_Type, nonce_values, + &provisioning_request, + sizeof(provisioning_request)); +} + /* @@ parse response functions */ OEMCryptoResult ODK_ParseLicense( diff --git a/libwvdrmengine/oemcrypto/odk/src/odk_serialize.c b/libwvdrmengine/oemcrypto/odk/src/odk_serialize.c index 55ea3a4b..5c582000 100644 --- a/libwvdrmengine/oemcrypto/odk/src/odk_serialize.c +++ b/libwvdrmengine/oemcrypto/odk/src/odk_serialize.c @@ -128,12 +128,22 @@ void Pack_ODK_PreparedRenewalRequest(ODK_Message* msg, } void Pack_ODK_PreparedProvisioningRequest( - ODK_Message* msg, ODK_PreparedProvisioningRequest const* obj) { + ODK_Message* msg, const ODK_PreparedProvisioningRequest* obj) { Pack_ODK_CoreMessage(msg, &obj->core_message); Pack_uint32_t(msg, &obj->device_id_length); PackArray(msg, &obj->device_id[0], sizeof(obj->device_id)); } +void Pack_ODK_PreparedRenewedProvisioningRequest( + ODK_Message* msg, const ODK_PreparedRenewedProvisioningRequest* obj) { + Pack_ODK_CoreMessage(msg, &obj->core_message); + Pack_uint32_t(msg, &obj->device_id_length); + PackArray(msg, &obj->device_id[0], sizeof(obj->device_id)); + Pack_uint16_t(msg, &obj->renewal_type); + Pack_uint32_t(msg, &obj->renewal_data_length); + PackArray(msg, &obj->renewal_data[0], sizeof(obj->renewal_data)); +} + /* @@ kdo serialize */ void Pack_ODK_LicenseResponse(ODK_Message* msg, @@ -156,7 +166,7 @@ void Pack_ODK_RenewalResponse(ODK_Message* msg, } void Pack_ODK_ProvisioningResponse(ODK_Message* msg, - ODK_ProvisioningResponse const* obj) { + const ODK_ProvisioningResponse* obj) { Pack_ODK_PreparedProvisioningRequest(msg, &obj->request); Pack_ODK_ParsedProvisioning( msg, (const ODK_ParsedProvisioning*)obj->parsed_provisioning); @@ -302,6 +312,16 @@ void Unpack_ODK_PreparedProvisioningRequest( UnpackArray(msg, &obj->device_id[0], sizeof(obj->device_id)); } +void Unpack_ODK_PreparedRenewedProvisioningRequest( + ODK_Message* msg, ODK_PreparedRenewedProvisioningRequest* obj) { + Unpack_ODK_CoreMessage(msg, &obj->core_message); + Unpack_uint32_t(msg, &obj->device_id_length); + UnpackArray(msg, &obj->device_id[0], sizeof(obj->device_id)); + Unpack_uint16_t(msg, &obj->renewal_type); + Unpack_uint32_t(msg, &obj->renewal_data_length); + UnpackArray(msg, &obj->renewal_data[0], obj->renewal_data_length); +} + void Unpack_ODK_PreparedCommonRequest(ODK_Message* msg, ODK_PreparedCommonRequest* obj) { Unpack_ODK_CoreMessage(msg, &obj->core_message); diff --git a/libwvdrmengine/oemcrypto/odk/src/odk_serialize.h b/libwvdrmengine/oemcrypto/odk/src/odk_serialize.h index c08b4d52..0904700c 100644 --- a/libwvdrmengine/oemcrypto/odk/src/odk_serialize.h +++ b/libwvdrmengine/oemcrypto/odk/src/odk_serialize.h @@ -22,6 +22,8 @@ void Pack_ODK_PreparedRenewalRequest(ODK_Message* msg, const ODK_PreparedRenewalRequest* obj); void Pack_ODK_PreparedProvisioningRequest( ODK_Message* msg, const ODK_PreparedProvisioningRequest* obj); +void Pack_ODK_PreparedRenewedProvisioningRequest( + ODK_Message* msg, const ODK_PreparedRenewedProvisioningRequest* obj); /* odk unpack */ void Unpack_ODK_CoreMessage(ODK_Message* msg, ODK_CoreMessage* obj); @@ -47,6 +49,8 @@ void Unpack_ODK_PreparedRenewalRequest(ODK_Message* msg, ODK_PreparedRenewalRequest* obj); void Unpack_ODK_PreparedProvisioningRequest( ODK_Message* msg, ODK_PreparedProvisioningRequest* obj); +void Unpack_ODK_PreparedRenewedProvisioningRequest( + ODK_Message* msg, ODK_PreparedRenewedProvisioningRequest* obj); void Unpack_ODK_PreparedCommonRequest(ODK_Message* msg, ODK_PreparedCommonRequest* obj); diff --git a/libwvdrmengine/oemcrypto/odk/src/odk_structs_priv.h b/libwvdrmengine/oemcrypto/odk/src/odk_structs_priv.h index 1bfc5978..3fe73eed 100644 --- a/libwvdrmengine/oemcrypto/odk/src/odk_structs_priv.h +++ b/libwvdrmengine/oemcrypto/odk/src/odk_structs_priv.h @@ -24,6 +24,7 @@ typedef uint32_t ODK_MessageType; #define ODK_Renewal_Response_Type ((ODK_MessageType)4u) #define ODK_Provisioning_Request_Type ((ODK_MessageType)5u) #define ODK_Provisioning_Response_Type ((ODK_MessageType)6u) +#define ODK_Renewed_Provisioning_Request_Type ((ODK_MessageType)11u) // Reserve future message types to support forward compatibility. #define ODK_Release_Request_Type ((ODK_MessageType)7u) @@ -52,6 +53,15 @@ typedef struct { uint8_t device_id[ODK_DEVICE_ID_LEN_MAX]; } ODK_PreparedProvisioningRequest; +typedef struct { + ODK_CoreMessage core_message; + uint32_t device_id_length; + uint8_t device_id[ODK_DEVICE_ID_LEN_MAX]; + uint16_t renewal_type; + uint32_t renewal_data_length; + uint8_t renewal_data[ODK_KEYBOX_RENEWAL_DATA_SIZE]; +} ODK_PreparedRenewedProvisioningRequest; + typedef struct { ODK_CoreMessage core_message; } ODK_PreparedCommonRequest; @@ -96,6 +106,7 @@ typedef struct { #define ODK_LICENSE_REQUEST_SIZE 20u #define ODK_RENEWAL_REQUEST_SIZE 28u #define ODK_PROVISIONING_REQUEST_SIZE 88u +#define ODK_RENEWED_PROVISIONING_REQUEST_SIZE 1694u // These are the possible timer status values. #define ODK_CLOCK_TIMER_STATUS_UNDEFINED 0u // Should not happen. diff --git a/libwvdrmengine/oemcrypto/odk/src/serialization_base.c b/libwvdrmengine/oemcrypto/odk/src/serialization_base.c index 90b84b31..30af34cf 100644 --- a/libwvdrmengine/oemcrypto/odk/src/serialization_base.c +++ b/libwvdrmengine/oemcrypto/odk/src/serialization_base.c @@ -38,7 +38,7 @@ static void PackBytes(ODK_Message* message, const uint8_t* ptr, size_t count) { } void Pack_enum(ODK_Message* message, int value) { - uint32_t v32 = value; + uint32_t v32 = (uint32_t)value; Pack_uint32_t(message, &v32); } diff --git a/libwvdrmengine/oemcrypto/odk/test/odk_core_message_test.cpp b/libwvdrmengine/oemcrypto/odk/test/odk_core_message_test.cpp index 24fbe6da..22051b22 100644 --- a/libwvdrmengine/oemcrypto/odk/test/odk_core_message_test.cpp +++ b/libwvdrmengine/oemcrypto/odk/test/odk_core_message_test.cpp @@ -2,6 +2,8 @@ // source code may only be used and distributed under the Widevine // License Agreement. +#include + #include "OEMCryptoCENCCommon.h" #include "gtest/gtest.h" #include "odk.h" diff --git a/libwvdrmengine/oemcrypto/odk/test/odk_test.cpp b/libwvdrmengine/oemcrypto/odk/test/odk_test.cpp index d1a817d8..a244d25b 100644 --- a/libwvdrmengine/oemcrypto/odk/test/odk_test.cpp +++ b/libwvdrmengine/oemcrypto/odk/test/odk_test.cpp @@ -6,6 +6,7 @@ #include #include +#include #include "OEMCryptoCENCCommon.h" #include "core_message_deserialize.h" @@ -27,6 +28,8 @@ using oemcrypto_core_message::ODK_RenewalRequest; using oemcrypto_core_message::deserialize::CoreLicenseRequestFromMessage; using oemcrypto_core_message::deserialize::CoreProvisioningRequestFromMessage; using oemcrypto_core_message::deserialize::CoreRenewalRequestFromMessage; +using oemcrypto_core_message::deserialize:: + CoreRenewedProvisioningRequestFromMessage; using oemcrypto_core_message::features::CoreMessageFeatures; @@ -270,6 +273,35 @@ TEST(OdkTest, NullRequestTest) { ODK_PrepareCoreProvisioningRequest( message, ODK_PROVISIONING_REQUEST_SIZE, &core_message_length, &nonce_values, nullptr, 0uL)); + + EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, + ODK_PrepareCoreRenewedProvisioningRequest( + nullptr, 0uL, &core_message_length, nullptr, nullptr, 0uL, + OEMCrypto_RenewalACert, nullptr, 0uL)); + EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, + ODK_PrepareCoreRenewedProvisioningRequest( + nullptr, 0uL, nullptr, &nonce_values, nullptr, 0uL, + OEMCrypto_RenewalACert, nullptr, 0uL)); + + // Null device id in renewed provisioning request is ok + uint8_t renewed_message[ODK_RENEWED_PROVISIONING_REQUEST_SIZE] = {0}; + uint8_t renewal_data[ODK_KEYBOX_RENEWAL_DATA_SIZE] = {0}; + uint32_t renewal_data_length = ODK_KEYBOX_RENEWAL_DATA_SIZE; + core_message_length = ODK_RENEWED_PROVISIONING_REQUEST_SIZE; + EXPECT_EQ(OEMCrypto_SUCCESS, + ODK_PrepareCoreRenewedProvisioningRequest( + renewed_message, ODK_RENEWED_PROVISIONING_REQUEST_SIZE, + &core_message_length, &nonce_values, nullptr, 0uL, + OEMCrypto_RenewalACert, renewal_data, renewal_data_length)); + + // Null renewal data in renewed provisioning request is ok + uint8_t device_id[ODK_DEVICE_ID_LEN_MAX] = {0}; + uint32_t device_id_length = ODK_DEVICE_ID_LEN_MAX; + core_message_length = ODK_RENEWED_PROVISIONING_REQUEST_SIZE; + ODK_PrepareCoreRenewedProvisioningRequest( + renewed_message, ODK_RENEWED_PROVISIONING_REQUEST_SIZE, + &core_message_length, &nonce_values, device_id, device_id_length, + OEMCrypto_RenewalACert, nullptr, 0uL); } TEST(OdkTest, NullResponseTest) { @@ -422,6 +454,21 @@ TEST(OdkTest, PrepareCoreProvisioningRequest) { &core_message_length, &nonce_values, device_id, sizeof(device_id))); } +TEST(OdkTest, PrepareCoreRenewedProvisioningRequest) { + uint8_t provisioning_message[ODK_RENEWED_PROVISIONING_REQUEST_SIZE] = {0}; + size_t core_message_length = sizeof(provisioning_message); + ODK_NonceValues nonce_values; + memset(&nonce_values, 0, sizeof(nonce_values)); + uint8_t device_id[ODK_DEVICE_ID_LEN_MAX] = {0}; + uint8_t renewal_data[ODK_KEYBOX_RENEWAL_DATA_SIZE] = {0}; + EXPECT_EQ( + OEMCrypto_SUCCESS, + ODK_PrepareCoreRenewedProvisioningRequest( + provisioning_message, sizeof(provisioning_message), + &core_message_length, &nonce_values, device_id, sizeof(device_id), + OEMCrypto_RenewalACert, renewal_data, sizeof(renewal_data))); +} + TEST(OdkTest, PrepareCoreProvisioningRequestDeviceId) { uint8_t provisioning_message[ODK_PROVISIONING_REQUEST_SIZE] = {0}; size_t core_message_length = sizeof(provisioning_message); @@ -435,6 +482,36 @@ TEST(OdkTest, PrepareCoreProvisioningRequestDeviceId) { sizeof(device_id_invalid))); } +TEST(OdkTest, PrepareCoreRenewedProvisioningRequestDeviceId) { + uint8_t provisioning_message[ODK_PROVISIONING_REQUEST_SIZE] = {0}; + size_t core_message_length = sizeof(provisioning_message); + ODK_NonceValues nonce_values; + memset(&nonce_values, 0, sizeof(nonce_values)); + uint8_t device_id_invalid[ODK_DEVICE_ID_LEN_MAX + 1] = {0}; + uint8_t renewal_data[ODK_KEYBOX_RENEWAL_DATA_SIZE] = {0}; + EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, + ODK_PrepareCoreRenewedProvisioningRequest( + provisioning_message, sizeof(provisioning_message), + &core_message_length, &nonce_values, device_id_invalid, + sizeof(device_id_invalid), OEMCrypto_RenewalACert, renewal_data, + sizeof(renewal_data))); +} + +TEST(OdkTest, PrepareCoreRenewedProvisioningRequestRenewalDataInvalid) { + uint8_t provisioning_message[ODK_PROVISIONING_REQUEST_SIZE] = {0}; + size_t core_message_length = sizeof(provisioning_message); + ODK_NonceValues nonce_values; + memset(&nonce_values, 0, sizeof(nonce_values)); + uint8_t device_id[ODK_DEVICE_ID_LEN_MAX] = {0}; + uint8_t renewal_data_invalid[ODK_KEYBOX_RENEWAL_DATA_SIZE + 1] = {0}; + EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, + ODK_PrepareCoreRenewedProvisioningRequest( + provisioning_message, sizeof(provisioning_message), + &core_message_length, &nonce_values, device_id, + sizeof(device_id), OEMCrypto_RenewalACert, renewal_data_invalid, + sizeof(renewal_data_invalid))); +} + // Serialize and de-serialize license request TEST(OdkTest, LicenseRequestRoundtrip) { std::vector empty; @@ -497,6 +574,39 @@ TEST(OdkTest, ProvisionRequestRoundtrip) { kdo_parse_func); } +TEST(OdkTest, RenewedProvisionRequestRoundtrip) { + uint32_t device_id_length = ODK_DEVICE_ID_LEN_MAX / 2; + uint8_t device_id[ODK_DEVICE_ID_LEN_MAX] = {0}; + memset(device_id, 0xff, device_id_length); + uint16_t renewal_type = OEMCrypto_RenewalACert; + uint32_t renewal_data_length = ODK_KEYBOX_RENEWAL_DATA_SIZE / 2; + uint8_t renewal_data[ODK_KEYBOX_RENEWAL_DATA_SIZE] = {0}; + memset(renewal_data, 0xff, renewal_data_length); + std::vector extra_fields = { + {ODK_UINT32, &device_id_length, "device_id_length"}, + {ODK_DEVICEID, device_id, "device_id"}, + {ODK_UINT16, &renewal_type, "renewal_type"}, + {ODK_UINT32, &renewal_data_length, "renewal_data_length"}, + {ODK_RENEWALDATA, renewal_data, "renewal_data"}, + }; + auto odk_prepare_func = [&](uint8_t* const buf, size_t* size, + const ODK_NonceValues* nonce_values) { + return ODK_PrepareCoreRenewedProvisioningRequest( + buf, SIZE_MAX, size, nonce_values, device_id, device_id_length, + renewal_type, renewal_data, renewal_data_length); + }; + auto kdo_parse_func = + [&](const std::string& oemcrypto_core_message, + ODK_ProvisioningRequest* core_provisioning_request) { + bool ok = CoreRenewedProvisioningRequestFromMessage( + oemcrypto_core_message, core_provisioning_request); + return ok; + }; + ValidateRequest( + ODK_Renewed_Provisioning_Request_Type, extra_fields, odk_prepare_func, + kdo_parse_func); +} + TEST(OdkTest, ParseLicenseErrorNonce) { ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); @@ -761,6 +871,7 @@ std::vector TestCases() { {17, 16, 4, 16, 4}, {17, 16, 5, 16, 5}, {17, 17, 0, 17, 0}, + {17, 17, 1, 17, 1}, }; return test_cases; } diff --git a/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.cpp b/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.cpp index c1cf4658..dab9afa3 100644 --- a/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.cpp +++ b/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "OEMCryptoCENCCommon.h" @@ -329,6 +330,8 @@ size_t ODK_FieldLength(ODK_FieldType type) { return sizeof(uint32_t) + sizeof(uint32_t); case ODK_DEVICEID: return ODK_DEVICE_ID_LEN_MAX; + case ODK_RENEWALDATA: + return ODK_KEYBOX_RENEWAL_DATA_SIZE; case ODK_HASH: return ODK_SHA256_HASH_SIZE; default: @@ -385,6 +388,7 @@ OEMCryptoResult ODK_WriteSingleField(uint8_t* buf, const ODK_Field* field) { break; } case ODK_DEVICEID: + case ODK_RENEWALDATA: case ODK_HASH: { const size_t field_len = ODK_FieldLength(field->type); const uint8_t* const id = static_cast(field->value); @@ -444,6 +448,7 @@ OEMCryptoResult ODK_ReadSingleField(const uint8_t* buf, break; } case ODK_DEVICEID: + case ODK_RENEWALDATA: case ODK_HASH: { const size_t field_len = ODK_FieldLength(field->type); uint8_t* const id = static_cast(field->value); @@ -503,6 +508,7 @@ OEMCryptoResult ODK_DumpSingleField(const uint8_t* buf, break; } case ODK_DEVICEID: + case ODK_RENEWALDATA: case ODK_HASH: { const size_t field_len = ODK_FieldLength(field->type); std::cerr << field->name << ": "; diff --git a/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.h b/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.h index 650950b2..f825af13 100644 --- a/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.h +++ b/libwvdrmengine/oemcrypto/odk/test/odk_test_helper.h @@ -21,6 +21,7 @@ enum ODK_FieldType { ODK_UINT64, ODK_SUBSTRING, ODK_DEVICEID, + ODK_RENEWALDATA, ODK_HASH, // The "stressable" types are the ones we can put in a stress test that packs // and unpacks random data and can expect to get back the same thing. diff --git a/libwvdrmengine/oemcrypto/test/common.mk b/libwvdrmengine/oemcrypto/test/common.mk index fb173896..0486c8a7 100644 --- a/libwvdrmengine/oemcrypto/test/common.mk +++ b/libwvdrmengine/oemcrypto/test/common.mk @@ -30,6 +30,8 @@ LOCAL_SRC_FILES:= \ ota_keybox_test.cpp \ wvcrc.cpp \ ../../cdm/util/test/test_sleep.cpp \ + ../util/src/oemcrypto_ecc_key.cpp \ + ../util/src/oemcrypto_rsa_key.cpp \ LOCAL_C_INCLUDES += \ $(LOCAL_PATH)/fuzz_tests \ @@ -37,6 +39,7 @@ LOCAL_C_INCLUDES += \ $(LOCAL_PATH)/../odk/include \ $(LOCAL_PATH)/../odk/kdo/include \ $(LOCAL_PATH)/../ref/src \ + $(LOCAL_PATH)/../util/include \ vendor/widevine/libwvdrmengine/cdm/core/include \ vendor/widevine/libwvdrmengine/cdm/util/include \ vendor/widevine/libwvdrmengine/cdm/util/test \ diff --git a/libwvdrmengine/oemcrypto/test/oec_key_deriver.cpp b/libwvdrmengine/oemcrypto/test/oec_key_deriver.cpp index dc1d2371..21d50781 100644 --- a/libwvdrmengine/oemcrypto/test/oec_key_deriver.cpp +++ b/libwvdrmengine/oemcrypto/test/oec_key_deriver.cpp @@ -83,19 +83,23 @@ void Encryptor::PadAndEncryptProvisioningMessage( // This generates the data for deriving one key. If there are failures in // this function, then there is something wrong with the test program and its // dependency on BoringSSL. -void KeyDeriver::DeriveKey(const uint8_t* key, const vector& context, - int counter, vector* out) { +void KeyDeriver::DeriveKey(const uint8_t* key, size_t master_key_size, + const vector& context, int counter, + vector* out) { ASSERT_NE(key, nullptr); ASSERT_FALSE(context.empty()); ASSERT_GE(4, counter); ASSERT_LE(1, counter); ASSERT_NE(out, nullptr); + // For RSA, the master key is expected to be 16 bytes; for EC key, 32 bytes. + ASSERT_TRUE(master_key_size == KEY_SIZE || master_key_size == 2 * KEY_SIZE); - const EVP_CIPHER* cipher = EVP_aes_128_cbc(); + const EVP_CIPHER* cipher = + master_key_size == KEY_SIZE ? EVP_aes_128_cbc() : EVP_aes_256_cbc(); CMAC_CTX* cmac_ctx = CMAC_CTX_new(); ASSERT_NE(nullptr, cmac_ctx); - ASSERT_TRUE(CMAC_Init(cmac_ctx, key, KEY_SIZE, cipher, nullptr)); + ASSERT_TRUE(CMAC_Init(cmac_ctx, key, master_key_size, cipher, nullptr)); std::vector message; message.push_back(static_cast(counter)); @@ -114,24 +118,24 @@ void KeyDeriver::DeriveKey(const uint8_t* key, const vector& context, // This generates the data for deriving a set of keys. If there are failures in // this function, then there is something wrong with the test program and its // dependency on BoringSSL. -void KeyDeriver::DeriveKeys(const uint8_t* master_key, +void KeyDeriver::DeriveKeys(const uint8_t* master_key, size_t master_key_size, const vector& mac_key_context, const vector& enc_key_context) { // Generate derived key for mac key std::vector mac_key_part2; - DeriveKey(master_key, mac_key_context, 1, &mac_key_server_); - DeriveKey(master_key, mac_key_context, 2, &mac_key_part2); + DeriveKey(master_key, master_key_size, mac_key_context, 1, &mac_key_server_); + DeriveKey(master_key, master_key_size, mac_key_context, 2, &mac_key_part2); mac_key_server_.insert(mac_key_server_.end(), mac_key_part2.begin(), mac_key_part2.end()); - DeriveKey(master_key, mac_key_context, 3, &mac_key_client_); - DeriveKey(master_key, mac_key_context, 4, &mac_key_part2); + DeriveKey(master_key, master_key_size, mac_key_context, 3, &mac_key_client_); + DeriveKey(master_key, master_key_size, mac_key_context, 4, &mac_key_part2); mac_key_client_.insert(mac_key_client_.end(), mac_key_part2.begin(), mac_key_part2.end()); // Generate derived key for encryption key std::vector enc_key; - DeriveKey(master_key, enc_key_context, 1, &enc_key); + DeriveKey(master_key, master_key_size, enc_key_context, 1, &enc_key); set_enc_key(enc_key); } diff --git a/libwvdrmengine/oemcrypto/test/oec_key_deriver.h b/libwvdrmengine/oemcrypto/test/oec_key_deriver.h index 65b4ad71..4741da49 100644 --- a/libwvdrmengine/oemcrypto/test/oec_key_deriver.h +++ b/libwvdrmengine/oemcrypto/test/oec_key_deriver.h @@ -62,7 +62,7 @@ class KeyDeriver : public Encryptor { KeyDeriver& operator=(const KeyDeriver&) = default; // Generate mac and enc keys give the master key. - void DeriveKeys(const uint8_t* master_key, + void DeriveKeys(const uint8_t* master_key, size_t master_key_size, const std::vector& mac_key_context, const std::vector& enc_key_context); // Sign the buffer with server's mac key. @@ -80,9 +80,11 @@ class KeyDeriver : public Encryptor { void set_mac_keys(const uint8_t* mac_keys); private: - // Internal utility function to derive key using CMAC-128 - void DeriveKey(const uint8_t* key, const std::vector& context, - int counter, std::vector* out); + // Internal utility function to derive key using CMAC-128 or CMAC-256 based on + // master_key_size. + void DeriveKey(const uint8_t* key, size_t master_key_size, + const std::vector& context, int counter, + std::vector* out); std::vector mac_key_server_; std::vector mac_key_client_; diff --git a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp index 7ffd2cea..393fdfb5 100644 --- a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp +++ b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -309,12 +310,12 @@ void ProvisioningRoundTrip::PrepareSession( OEMCrypto_BootCertificateChain) { // TODO(chelu): change this to CSR provisioning. session_->LoadOEMCert(true); - session_->GenerateRSASessionKey(&message_key_, &encrypted_message_key_); + session_->GenerateRsaSessionKey(&message_key_, &encrypted_message_key_); encryptor_.set_enc_key(message_key_); } else { EXPECT_EQ(global_features.provisioning_method, OEMCrypto_OEMCertificate); session_->LoadOEMCert(true); - session_->GenerateRSASessionKey(&message_key_, &encrypted_message_key_); + session_->GenerateRsaSessionKey(&message_key_, &encrypted_message_key_); encryptor_.set_enc_key(message_key_); } } @@ -323,7 +324,7 @@ void ProvisioningRoundTrip::VerifyRequestSignature( const vector& data, const vector& generated_signature, size_t /* core_message_length */) { if (global_features.provisioning_method == OEMCrypto_OEMCertificate) { - session()->VerifyRSASignature(data, generated_signature.data(), + session()->VerifyRsaSignature(data, generated_signature.data(), generated_signature.size(), kSign_RSASSA_PSS); } else { EXPECT_EQ(global_features.provisioning_method, OEMCrypto_Keybox); @@ -564,12 +565,12 @@ void LicenseRoundTrip::VerifyRequestSignature( if (global_features.api_version < 17) { const std::vector subdata(data.begin() + core_message_length, data.end()); - session()->VerifyRSASignature(subdata, generated_signature.data(), + session()->VerifyRsaSignature(subdata, generated_signature.data(), generated_signature.size(), kSign_RSASSA_PSS); SHA256(data.data(), core_message_length, request_hash_); } else { - session()->VerifyRSASignature(data, generated_signature.data(), - generated_signature.size(), kSign_RSASSA_PSS); + session()->VerifySignature(data, generated_signature.data(), + generated_signature.size(), kSign_RSASSA_PSS); SHA256(data.data(), core_message_length, request_hash_); } } @@ -1422,17 +1423,15 @@ OEMCryptoResult RenewalRoundTrip::LoadResponse(Session* session) { } } -Session::Session() - : open_(false), - forced_session_id_(false), - session_id_(0), - nonce_(0), - public_rsa_(nullptr) {} +std::unordered_map, + std::hash> + Session::server_ephemeral_keys_; +std::mutex Session::ephemeral_key_map_lock_; + +Session::Session() {} Session::~Session() { if (!forced_session_id_ && open_) close(); - if (public_rsa_) RSA_free(public_rsa_); - if (public_ec_) EC_KEY_free(public_ec_); } void Session::open() { @@ -1505,17 +1504,19 @@ void Session::GenerateDerivedKeysFromKeybox( OEMCrypto_GenerateDerivedKeys( session_id(), mac_context.data(), mac_context.size(), enc_context.data(), enc_context.size())); - key_deriver_.DeriveKeys(keybox.device_key_, mac_context, enc_context); + key_deriver_.DeriveKeys(keybox.device_key_, sizeof(keybox.device_key_), + mac_context, enc_context); } void Session::GenerateDerivedKeysFromSessionKey() { // Uses test certificate. vector session_key; vector enc_session_key; - ASSERT_NE(public_rsa_, nullptr) << "No public RSA key loaded in test code."; + ASSERT_TRUE(public_rsa_ || public_ec_) + << "No public RSA/ECC key loaded in test code"; // A failure here probably indicates that there is something wrong with the // test program and its dependency on BoringSSL. - ASSERT_TRUE(GenerateRSASessionKey(&session_key, &enc_session_key)); + ASSERT_TRUE(GenerateSessionKey(&session_key, &enc_session_key)); vector mac_context; vector enc_context; FillDefaultContext(&mac_context, &enc_context); @@ -1526,7 +1527,8 @@ void Session::GenerateDerivedKeysFromSessionKey() { mac_context.data(), mac_context.size(), enc_context.data(), enc_context.size())); - key_deriver_.DeriveKeys(session_key.data(), mac_context, enc_context); + key_deriver_.DeriveKeys(session_key.data(), session_key.size(), mac_context, + enc_context); } void Session::TestDecryptCTR(bool select_key_first, @@ -1640,12 +1642,11 @@ void Session::LoadOEMCert(bool verify_cert) { boringssl_ptr pubkey(X509_get_pubkey(x509_cert)); ASSERT_TRUE(pubkey.NotNull()); if (i == 0) { - public_rsa_ = EVP_PKEY_get1_RSA(pubkey.get()); - if (!public_rsa_) { - cerr << "d2i_RSAPrivateKey failed.\n"; - dump_boringssl_error(); - ASSERT_TRUE(nullptr != public_rsa_); - } + public_rsa_ = + util::RsaPublicKey::FromSslHandle(EVP_PKEY_get0_RSA(pubkey.get())); + ASSERT_TRUE(public_rsa_) + << "Failed to extract public RSA key from OEM certificate"; + return; } if (verify_cert) { vector buffer(80); @@ -1677,199 +1678,181 @@ void Session::LoadOEMCert(bool verify_cert) { } } -void Session::PreparePublicKey(const uint8_t* rsa_key, size_t rsa_key_length) { - if (rsa_key == nullptr) { - rsa_key = kTestRSAPKCS8PrivateKeyInfo2_2048; - rsa_key_length = sizeof(kTestRSAPKCS8PrivateKeyInfo2_2048); - } - uint8_t* p = const_cast(rsa_key); - boringssl_ptr bio( - BIO_new_mem_buf(p, static_cast(rsa_key_length))); - ASSERT_TRUE(bio.NotNull()); - boringssl_ptr pkcs8_pki( - d2i_PKCS8_PRIV_KEY_INFO_bio(bio.get(), nullptr)); - ASSERT_TRUE(pkcs8_pki.NotNull()); - boringssl_ptr evp(EVP_PKCS82PKEY(pkcs8_pki.get())); - ASSERT_TRUE(evp.NotNull()); - if (public_rsa_) RSA_free(public_rsa_); - public_rsa_ = EVP_PKEY_get1_RSA(evp.get()); - if (!public_rsa_) { - cerr << "d2i_RSAPrivateKey failed. "; - dump_boringssl_error(); - FAIL() << "Could not parse public RSA key."; - } - switch (RSA_check_key(public_rsa_)) { - case 1: // valid. +void Session::SetTestRsaPublicKey() { + public_ec_.reset(); + public_rsa_ = util::RsaPublicKey::LoadPrivateKeyInfo( + kTestRSAPKCS8PrivateKeyInfo2_2048, + sizeof(kTestRSAPKCS8PrivateKeyInfo2_2048)); + ASSERT_TRUE(public_rsa_) << "Could not parse test RSA public key #2"; +} + +void Session::SetPublicKeyFromPrivateKeyInfo(OEMCrypto_PrivateKeyType key_type, + const uint8_t* buffer, + size_t length) { + switch (key_type) { + case OEMCrypto_RSA_Private_Key: + ASSERT_NO_FATAL_FAILURE( + SetRsaPublicKeyFromPrivateKeyInfo(buffer, length)); return; - case 0: // not valid. - dump_boringssl_error(); - FAIL() << "[rsa key not valid] "; - default: // -1 == check failed. - dump_boringssl_error(); - FAIL() << "[error checking rsa key] "; - } -} - -void Session::SetRsaPublicKey(const uint8_t* buffer, size_t length) { - if (public_rsa_) { - RSA_free(public_rsa_); - public_rsa_ = nullptr; - } - if (public_ec_) { - EC_KEY_free(public_ec_); - public_ec_ = nullptr; - } - public_rsa_ = d2i_RSA_PUBKEY(nullptr, &buffer, length); - if (!public_rsa_) { - cout << "d2i_RSAPrivateKey failed. "; - dump_boringssl_error(); - FAIL() << "Could not parse public RSA key."; - } - switch (RSA_check_key(public_rsa_)) { - case 1: // valid. + case OEMCrypto_ECC_Private_Key: + ASSERT_NO_FATAL_FAILURE( + SetEccPublicKeyFromPrivateKeyInfo(buffer, length)); return; - case 0: // not valid. - dump_boringssl_error(); - FAIL() << "[rsa key not valid] "; - default: // -1 == check failed. - dump_boringssl_error(); - FAIL() << "[error checking rsa key] "; } + FAIL() << "Unknown key type: " << static_cast(key_type); } -void Session::SetEcPublicKey(const uint8_t* buffer, size_t length) { - if (public_rsa_) { - RSA_free(public_rsa_); - public_rsa_ = nullptr; - } - if (public_ec_) { - EC_KEY_free(public_ec_); - public_ec_ = nullptr; - } - public_ec_ = d2i_EC_PUBKEY(nullptr, &buffer, length); - if (!public_ec_) { - cout << "d2i_RSAPrivateKey failed. "; - dump_boringssl_error(); - FAIL() << "Could not parse public RSA key."; - } - switch (EC_KEY_check_key(public_ec_)) { - case 1: // valid. +void Session::SetRsaPublicKeyFromPrivateKeyInfo(const uint8_t* buffer, + size_t length) { + public_ec_.reset(); + public_rsa_ = util::RsaPublicKey::LoadPrivateKeyInfo(buffer, length); + ASSERT_TRUE(public_rsa_) << "Could not parse RSA public key"; +} + +void Session::SetEccPublicKeyFromPrivateKeyInfo(const uint8_t* buffer, + size_t length) { + public_rsa_.reset(); + public_ec_ = util::EccPublicKey::LoadPrivateKeyInfo(buffer, length); + ASSERT_TRUE(public_ec_) << "Could not parse ECC public key"; +} + +void Session::SetPublicKeyFromSubjectPublicKey( + OEMCrypto_PrivateKeyType key_type, const uint8_t* buffer, size_t length) { + switch (key_type) { + case OEMCrypto_RSA_Private_Key: + ASSERT_NO_FATAL_FAILURE( + SetRsaPublicKeyFromSubjectPublicKey(buffer, length)); + return; + case OEMCrypto_ECC_Private_Key: + ASSERT_NO_FATAL_FAILURE( + SetEccPublicKeyFromSubjectPublicKey(buffer, length)); return; - case 0: // not valid. - default: - dump_boringssl_error(); - FAIL() << "[ec key not valid] "; } + FAIL() << "Unknown key type: " << static_cast(key_type); } -bool Session::VerifyPSSSignature(EVP_PKEY* pkey, const uint8_t* message, - size_t message_length, - const uint8_t* signature, - size_t signature_length) { - boringssl_ptr md_ctx(EVP_MD_CTX_new()); - EVP_PKEY_CTX* pkey_ctx = nullptr; - - if (EVP_DigestVerifyInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), - nullptr /* no ENGINE */, pkey) != 1) { - LOGE("EVP_DigestVerifyInit failed in VerifyPSSSignature"); - goto err; - } - - if (EVP_PKEY_CTX_set_signature_md(pkey_ctx, - const_cast(EVP_sha1())) != 1) { - LOGE("EVP_PKEY_CTX_set_signature_md failed in VerifyPSSSignature"); - goto err; - } - - if (EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING) != 1) { - LOGE("EVP_PKEY_CTX_set_rsa_padding failed in VerifyPSSSignature"); - goto err; - } - - if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, SHA_DIGEST_LENGTH) != 1) { - LOGE("EVP_PKEY_CTX_set_rsa_pss_saltlen failed in VerifyPSSSignature"); - goto err; - } - - if (EVP_DigestVerifyUpdate(md_ctx.get(), message, message_length) != 1) { - LOGE("EVP_DigestVerifyUpdate failed in VerifyPSSSignature"); - goto err; - } - - if (EVP_DigestVerifyFinal(md_ctx.get(), const_cast(signature), - signature_length) != 1) { - LOGE( - "EVP_DigestVerifyFinal failed in VerifyPSSSignature. (Probably a bad " - "signature.)"); - goto err; - } - - return true; - -err: - dump_boringssl_error(); - return false; +void Session::SetRsaPublicKeyFromSubjectPublicKey(const uint8_t* buffer, + size_t length) { + public_ec_.reset(); + public_rsa_ = util::RsaPublicKey::Load(buffer, length); + ASSERT_TRUE(public_rsa_) << "Could not parse RSA public key"; } -void Session::VerifyRSASignature(const vector& message, +void Session::SetEccPublicKeyFromSubjectPublicKey(const uint8_t* buffer, + size_t length) { + public_rsa_.reset(); + public_ec_ = util::EccPublicKey::Load(buffer, length); + ASSERT_TRUE(public_ec_) << "Could not parse ECC public key"; +} + +void Session::VerifyRsaSignature(const vector& message, const uint8_t* signature, size_t signature_length, RSA_Padding_Scheme padding_scheme) { - ASSERT_NE(public_rsa_, nullptr) << "No public RSA key loaded in test code."; - - ASSERT_EQ(static_cast(RSA_size(public_rsa_)), signature_length) - << "Signature size is wrong. " << signature_length << ", should be " - << RSA_size(public_rsa_); - - if (padding_scheme == kSign_RSASSA_PSS) { - boringssl_ptr pkey(EVP_PKEY_new()); - ASSERT_EQ(1, EVP_PKEY_set1_RSA(pkey.get(), public_rsa_)); - - const bool ok = - VerifyPSSSignature(pkey.get(), message.data(), message.size(), - signature, signature_length); - EXPECT_TRUE(ok) << "PSS signature check failed."; - } else if (padding_scheme == kSign_PKCS1_Block1) { - vector padded_digest(signature_length); - int size; - // RSA_public_decrypt decrypts the signature, and then verifies that - // it was padded with RSA PKCS1 padding. - size = RSA_public_decrypt(static_cast(signature_length), signature, - padded_digest.data(), public_rsa_, - RSA_PKCS1_PADDING); - EXPECT_GT(size, 0); - padded_digest.resize(size); - EXPECT_EQ(message, padded_digest); - } else { - EXPECT_TRUE(false) << "Padding scheme not supported."; + ASSERT_TRUE(public_rsa_) << "No public RSA key loaded in test code"; + if (padding_scheme != kSign_RSASSA_PSS && + padding_scheme != kSign_PKCS1_Block1) { + FAIL() << "Padding scheme not supported: " << padding_scheme; + return; } + const util::RsaSignatureAlgorithm algorithm = + padding_scheme == kSign_RSASSA_PSS ? util::kRsaPssDefault + : util::kRsaPkcs1Cast; + const OEMCryptoResult result = public_rsa_->VerifySignature( + message.data(), message.size(), signature, signature_length, algorithm); + ASSERT_EQ(result, OEMCrypto_SUCCESS) << "RSA signature check failed"; } -bool Session::GenerateRSASessionKey(vector* session_key, +void Session::VerifyEccSignature(const vector& message, + const uint8_t* signature, + size_t signature_length) { + ASSERT_TRUE(public_ec_) << "No public ECC key loaded in test code"; + const OEMCryptoResult result = public_ec_->VerifySignature( + message.data(), message.size(), signature, signature_length); + ASSERT_EQ(result, OEMCrypto_SUCCESS) << "ECC signature check failed"; +} + +void Session::VerifySignature(const vector& message, + const uint8_t* signature, size_t signature_length, + RSA_Padding_Scheme padding_scheme) { + if (public_rsa_ != nullptr) { + return VerifyRsaSignature(message, signature, signature_length, + padding_scheme); + } else if (public_ec_ != nullptr) { + return VerifyEccSignature(message, signature, signature_length); + } + FAIL() << "No public RSA or ECC key loaded in test code"; +} + +bool Session::GenerateRsaSessionKey(vector* session_key, vector* enc_session_key) { if (!public_rsa_) { - cerr << "No public RSA key loaded in test code.\n"; + cerr << "No public RSA key loaded in test code\n"; return false; } *session_key = wvutil::a2b_hex("6fa479c731d2770b6a61a5d1420bb9d1"); - enc_session_key->assign(RSA_size(public_rsa_), 0); - int status = RSA_public_encrypt( - static_cast(session_key->size()), &(session_key->front()), - &(enc_session_key->front()), public_rsa_, RSA_PKCS1_OAEP_PADDING); - int size = static_cast(RSA_size(public_rsa_)); - if (status != size) { - cerr << "GenerateRSASessionKey error encrypting session key.\n"; - dump_boringssl_error(); + *enc_session_key = public_rsa_->EncryptSessionKey(*session_key); + if (enc_session_key->empty()) { return false; } return true; } -void Session::InstallRSASessionTestKey(const vector& wrapped_rsa_key) { +bool Session::GenerateEccSessionKey(vector* session_key, + vector* ecdh_public_key_data) { + if (!public_ec_) { + cerr << "No public ECC key loaded in test code\n"; + return false; + } + std::unique_lock lock(Session::ephemeral_key_map_lock_); + const util::EccCurve curve = public_ec_->curve(); + if (server_ephemeral_keys_.count(curve) == 0) { + server_ephemeral_keys_[curve] = util::EccPrivateKey::New(curve); + } + if (server_ephemeral_keys_.count(curve) == 0) { + cerr << "Failed to find/create server ECC key for curve " + << util::EccCurveToString(curve) << std::endl; + return false; + } + *session_key = server_ephemeral_keys_[curve]->DeriveSessionKey(*public_ec_); + if (session_key->empty()) { + return false; + } + *ecdh_public_key_data = server_ephemeral_keys_[curve]->SerializeAsPublicKey(); + if (ecdh_public_key_data->empty()) { + session_key->clear(); + return false; + } + return true; +} + +bool Session::GenerateSessionKey(vector* session_key, + vector* key_material) { + if (public_rsa_ != nullptr) { + return GenerateRsaSessionKey(session_key, key_material); + } else if (public_ec_ != nullptr) { + return GenerateEccSessionKey(session_key, key_material); + } + cerr << "No public RSA or ECC key loaded in test code\n"; + return false; +} + +void Session::LoadWrappedDrmKey(OEMCrypto_PrivateKeyType key_type, + const vector& wrapped_drm_key) { ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_LoadDRMPrivateKey(session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key.data(), - wrapped_rsa_key.size())); + OEMCrypto_LoadDRMPrivateKey(session_id(), key_type, + wrapped_drm_key.data(), + wrapped_drm_key.size())); +} + +void Session::LoadWrappedRsaDrmKey(const vector& wrapped_rsa_key) { + ASSERT_NO_FATAL_FAILURE( + LoadWrappedDrmKey(OEMCrypto_RSA_Private_Key, wrapped_rsa_key)); +} + +void Session::LoadWrappedEccDrmKey(const vector& wrapped_ecc_key) { + ASSERT_NO_FATAL_FAILURE( + LoadWrappedDrmKey(OEMCrypto_ECC_Private_Key, wrapped_ecc_key)); } void Session::CreateNewUsageEntry(OEMCryptoResult* status) { diff --git a/libwvdrmengine/oemcrypto/test/oec_session_util.h b/libwvdrmengine/oemcrypto/test/oec_session_util.h index a0fb1b85..741f96a8 100644 --- a/libwvdrmengine/oemcrypto/test/oec_session_util.h +++ b/libwvdrmengine/oemcrypto/test/oec_session_util.h @@ -8,9 +8,8 @@ // OEMCrypto unit tests // #include -#include -#include #include +#include #include #include @@ -21,7 +20,9 @@ #include "odk.h" #include "oec_device_features.h" #include "oec_key_deriver.h" +#include "oemcrypto_ecc_key.h" #include "oemcrypto_fuzz_structs.h" +#include "oemcrypto_rsa_key.h" #include "oemcrypto_types.h" #include "pst_report.h" @@ -571,38 +572,75 @@ class Session { void RewrapRSAKey(const struct RSAPrivateKeyMessage& encrypted, size_t message_size, const std::vector& signature, vector* wrapped_key, bool force); - // Loads the specified RSA public key into public_rsa_. If rsa_key is null, - // the default test key is loaded. - void PreparePublicKey(const uint8_t* rsa_key = nullptr, - size_t rsa_key_length = 0); + // Loads the default test RSA public key into public_rsa_. + void SetTestRsaPublicKey(); + // Loads the specified DRM public key into the appropriate key. + // The provided key is serialized as an ASN.1 DER encoded PrivateKeyInfo. + void SetPublicKeyFromPrivateKeyInfo(OEMCrypto_PrivateKeyType key_type, + const uint8_t* buffer, size_t length); // Loads the specified RSA public key into public_rsa_. - void SetRsaPublicKey(const uint8_t* buffer, size_t length); + // The provided key is serialized as an ASN.1 DER encoded PrivateKeyInfo. + void SetRsaPublicKeyFromPrivateKeyInfo(const uint8_t* buffer, size_t length); // Loads the specified EC public key into public_ec_. - void SetEcPublicKey(const uint8_t* buffer, size_t length); + // The provided key is serialized as an ASN.1 DER encoded PrivateKeyInfo. + void SetEccPublicKeyFromPrivateKeyInfo(const uint8_t* buffer, size_t length); + + // Loads the specified DRM public key into the appropriate key. + // The provided key is serialized as an ASN.1 DER encoded SubjectPublicKey. + void SetPublicKeyFromSubjectPublicKey(OEMCrypto_PrivateKeyType key_type, + const uint8_t* buffer, size_t length); + // Loads the specified RSA public key into public_rsa_. + // The provided key is serialized as an ASN.1 DER encoded SubjectPublicKey. + void SetRsaPublicKeyFromSubjectPublicKey(const uint8_t* buffer, + size_t length); + // Loads the specified EC public key into public_ec_. + // The provided key is serialized as an ASN.1 DER encoded SubjectPublicKey. + void SetEccPublicKeyFromSubjectPublicKey(const uint8_t* buffer, + size_t length); - // Verifies the given signature is from the given message and RSA key, pkey. - static bool VerifyPSSSignature(EVP_PKEY* pkey, const uint8_t* message, - size_t message_length, - const uint8_t* signature, - size_t signature_length); // Verify that the message was signed by the private key associated with // |public_rsa_| using the specified padding scheme. - void VerifyRSASignature(const vector& message, + void VerifyRsaSignature(const vector& message, const uint8_t* signature, size_t signature_length, RSA_Padding_Scheme padding_scheme); + // Verify that the message was signed by the private key associated with + // |public_ecc_| using Widevine ECDSA. + void VerifyEccSignature(const vector& message, + const uint8_t* signature, size_t signature_length); + // Verify RSA or ECC signature based on the key type installed. The + // padding_scheme will be ignored in case of ECC key. + void VerifySignature(const vector& message, const uint8_t* signature, + size_t signature_length, + RSA_Padding_Scheme padding_scheme); + // Encrypts a known session key with public_rsa_ for use in future calls to // OEMCrypto_DeriveKeysFromSessionKey or OEMCrypto_RewrapDeviceRSAKey30. // The unencrypted session key is stored in session_key. - bool GenerateRSASessionKey(vector* session_key, + bool GenerateRsaSessionKey(vector* session_key, vector* enc_session_key); + // Derives a session key with public_ec_ and a ephemeral "server" ECC key + // for use in future calls to OEMCrypto_DeriveKeysFromSessionKey. + // The unencrypted session key is stored in session_key. + bool GenerateEccSessionKey(vector* session_key, + vector* ecdh_public_key_data); + // Based on the key type installed, call GenerateRsaSessionKey or + // GenerateEccSessionKey. + bool GenerateSessionKey(vector* session_key, + vector* key_material); + // Calls OEMCrypto_RewrapDeviceRSAKey30 with the given provisioning response // message. If force is true, we assert that the key loads successfully. void RewrapRSAKey30(const struct RSAPrivateKeyMessage& encrypted, const std::vector& encrypted_message_key, vector* wrapped_key, bool force); - // Loads the specified wrapped_rsa_key into OEMCrypto, and then runs - // GenerateDerivedKeysFromSessionKey to install known encryption and mac keys. - void InstallRSASessionTestKey(const vector& wrapped_rsa_key); + + void LoadWrappedDrmKey(OEMCrypto_PrivateKeyType key_type, + const vector& wrapped_drm_key); + // Loads the specified wrapped_rsa_key into OEMCrypto. + void LoadWrappedRsaDrmKey(const vector& wrapped_rsa_key); + // Loads the specified wrapped_ecc_key into OEMCrypto. + void LoadWrappedEccDrmKey(const vector& wrapped_ecc_key); + // Creates a new usage entry, and keeps track of the index. // If status is null, we expect success, otherwise status is set to the // return value. @@ -676,21 +714,34 @@ class Session { OEMCryptoResult actual_select_result, OEMCryptoResult actual_decryt_result); - bool open_; - bool forced_session_id_; - OEMCrypto_SESSION session_id_; + bool open_ = false; + bool forced_session_id_ = false; + OEMCrypto_SESSION session_id_ = 0; KeyDeriver key_deriver_; - uint32_t nonce_; + uint32_t nonce_ = 0; // Only one of RSA or EC should be set. - RSA* public_rsa_ = nullptr; - EC_KEY* public_ec_ = nullptr; + std::unique_ptr public_rsa_; + std::unique_ptr public_ec_; + // In provisioning 4.0, the shared session key is derived from either + // 1. (client side) client private key + server ephemeral public key, or + // 2. (server side) server ephemeral private key + client public key + // Encryption key and mac keys are derived from the shared session key, and + // are inserted in to the default license response which simulates the + // response from a license server. In order for these keys to be deterministic + // across multiple test calls of GenerateDerivedKeysFromSessionKey(), which + // simulates how the server derives keys, the ephemeral keys used by the + // "server" need to be stored for re-use. + static std::unordered_map< + util::EccCurve, std::unique_ptr, std::hash> + server_ephemeral_keys_; + static std::mutex ephemeral_key_map_lock_; vector pst_report_buffer_; MessageData license_ = {}; vector encrypted_usage_entry_; - uint32_t usage_entry_number_; + uint32_t usage_entry_number_ = 0; string pst_; -}; +}; // class Session // Used for OEMCrypto Fuzzing: Convert byte to a valid boolean to avoid errors // generated by msan. diff --git a/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.cpp b/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.cpp index f91a39a7..f84e37c2 100644 --- a/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.cpp +++ b/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.cpp @@ -75,29 +75,35 @@ void SessionUtil::EnsureTestKeys() { // are installed in OEMCrypto and in the test session. void SessionUtil::InstallTestRSAKey(Session* s) { if (global_features.provisioning_method == OEMCrypto_BootCertificateChain) { - const size_t buffer_size = 5000; // Make sure it is large enough. - std::vector public_key(buffer_size); - size_t public_key_size = buffer_size; - std::vector public_key_signature(buffer_size); - size_t public_key_signature_size = buffer_size; - std::vector wrapped_private_key(buffer_size); - size_t wrapped_private_key_size = buffer_size; - OEMCrypto_PrivateKeyType key_type; - // Assume OEM cert has been loaded. - ASSERT_EQ( - OEMCrypto_SUCCESS, - OEMCrypto_GenerateCertificateKeyPair( - s->session_id(), public_key.data(), &public_key_size, - public_key_signature.data(), &public_key_signature_size, - wrapped_private_key.data(), &wrapped_private_key_size, &key_type)); - // Assume the public key has been verified by the server and the DRM cert is - // returned. - ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_LoadDRMPrivateKey(s->session_id(), key_type, - wrapped_private_key.data(), - wrapped_private_key_size)); - ASSERT_NO_FATAL_FAILURE( - s->SetRsaPublicKey(public_key.data(), public_key_size)); + if (wrapped_rsa_key_.size() == 0) { + // If we don't have a wrapped key yet, create one. + // This wrapped key will be shared by all sessions in the test. + const size_t buffer_size = 5000; // Make sure it is large enough. + std::vector public_key(buffer_size); + size_t public_key_size = buffer_size; + std::vector public_key_signature(buffer_size); + size_t public_key_signature_size = buffer_size; + std::vector wrapped_private_key(buffer_size); + size_t wrapped_private_key_size = buffer_size; + OEMCrypto_PrivateKeyType key_type; + // Assume OEM cert has been loaded. + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_GenerateCertificateKeyPair( + s->session_id(), public_key.data(), &public_key_size, + public_key_signature.data(), &public_key_signature_size, + wrapped_private_key.data(), &wrapped_private_key_size, + &key_type)); + // Assume the public key has been verified by the server and the DRM cert + // is returned. + wrapped_private_key.resize(wrapped_private_key_size); + public_key.resize(public_key_size); + wrapped_rsa_key_ = wrapped_private_key; + drm_public_key_ = public_key; + key_type_ = key_type; + } + ASSERT_NO_FATAL_FAILURE(s->LoadWrappedDrmKey(key_type_, wrapped_rsa_key_)); + ASSERT_NO_FATAL_FAILURE(s->SetPublicKeyFromSubjectPublicKey( + key_type_, drm_public_key_.data(), drm_public_key_.size())); return; } @@ -108,10 +114,10 @@ void SessionUtil::InstallTestRSAKey(Session* s) { ASSERT_NO_FATAL_FAILURE(CreateWrappedRSAKey()); } // Load the wrapped rsa test key. - ASSERT_NO_FATAL_FAILURE(s->InstallRSASessionTestKey(wrapped_rsa_key_)); + ASSERT_NO_FATAL_FAILURE(s->LoadWrappedRsaDrmKey(wrapped_rsa_key_)); } // Test RSA key should be loaded. - ASSERT_NO_FATAL_FAILURE(s->PreparePublicKey()); + ASSERT_NO_FATAL_FAILURE(s->SetTestRsaPublicKey()); } } // namespace wvoec diff --git a/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.h b/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.h index 10898435..6ceb03a4 100644 --- a/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.h +++ b/libwvdrmengine/oemcrypto/test/oemcrypto_session_tests_helper.h @@ -1,38 +1,40 @@ #include -#include -#include #include #include +#include +#include +#include "OEMCryptoCENC.h" #include "oec_session_util.h" #include "oec_test_data.h" -#include "OEMCryptoCENC.h" namespace wvoec { class SessionUtil { -public: - SessionUtil() - : encoded_rsa_key_(kTestRSAPKCS8PrivateKeyInfo2_2048, - kTestRSAPKCS8PrivateKeyInfo2_2048 + - sizeof(kTestRSAPKCS8PrivateKeyInfo2_2048)) {} + public: + SessionUtil() + : encoded_rsa_key_(kTestRSAPKCS8PrivateKeyInfo2_2048, + kTestRSAPKCS8PrivateKeyInfo2_2048 + + sizeof(kTestRSAPKCS8PrivateKeyInfo2_2048)) {} - // Create a new wrapped DRM Certificate. - void CreateWrappedRSAKey(); + // Create a new wrapped DRM Certificate. + void CreateWrappedRSAKey(); - // This is used to force installation of a keybox. This overwrites the - // production keybox -- it does NOT use OEMCrypto_LoadTestKeybox. - void InstallKeybox(const wvoec::WidevineKeybox& keybox, bool good); + // This is used to force installation of a keybox. This overwrites the + // production keybox -- it does NOT use OEMCrypto_LoadTestKeybox. + void InstallKeybox(const wvoec::WidevineKeybox& keybox, bool good); - // This loads the test keybox or the test RSA key, using LoadTestKeybox or - // LoadTestRSAKey as needed. - void EnsureTestKeys(); + // This loads the test keybox or the test RSA key, using LoadTestKeybox or + // LoadTestRSAKey as needed. + void EnsureTestKeys(); - void InstallTestRSAKey(Session* s); + void InstallTestRSAKey(Session* s); - std::vector encoded_rsa_key_; - std::vector wrapped_rsa_key_; - wvoec::WidevineKeybox keybox_; + std::vector encoded_rsa_key_; + std::vector wrapped_rsa_key_; + OEMCrypto_PrivateKeyType key_type_; + std::vector drm_public_key_; + wvoec::WidevineKeybox keybox_; }; } // namespace wvoec diff --git a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp index d8d531ef..b37ea8e9 100644 --- a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp +++ b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp @@ -275,7 +275,7 @@ TEST_F(OEMCryptoClientTest, VersionNumber) { // If any of the following fail, then it is time to update the log message // above. EXPECT_EQ(ODK_MAJOR_VERSION, 17); - EXPECT_EQ(ODK_MINOR_VERSION, 0); + EXPECT_EQ(ODK_MINOR_VERSION, 1); EXPECT_EQ(kCurrentAPI, 17u); OEMCrypto_Security_Level level = OEMCrypto_SecurityLevel(); EXPECT_GT(level, OEMCrypto_Level_Unknown); @@ -1214,8 +1214,8 @@ TEST_F(OEMCryptoProv30Test, GetCertOnlyAPI16) { Session s; ASSERT_NO_FATAL_FAILURE(s.open()); // Install the DRM Cert's RSA key. - ASSERT_NO_FATAL_FAILURE(s.InstallRSASessionTestKey(wrapped_rsa_key_)); - ASSERT_NO_FATAL_FAILURE(s.PreparePublicKey()); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); + ASSERT_NO_FATAL_FAILURE(s.SetTestRsaPublicKey()); // Request the OEM Cert. -- This should NOT load the OEM Private key. vector public_cert; size_t public_cert_length = 0; @@ -1247,8 +1247,8 @@ TEST_F(OEMCryptoProv30Test, OEMCryptoMemoryGetOEMPublicCertForHugeCertLength) { Session s; ASSERT_NO_FATAL_FAILURE(s.open()); // Install the DRM Cert's RSA key. - ASSERT_NO_FATAL_FAILURE(s.InstallRSASessionTestKey(wrapped_rsa_key_)); - ASSERT_NO_FATAL_FAILURE(s.PreparePublicKey()); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); + ASSERT_NO_FATAL_FAILURE(s.SetTestRsaPublicKey()); auto oemcrypto_function = [](size_t input_length) { size_t public_cert_length = input_length; @@ -1355,15 +1355,8 @@ TEST_F(OEMCryptoProv40Test, GenerateCertificateKeyPairSuccess) { public_key_signature.resize(public_key_signature_size); wrapped_private_key.resize(wrapped_private_key_size); // Parse the public key generated to make sure it is correctly formatted. - if (key_type == OEMCrypto_PrivateKeyType::OEMCrypto_RSA_Private_Key) { - ASSERT_NO_FATAL_FAILURE( - s.SetRsaPublicKey(public_key.data(), public_key_size)); - } else if (key_type == OEMCrypto_PrivateKeyType::OEMCrypto_ECC_Private_Key) { - ASSERT_NO_FATAL_FAILURE( - s.SetEcPublicKey(public_key.data(), public_key_size)); - } else { - FAIL() << "Unknown private key type: " << static_cast(key_type); - } + ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( + key_type, public_key.data(), public_key_size)); } // Verifies the generated key pairs are different on each call. @@ -1522,11 +1515,19 @@ TEST_F(OEMCryptoProv40Test, InstallOemPrivateKeyCanBeUsed) { wrapped_private_key2.resize(wrapped_private_key_size2); // Verify public_key_signature2 with public_key1. - ASSERT_NO_FATAL_FAILURE( - s.SetRsaPublicKey(public_key1.data(), public_key1.size())); - ASSERT_NO_FATAL_FAILURE( - s.VerifyRSASignature(public_key2, public_key_signature2.data(), - public_key_signature2.size(), kSign_RSASSA_PSS)); + if (key_type2 == OEMCrypto_PrivateKeyType::OEMCrypto_RSA_Private_Key) { + ASSERT_NO_FATAL_FAILURE(s.SetRsaPublicKeyFromSubjectPublicKey( + public_key1.data(), public_key1.size())); + ASSERT_NO_FATAL_FAILURE( + s.VerifyRsaSignature(public_key2, public_key_signature2.data(), + public_key_signature2.size(), kSign_RSASSA_PSS)); + } else if (key_type2 == OEMCrypto_PrivateKeyType::OEMCrypto_ECC_Private_Key) { + ASSERT_NO_FATAL_FAILURE(s.SetEccPublicKeyFromSubjectPublicKey( + public_key1.data(), public_key1.size())); + ASSERT_NO_FATAL_FAILURE(s.VerifyEccSignature(public_key2, + public_key_signature2.data(), + public_key_signature2.size())); + } } // @@ -5363,7 +5364,7 @@ TEST_F(OEMCryptoLoadsCertificate, LoadRSASessionKey) { ASSERT_NO_FATAL_FAILURE(CreateWrappedRSAKey()); Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - ASSERT_NO_FATAL_FAILURE(s.InstallRSASessionTestKey(wrapped_rsa_key_)); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); } TEST_F(OEMCryptoLoadsCertificate, SignProvisioningRequest) { @@ -5797,14 +5798,10 @@ TEST_F(OEMCryptoLoadsCertificate, // Test that a wrapped RSA key can be loaded. TEST_F(OEMCryptoLoadsCertificate, LoadWrappedRSAKey) { - OEMCryptoResult sts; ASSERT_NO_FATAL_FAILURE(CreateWrappedRSAKey()); Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); } TEST_F(OEMCryptoLoadsCertificate, @@ -5813,9 +5810,8 @@ TEST_F(OEMCryptoLoadsCertificate, auto oemcrypto_function = [&](size_t wrapped_rsa_key_length) { Session s; s.open(); - vector wrapped_rsa_key_buffer(wrapped_rsa_key_length); - memcpy(wrapped_rsa_key_buffer.data(), wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); + vector wrapped_rsa_key_buffer = wrapped_rsa_key_; + wrapped_rsa_key_buffer.resize(wrapped_rsa_key_length); OEMCryptoResult result = OEMCrypto_LoadDRMPrivateKey( s.session_id(), OEMCrypto_RSA_Private_Key, wrapped_rsa_key_buffer.data(), wrapped_rsa_key_buffer.size()); @@ -5856,9 +5852,9 @@ class OEMCryptoLoadsCertVariousKeys : public OEMCryptoLoadsCertificate { ASSERT_NO_FATAL_FAILURE(CreateWrappedRSAKey()); Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - ASSERT_NO_FATAL_FAILURE( - s.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); - ASSERT_NO_FATAL_FAILURE(s.InstallRSASessionTestKey(wrapped_rsa_key_)); + ASSERT_NO_FATAL_FAILURE(s.SetRsaPublicKeyFromPrivateKeyInfo( + encoded_rsa_key_.data(), encoded_rsa_key_.size())); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); LicenseRoundTrip license_messages(&s); ASSERT_NO_FATAL_FAILURE(license_messages.SignAndVerifyRequest()); @@ -5927,12 +5923,9 @@ TEST_F(OEMCryptoLoadsCertificate, TestMultipleRSAKeys) { Session s1; // Session s1 loads the default rsa key, but doesn't use it // until after s2 uses its key. ASSERT_NO_FATAL_FAILURE(s1.open()); - ASSERT_NO_FATAL_FAILURE( - s1.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); - ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_LoadDRMPrivateKey( - s1.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), wrapped_rsa_key_.size())); + ASSERT_NO_FATAL_FAILURE(s1.SetRsaPublicKeyFromPrivateKeyInfo( + encoded_rsa_key_.data(), encoded_rsa_key_.size())); + ASSERT_NO_FATAL_FAILURE(s1.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); Session s2; // Session s2 uses a different rsa key. encoded_rsa_key_.assign(kTestRSAPKCS8PrivateKeyInfo4_2048, @@ -5940,9 +5933,9 @@ TEST_F(OEMCryptoLoadsCertificate, TestMultipleRSAKeys) { sizeof(kTestRSAPKCS8PrivateKeyInfo4_2048)); ASSERT_NO_FATAL_FAILURE(CreateWrappedRSAKey()); ASSERT_NO_FATAL_FAILURE(s2.open()); - ASSERT_NO_FATAL_FAILURE( - s2.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); - ASSERT_NO_FATAL_FAILURE(s2.InstallRSASessionTestKey(wrapped_rsa_key_)); + ASSERT_NO_FATAL_FAILURE(s2.SetRsaPublicKeyFromPrivateKeyInfo( + encoded_rsa_key_.data(), encoded_rsa_key_.size())); + ASSERT_NO_FATAL_FAILURE(s2.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); LicenseRoundTrip license_messages2(&s2); ASSERT_NO_FATAL_FAILURE(license_messages2.SignAndVerifyRequest()); ASSERT_NO_FATAL_FAILURE(license_messages2.CreateDefaultResponse()); @@ -5976,10 +5969,10 @@ TEST_F(OEMCryptoLoadsCertificate, TestMaxDRMKeys) { kTestRSAPKCS8PrivateKeys_2048[key_index].end()); ASSERT_NO_FATAL_FAILURE(CreateWrappedRSAKey()); ASSERT_NO_FATAL_FAILURE(sessions[i]->open()); - ASSERT_NO_FATAL_FAILURE(sessions[i]->PreparePublicKey( + ASSERT_NO_FATAL_FAILURE(sessions[i]->SetRsaPublicKeyFromPrivateKeyInfo( encoded_rsa_key_.data(), encoded_rsa_key_.size())); ASSERT_NO_FATAL_FAILURE( - sessions[i]->InstallRSASessionTestKey(wrapped_rsa_key_)); + sessions[i]->LoadWrappedRsaDrmKey(wrapped_rsa_key_)); } // Attempts to load one more key than the kMaxTotalDRMPrivateKeys @@ -6053,7 +6046,7 @@ class OEMCryptoUsesCertificate : public OEMCryptoLoadsCertificate { ASSERT_NO_FATAL_FAILURE(session_.open()); if (global_features.derive_key_method == DeviceFeatures::LOAD_TEST_RSA_KEY) { - ASSERT_NO_FATAL_FAILURE(session_.PreparePublicKey( + ASSERT_NO_FATAL_FAILURE(session_.SetRsaPublicKeyFromPrivateKeyInfo( encoded_rsa_key_.data(), encoded_rsa_key_.size())); } else { InstallTestRSAKey(&session_); @@ -6092,10 +6085,7 @@ TEST_F(OEMCryptoLoadsCertificate, RSAPerformance) { while (clock.now() - start_time < kTestDuration) { Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); const size_t size = 50; vector licenseRequest(size); GetRandBytes(licenseRequest.data(), licenseRequest.size()); @@ -6133,15 +6123,12 @@ TEST_F(OEMCryptoLoadsCertificate, RSAPerformance) { Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_LoadDRMPrivateKey( - s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), wrapped_rsa_key_.size())); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); vector session_key; vector enc_session_key; - ASSERT_NO_FATAL_FAILURE( - s.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); - ASSERT_TRUE(s.GenerateRSASessionKey(&session_key, &enc_session_key)); + ASSERT_NO_FATAL_FAILURE(s.SetRsaPublicKeyFromPrivateKeyInfo( + encoded_rsa_key_.data(), encoded_rsa_key_.size())); + ASSERT_TRUE(s.GenerateRsaSessionKey(&session_key, &enc_session_key)); vector mac_context; vector enc_context; s.FillDefaultContext(&mac_context, &enc_context); @@ -6188,7 +6175,7 @@ TEST_F(OEMCryptoLoadsCertificate, RSAPerformance) { TEST_F(OEMCryptoUsesCertificate, GenerateDerivedKeysLargeBuffer) { vector session_key; vector enc_session_key; - ASSERT_TRUE(session_.GenerateRSASessionKey(&session_key, &enc_session_key)); + ASSERT_TRUE(session_.GenerateRsaSessionKey(&session_key, &enc_session_key)); const size_t max_size = GetResourceValue(kLargeMessageSize); vector mac_context(max_size); vector enc_context(max_size); @@ -6208,7 +6195,7 @@ TEST_F(OEMCryptoUsesCertificate, OEMCryptoMemoryDeriveKeysFromSessionKeyForHugeMacContext) { vector session_key; vector enc_session_key; - ASSERT_TRUE(session_.GenerateRSASessionKey(&session_key, &enc_session_key)); + ASSERT_TRUE(session_.GenerateRsaSessionKey(&session_key, &enc_session_key)); vector mac_context; vector enc_context; session_.FillDefaultContext(&mac_context, &enc_context); @@ -6228,7 +6215,7 @@ TEST_F(OEMCryptoUsesCertificate, OEMCryptoMemoryDeriveKeysFromSessionKeyForHugeEncContext) { vector session_key; vector enc_session_key; - ASSERT_TRUE(session_.GenerateRSASessionKey(&session_key, &enc_session_key)); + ASSERT_TRUE(session_.GenerateRsaSessionKey(&session_key, &enc_session_key)); vector mac_context; vector enc_context; session_.FillDefaultContext(&mac_context, &enc_context); @@ -6248,7 +6235,7 @@ TEST_F(OEMCryptoUsesCertificate, OEMCryptoMemoryDeriveKeysFromSessionKeyForHugeEncSessionKey) { vector session_key; vector enc_session_key; - ASSERT_TRUE(session_.GenerateRSASessionKey(&session_key, &enc_session_key)); + ASSERT_TRUE(session_.GenerateRsaSessionKey(&session_key, &enc_session_key)); vector mac_context; vector enc_context; session_.FillDefaultContext(&mac_context, &enc_context); @@ -6272,10 +6259,7 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { OEMCryptoResult sts; Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); // Sign a Message vector licenseRequest(size); @@ -6302,20 +6286,16 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { } void TestSignature(RSA_Padding_Scheme scheme, size_t size) { - OEMCryptoResult sts; Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); vector licenseRequest(size); GetRandBytes(licenseRequest.data(), licenseRequest.size()); size_t signature_length = 0; - sts = OEMCrypto_GenerateRSASignature(s.session_id(), licenseRequest.data(), - licenseRequest.size(), nullptr, - &signature_length, scheme); + OEMCryptoResult sts = OEMCrypto_GenerateRSASignature( + s.session_id(), licenseRequest.data(), licenseRequest.size(), nullptr, + &signature_length, scheme); ASSERT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, sts); ASSERT_NE(static_cast(0), signature_length); @@ -6328,26 +6308,22 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { << "Failed to sign with padding scheme=" << (int)scheme << ", size=" << size; signature.resize(signature_length); - ASSERT_NO_FATAL_FAILURE( - s.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); - ASSERT_NO_FATAL_FAILURE(s.VerifyRSASignature( + ASSERT_NO_FATAL_FAILURE(s.SetRsaPublicKeyFromPrivateKeyInfo( + encoded_rsa_key_.data(), encoded_rsa_key_.size())); + ASSERT_NO_FATAL_FAILURE(s.VerifyRsaSignature( licenseRequest, signature.data(), signature_length, scheme)); } void DisallowDeriveKeys() { - OEMCryptoResult sts; Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); s.GenerateNonce(); vector session_key; vector enc_session_key; - ASSERT_NO_FATAL_FAILURE( - s.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); - ASSERT_TRUE(s.GenerateRSASessionKey(&session_key, &enc_session_key)); + ASSERT_NO_FATAL_FAILURE(s.SetRsaPublicKeyFromPrivateKeyInfo( + encoded_rsa_key_.data(), encoded_rsa_key_.size())); + ASSERT_TRUE(s.GenerateRsaSessionKey(&session_key, &enc_session_key)); vector mac_context; vector enc_context; s.FillDefaultContext(&mac_context, &enc_context); @@ -6394,10 +6370,7 @@ TEST_F(OEMCryptoLoadsCertificateAlternates, if (key_loaded_) { Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); vector message_buffer(10); size_t signature_length = 0; @@ -6421,7 +6394,6 @@ TEST_F(OEMCryptoLoadsCertificateAlternates, TEST_F(OEMCryptoLoadsCertificateAlternates, OEMCryptoMemoryGenerateRSASignatureForHugeSignatureLength) { - OEMCryptoResult sts; LoadWithAllowedSchemes(kSign_PKCS1_Block1, false); // If the device is a cast receiver, then this scheme is required. if (global_features.cast_receiver) { @@ -6430,10 +6402,7 @@ TEST_F(OEMCryptoLoadsCertificateAlternates, if (key_loaded_) { Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); vector message_buffer(50); vector signature; @@ -6643,10 +6612,7 @@ class OEMCryptoCastReceiverTest : public OEMCryptoLoadsCertificateAlternates { OEMCryptoResult sts; Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - sts = OEMCrypto_LoadDRMPrivateKey(s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_rsa_key_.data(), - wrapped_rsa_key_.size()); - ASSERT_EQ(OEMCrypto_SUCCESS, sts); + ASSERT_NO_FATAL_FAILURE(s.LoadWrappedRsaDrmKey(wrapped_rsa_key_)); // The application will compute the SHA-1 Hash of the message, so this // test must do that also. @@ -6678,8 +6644,8 @@ class OEMCryptoCastReceiverTest : public OEMCryptoLoadsCertificateAlternates { << "Failed to sign with padding scheme=" << (int)scheme << ", size=" << message.size(); signature.resize(signature_length); - ASSERT_NO_FATAL_FAILURE( - s.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); + ASSERT_NO_FATAL_FAILURE(s.SetRsaPublicKeyFromPrivateKeyInfo( + encoded_rsa_key_.data(), encoded_rsa_key_.size())); // Verify that the signature matches the official test vector. ASSERT_EQ(correct_signature.size(), signature_length); @@ -6688,9 +6654,9 @@ class OEMCryptoCastReceiverTest : public OEMCryptoLoadsCertificateAlternates { // Also verify that our verification algorithm agrees. This is not needed // to test OEMCrypto, but it does verify that this test is valid. - ASSERT_NO_FATAL_FAILURE(s.VerifyRSASignature(digest, signature.data(), + ASSERT_NO_FATAL_FAILURE(s.VerifyRsaSignature(digest, signature.data(), signature_length, scheme)); - ASSERT_NO_FATAL_FAILURE(s.VerifyRSASignature( + ASSERT_NO_FATAL_FAILURE(s.VerifyRsaSignature( digest, correct_signature.data(), correct_signature.size(), scheme)); } }; diff --git a/libwvdrmengine/oemcrypto/test/ota_keybox_test.cpp b/libwvdrmengine/oemcrypto/test/ota_keybox_test.cpp index ec8377e8..359da3a8 100644 --- a/libwvdrmengine/oemcrypto/test/ota_keybox_test.cpp +++ b/libwvdrmengine/oemcrypto/test/ota_keybox_test.cpp @@ -268,7 +268,7 @@ TEST_F(OTAKeyboxProvisioningTest, BasicTest) { std::copy(bit_size_string.begin(), bit_size_string.end(), std::back_inserter(enc_context)); KeyDeriver keys; - keys.DeriveKeys(model_key.data(), mac_context, enc_context); + keys.DeriveKeys(model_key.data(), model_key.size(), mac_context, enc_context); const std::vector message( request.data(), request.data() + request.size() - HMAC_SHA256_SIGNATURE_SIZE);