diff --git a/libwvdrmengine/cdm/core/src/crypto_session.cpp b/libwvdrmengine/cdm/core/src/crypto_session.cpp index c4a65104..2c2c780d 100644 --- a/libwvdrmengine/cdm/core/src/crypto_session.cpp +++ b/libwvdrmengine/cdm/core/src/crypto_session.cpp @@ -2099,6 +2099,11 @@ CdmResponseType CryptoSession::LoadProvisioning( metrics_, oemcrypto_load_provisioning_, status); }); + if (status == OEMCrypto_SUCCESS) { + wrapped_private_key->resize(wrapped_private_key_length); + return NO_ERROR; + } + wrapped_private_key->clear(); return MapOEMCryptoResult(status, LOAD_PROVISIONING_ERROR, "LoadProvisioning"); } @@ -2283,29 +2288,25 @@ bool CryptoSession::GetBuildInformation(RequestedSecurityLevel security_level, RequestedSecurityLevelToString(security_level)); RETURN_IF_UNINITIALIZED(false); RETURN_IF_NULL(info, false); - - OEMCryptoResult build_information; - std::string buf; - size_t buf_length = 0; - WithOecReadLock("GetBuildInformation", [&] { - build_information = - OEMCrypto_BuildInformation(&buf[0], &buf_length, security_level); + size_t info_length = 128; + info->assign(info_length, '\0'); + OEMCryptoResult result = WithOecReadLock("GetBuildInformation", [&] { + return OEMCrypto_BuildInformation(&info->front(), &info_length, + security_level); }); - if (build_information == OEMCrypto_ERROR_SHORT_BUFFER) { - buf.resize(buf_length); - WithOecReadLock("GetBuildInformation Attempt 2", [&] { - build_information = - OEMCrypto_BuildInformation(&buf[0], &buf_length, security_level); + if (result == OEMCrypto_ERROR_SHORT_BUFFER) { + info->assign(info_length, '\0'); + result = WithOecReadLock("GetBuildInformation Attempt 2", [&] { + return OEMCrypto_BuildInformation(&info->front(), &info_length, + security_level); }); } - - if (build_information == OEMCrypto_SUCCESS) { - *info = buf; - } else { - LOGE("Unexpected return value"); + if (result != OEMCrypto_SUCCESS) { + LOGE("GetBuildInformation failed: result = %d", result); + info->clear(); return false; } - + info->resize(info_length); return true; } diff --git a/libwvdrmengine/cdm/core/src/oemcrypto_adapter_dynamic.cpp b/libwvdrmengine/cdm/core/src/oemcrypto_adapter_dynamic.cpp index 97c00d65..0a2ecc63 100644 --- a/libwvdrmengine/cdm/core/src/oemcrypto_adapter_dynamic.cpp +++ b/libwvdrmengine/cdm/core/src/oemcrypto_adapter_dynamic.cpp @@ -1737,14 +1737,19 @@ OEMCryptoResult OEMCrypto_BuildInformation( if (fcn->BuildInformation_V16 == nullptr) { return OEMCrypto_ERROR_NOT_IMPLEMENTED; } + if (buffer_length == nullptr) return OEMCrypto_ERROR_INVALID_CONTEXT; + if (buffer == nullptr && *buffer_length > 0) + return OEMCrypto_ERROR_INVALID_CONTEXT; + constexpr size_t kMaxBuildInfoLength = 128; const char* build_info = fcn->BuildInformation_V16(); - size_t max_length = strnlen(build_info, 128); - if (*buffer_length < max_length) { - *buffer_length = max_length; + if (build_info == nullptr) return OEMCrypto_ERROR_UNKNOWN_FAILURE; + const size_t build_info_length = strnlen(build_info, kMaxBuildInfoLength); + if (*buffer_length < build_info_length) { + *buffer_length = build_info_length; return OEMCrypto_ERROR_SHORT_BUFFER; } - *buffer_length = max_length; - memcpy(buffer, build_info, *buffer_length); + *buffer_length = build_info_length; + memcpy(buffer, build_info, build_info_length); return OEMCrypto_SUCCESS; } return fcn->BuildInformation(buffer, buffer_length); diff --git a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp index a01c1fd3..5e6f7ec7 100644 --- a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp +++ b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp @@ -237,6 +237,9 @@ RoundTrip:: // We need to fill in core request and verify signature only for calls other // than OEMCryptoMemory buffer overflow test. Any test other than buffer // overflow will pass true. + if (result == OEMCrypto_SUCCESS) { + gen_signature.resize(gen_signature_length); + } if (!verify_request || result != OEMCrypto_SUCCESS) return result; if (global_features.api_version >= kCoreMessagesAPI) { std::string core_message(reinterpret_cast(data.data()), @@ -466,11 +469,14 @@ OEMCryptoResult ProvisioningRoundTrip::LoadResponse(Session* session) { sizeof(response_data_)); } size_t wrapped_key_length = 0; - const OEMCryptoResult sts = LoadResponseNoRetry(session, &wrapped_key_length); + OEMCryptoResult sts = LoadResponseNoRetry(session, &wrapped_key_length); if (sts != OEMCrypto_ERROR_SHORT_BUFFER) return sts; - wrapped_rsa_key_.clear(); wrapped_rsa_key_.assign(wrapped_key_length, 0); - return LoadResponseNoRetry(session, &wrapped_key_length); + sts = LoadResponseNoRetry(session, &wrapped_key_length); + if (sts == OEMCrypto_SUCCESS) { + wrapped_rsa_key_.resize(wrapped_key_length); + } + return sts; } #ifdef TEST_OEMCRYPTO_V15 @@ -1589,6 +1595,7 @@ void Session::LoadOEMCert(bool verify_cert) { public_cert.resize(public_cert_length); ASSERT_EQ(OEMCrypto_SUCCESS, OEMCrypto_GetOEMPublicCertificate( public_cert.data(), &public_cert_length)); + public_cert.resize(public_cert_length); ASSERT_EQ(OEMCrypto_SUCCESS, OEMCrypto_LoadOEMPrivateKey(session_id())); // The cert is a PKCS7 signed data type. First, parse it into an OpenSSL @@ -1871,6 +1878,8 @@ void Session::UpdateUsageEntry(std::vector* header_buffer) { OEMCrypto_UpdateUsageEntry( session_id(), header_buffer->data(), &header_buffer_length, encrypted_usage_entry_.data(), &entry_buffer_length)); + header_buffer->resize(header_buffer_length); + encrypted_usage_entry_.resize(entry_buffer_length); } void Session::LoadUsageEntry(uint32_t index, const vector& buffer) { @@ -1915,6 +1924,7 @@ void Session::GenerateReport(const std::string& pst, if (expected_result != OEMCrypto_SUCCESS) { return; } + pst_report_buffer_.resize(length); EXPECT_EQ(wvutil::Unpacked_PST_Report::report_size(pst.length()), length); vector computed_signature(SHA_DIGEST_LENGTH); key_deriver_.ClientSignPstReport(pst_report_buffer_, &computed_signature); diff --git a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp index c011dcdf..7a2e5147 100644 --- a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp +++ b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp @@ -308,6 +308,9 @@ TEST_F(OEMCryptoClientTest, VersionNumber) { sts = OEMCrypto_BuildInformation(&build_info[0], &buf_length); } ASSERT_EQ(OEMCrypto_SUCCESS, sts); + if (build_info.size() != buf_length) { + build_info.resize(buf_length); + } cout << " BuildInformation: " << build_info << endl; OEMCrypto_WatermarkingSupport support = OEMCrypto_GetWatermarkingSupport(); cout << " WatermarkingSupport: " << support << endl; @@ -484,7 +487,23 @@ TEST_F(OEMCryptoClientTest, CheckNullBuildInformationAPI17) { ASSERT_EQ(OEMCrypto_ERROR_INVALID_CONTEXT, sts); size_t buf_length = 0; sts = OEMCrypto_BuildInformation(nullptr, &buf_length); - ASSERT_EQ(OEMCrypto_ERROR_INVALID_CONTEXT, sts); + // Previous versions of the test expected the wrong error code. + // Although OEMCrypto_ERROR_INVALID_CONTEXT is still accepted by + // the tests, vendors should return OEMCrypto_ERROR_SHORT_BUFFER if + // |buffer| is null and |buf_length| is zero, assigning + // the correct length to |buf_length|. + // TODO(231514699): Remove case for ERROR_INVALID_CONTEXT. + ASSERT_TRUE(OEMCrypto_ERROR_SHORT_BUFFER == sts || + OEMCrypto_ERROR_INVALID_CONTEXT == sts); + if (sts == OEMCrypto_ERROR_INVALID_CONTEXT) { + printf( + "Warning: OEMCrypto_BuildInformation should return " + "ERROR_SHORT_BUFFER.\n"); + } + if (sts == OEMCrypto_ERROR_SHORT_BUFFER) { + constexpr size_t kZero = 0; + ASSERT_GT(buf_length, kZero); + } } TEST_F(OEMCryptoClientTest, CheckMaxNumberOfSessionsAPI10) { @@ -988,9 +1007,9 @@ TEST_F(OEMCryptoKeyboxTest, NormalGetDeviceId) { uint8_t dev_id[128] = {0}; size_t dev_id_len = 128; sts = OEMCrypto_GetDeviceID(dev_id, &dev_id_len); + ASSERT_EQ(OEMCrypto_SUCCESS, sts); cout << " NormalGetDeviceId: dev_id = " << MaybeHex(dev_id, dev_id_len) << " len = " << dev_id_len << endl; - ASSERT_EQ(OEMCrypto_SUCCESS, sts); } TEST_F(OEMCryptoKeyboxTest, OEMCryptoMemoryGetDeviceIdForHugeIdLength) { @@ -1133,7 +1152,6 @@ TEST_F(OEMCryptoProv30Test, GetDeviceId) { dev_id.resize(dev_id_len); cout << " NormalGetDeviceId: dev_id = " << MaybeHex(dev_id) << " len = " << dev_id_len << endl; - ASSERT_EQ(OEMCrypto_SUCCESS, sts); } // The OEM certificate must be valid. @@ -1333,6 +1351,9 @@ TEST_F(OEMCryptoProv40Test, GenerateCertificateKeyPairSuccess) { public_key_signature.data(), &public_key_signature_size, wrapped_private_key.data(), &wrapped_private_key_size, &key_type), OEMCrypto_SUCCESS); + public_key.resize(public_key_size); + public_key_signature.resize(public_key_signature_size); + wrapped_private_key.resize(wrapped_private_key_size); // Parse the public key generated to make sure it is correctly formatted. if (key_type == OEMCrypto_PrivateKeyType::OEMCrypto_RSA_Private_Key) { ASSERT_NO_FATAL_FAILURE( @@ -1543,6 +1564,7 @@ class OEMCryptoSessionTests : public OEMCryptoClientTest { &header_buffer_length); if (expect_success) { ASSERT_EQ(OEMCrypto_SUCCESS, sts); + encrypted_usage_header_.resize(header_buffer_length); } else { ASSERT_NE(OEMCrypto_SUCCESS, sts); } @@ -6065,11 +6087,10 @@ TEST_F(OEMCryptoLoadsCertificate, RSAPerformance) { licenseRequest.size()); } - uint8_t* signature = new uint8_t[signature_length]; - sts = OEMCrypto_GenerateRSASignature(s.session_id(), licenseRequest.data(), - licenseRequest.size(), signature, - &signature_length, kSign_RSASSA_PSS); - delete[] signature; + std::vector signature(signature_length, 0); + sts = OEMCrypto_GenerateRSASignature( + s.session_id(), licenseRequest.data(), licenseRequest.size(), + signature.data(), &signature_length, kSign_RSASSA_PSS); ASSERT_EQ(OEMCrypto_SUCCESS, sts); count++; } @@ -6243,7 +6264,7 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { EXPECT_NE(OEMCrypto_SUCCESS, sts) << "Signed with forbidden padding scheme=" << (int)scheme << ", size=" << (int)size; - vector zero(signature_length, 0); + const vector zero(signature.size(), 0); ASSERT_EQ(zero, signature); // signature should not be computed. } @@ -6265,19 +6286,19 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { ASSERT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, sts); ASSERT_NE(static_cast(0), signature_length); - uint8_t* signature = new uint8_t[signature_length]; - sts = OEMCrypto_GenerateRSASignature(s.session_id(), licenseRequest.data(), - licenseRequest.size(), signature, - &signature_length, scheme); + std::vector signature(signature_length, 0); + sts = OEMCrypto_GenerateRSASignature( + s.session_id(), licenseRequest.data(), licenseRequest.size(), + signature.data(), &signature_length, scheme); ASSERT_EQ(OEMCrypto_SUCCESS, sts) << "Failed to sign with padding scheme=" << (int)scheme - << ", size=" << (int)size; + << ", size=" << size; + signature.resize(signature_length); ASSERT_NO_FATAL_FAILURE( s.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); - ASSERT_NO_FATAL_FAILURE(s.VerifyRSASignature(licenseRequest, signature, - signature_length, scheme)); - delete[] signature; + ASSERT_NO_FATAL_FAILURE(s.VerifyRSASignature( + licenseRequest, signature.data(), signature_length, scheme)); } void DisallowDeriveKeys() { @@ -6622,7 +6643,8 @@ class OEMCryptoCastReceiverTest : public OEMCryptoLoadsCertificateAlternates { ASSERT_EQ(OEMCrypto_SUCCESS, sts) << "Failed to sign with padding scheme=" << (int)scheme - << ", size=" << (int)message.size(); + << ", size=" << message.size(); + signature.resize(signature_length); ASSERT_NO_FATAL_FAILURE( s.PreparePublicKey(encoded_rsa_key_.data(), encoded_rsa_key_.size())); @@ -9393,6 +9415,9 @@ class OEMCryptoUsageTableDefragTest : public OEMCryptoUsageTableTest { new_size, encrypted_usage_header_.data(), &header_buffer_length); // For the second call, we always demand the expected result. ASSERT_EQ(expected_result, sts); + if (sts == OEMCrypto_SUCCESS) { + encrypted_usage_header_.resize(header_buffer_length); + } } };