From 402532218568d25505a2fa0a88c67eae48484d06 Mon Sep 17 00:00:00 2001 From: Fred Gylys-Colwell Date: Wed, 1 Mar 2017 13:19:12 -0800 Subject: [PATCH] Source and destination buffers may point to same buffer Merge from Widevine repo of http://go/wvgerrit/23581 This CL adds some unit tests to oemcrypto to verify that DecryptCENC and the generic encrypt and decrypt functions behave correctly when the input and output buffer is the same. i.e. decrypt in place. The mock and haystack are also updated to pass the tests. b/34080119 Change-Id: Ie295bdaddbb8058bebb36f6dab092d307f249ecd --- .../oemcrypto/mock/src/oemcrypto_session.cpp | 6 +- .../oemcrypto/test/oec_session_util.cpp | 8 - .../oemcrypto/test/oec_session_util.h | 14 -- .../oemcrypto/test/oemcrypto_test.cpp | 152 ++++++++++++++---- 4 files changed, 122 insertions(+), 58 deletions(-) diff --git a/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp b/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp index 743716cb..e409c0f0 100644 --- a/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp +++ b/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp @@ -1134,6 +1134,7 @@ OEMCryptoResult SessionContext::DecryptCBC( AES_KEY aes_key; AES_set_decrypt_key(&key[0], AES_BLOCK_SIZE * 8, &aes_key); uint8_t iv[AES_BLOCK_SIZE]; + uint8_t next_iv[AES_BLOCK_SIZE]; memcpy(iv, &initial_iv[0], AES_BLOCK_SIZE); size_t l = 0; @@ -1151,11 +1152,14 @@ OEMCryptoResult SessionContext::DecryptCBC( memcpy(&clear_data[l], &cipher_data[l], size); } else { uint8_t aes_output[AES_BLOCK_SIZE]; + // Save the iv for the next block, in case cipher_data is in the same + // buffer as clear_data. + memcpy(next_iv, &cipher_data[l], AES_BLOCK_SIZE); AES_decrypt(&cipher_data[l], aes_output, &aes_key); for (size_t n = 0; n < AES_BLOCK_SIZE; n++) { clear_data[l + n] = aes_output[n] ^ iv[n]; } - memcpy(iv, &cipher_data[l], AES_BLOCK_SIZE); + memcpy(iv, next_iv, AES_BLOCK_SIZE); } l += size; } diff --git a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp index d84285a7..f806ddb8 100644 --- a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp +++ b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp @@ -35,17 +35,9 @@ using namespace std; // GTest requires PrintTo to be in the same namespace as the thing it prints, // which is std::vector in this case. namespace std { - void PrintTo(const vector& value, ostream* os) { *os << wvcdm::b2a_hex(value); } - -void PrintTo(const PatternTestVariant& param, ostream* os) { - *os << ((param.mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode") - << ", encrypt=" << param.pattern.encrypt - << ", skip=" << param.pattern.skip; -} - } // namespace std namespace wvoec { diff --git a/libwvdrmengine/oemcrypto/test/oec_session_util.h b/libwvdrmengine/oemcrypto/test/oec_session_util.h index 03716f60..926f21dd 100644 --- a/libwvdrmengine/oemcrypto/test/oec_session_util.h +++ b/libwvdrmengine/oemcrypto/test/oec_session_util.h @@ -18,21 +18,7 @@ using namespace std; // GTest requires PrintTo to be in the same namespace as the thing it prints, // which is std::vector in this case. namespace std { - -struct PatternTestVariant { - PatternTestVariant(size_t encrypt, size_t skip, OEMCryptoCipherMode mode) { - this->pattern.encrypt = encrypt; - this->pattern.skip = skip; - this->pattern.offset = 0; - this->mode = mode; - } - OEMCrypto_CENCEncryptPatternDesc pattern; - OEMCryptoCipherMode mode; -}; - void PrintTo(const vector& value, ostream* os); -void PrintTo(const PatternTestVariant& param, ostream* os); - } // namespace std namespace wvoec { diff --git a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp index 7b4544b6..42eb0bb5 100644 --- a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp +++ b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp @@ -36,10 +36,26 @@ #include "wv_cdm_constants.h" #include "wv_keybox.h" -using namespace std; -using ::testing::WithParamInterface; +using ::testing::Bool; +using ::testing::Combine; using ::testing::Range; using ::testing::Values; +using ::testing::WithParamInterface; +using namespace std; +using std::tr1::tuple; + +namespace std { // GTest wants PrintTo to be in the std namespace. +void PrintTo(const tuple& param, ostream* os) { + OEMCrypto_CENCEncryptPatternDesc pattern = std::tr1::get<0>(param); + OEMCryptoCipherMode mode = std::tr1::get<1>(param); + bool decrypt_inplace = std::tr1::get<2>(param); + *os << ((mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode") + << ", encrypt=" << pattern.encrypt + << ", skip=" << pattern.skip + << ", decrypt in place = " << (decrypt_inplace ? "true":"false"); +} +} namespace wvoec { @@ -1615,12 +1631,14 @@ struct SampleInitData { class OEMCryptoSessionTestsDecryptTests : public OEMCryptoSessionTests, - public WithParamInterface { + public WithParamInterface > { protected: virtual void SetUp() { OEMCryptoSessionTests::SetUp(); - pattern_ = GetParam().pattern; - cipher_mode_ = GetParam().mode; + pattern_ = std::tr1::get<0>(GetParam()); + cipher_mode_ = std::tr1::get<1>(GetParam()); + decrypt_inplace_ = std::tr1::get<2>(GetParam()); } void FindTotalSize() { @@ -1728,7 +1746,17 @@ class OEMCryptoSessionTestsDecryptTests ASSERT_EQ(OEMCrypto_SUCCESS, sts); // We decrypt each subsample. - vector outputBuffer(total_size_ + 16, 0xaa); + vector output_buffer(total_size_ + 16, 0xaa); + const uint8_t *input_buffer = NULL; + if (decrypt_inplace_) { // Use same buffer for input and output. + // Copy the useful data from encryptedData to output_buffer, which + // will be the same as input_buffer. Leave the 0xaa padding at the end. + for(int i=0; i < total_size_; i++) output_buffer[i] = encryptedData[i]; + // Now let input_buffer point to the same data. + input_buffer = &output_buffer[0]; + } else { + input_buffer = &encryptedData[0]; + } size_t buffer_offset = 0; for (size_t i = 0; i < subsample_size_.size(); i++) { OEMCrypto_CENCEncryptPatternDesc pattern = pattern_; @@ -1739,7 +1767,7 @@ class OEMCryptoSessionTestsDecryptTests uint8_t subsample_flags = 0; if (subsample_size_[i].clear_size > 0) { destBuffer.type = OEMCrypto_BufferType_Clear; - destBuffer.buffer.clear.address = &outputBuffer[buffer_offset]; + destBuffer.buffer.clear.address = &output_buffer[buffer_offset]; destBuffer.buffer.clear.max_length = total_size_ - buffer_offset; if (i == 0) subsample_flags |= OEMCrypto_FirstSubsample; if ((i == subsample_size_.size() - 1) && @@ -1747,7 +1775,7 @@ class OEMCryptoSessionTestsDecryptTests subsample_flags |= OEMCrypto_LastSubsample; } sts = - OEMCrypto_DecryptCENC(s.session_id(), &encryptedData[buffer_offset], + OEMCrypto_DecryptCENC(s.session_id(), input_buffer + buffer_offset, subsample_size_[i].clear_size, is_encrypted, sample_init_data_[i].iv, block_offset, &destBuffer, &pattern, subsample_flags); @@ -1756,7 +1784,7 @@ class OEMCryptoSessionTestsDecryptTests } if (subsample_size_[i].encrypted_size > 0) { destBuffer.type = OEMCrypto_BufferType_Clear; - destBuffer.buffer.clear.address = &outputBuffer[buffer_offset]; + destBuffer.buffer.clear.address = &output_buffer[buffer_offset]; destBuffer.buffer.clear.max_length = total_size_ - buffer_offset; is_encrypted = true; block_offset = sample_init_data_[i].block_offset; @@ -1768,7 +1796,7 @@ class OEMCryptoSessionTestsDecryptTests subsample_flags |= OEMCrypto_LastSubsample; } sts = OEMCrypto_DecryptCENC( - s.session_id(), &encryptedData[buffer_offset], + s.session_id(), input_buffer + buffer_offset, subsample_size_[i].encrypted_size, is_encrypted, sample_init_data_[i].iv, block_offset, &destBuffer, &pattern, subsample_flags); @@ -1782,13 +1810,14 @@ class OEMCryptoSessionTestsDecryptTests buffer_offset += subsample_size_[i].encrypted_size; } } - EXPECT_EQ(0xaa, outputBuffer[total_size_]) << "Buffer overrun."; - outputBuffer.resize(total_size_); - EXPECT_EQ(unencryptedData, outputBuffer); + EXPECT_EQ(0xaa, output_buffer[total_size_]) << "Buffer overrun."; + output_buffer.resize(total_size_); + EXPECT_EQ(unencryptedData, output_buffer); } OEMCrypto_CENCEncryptPatternDesc pattern_; OEMCryptoCipherMode cipher_mode_; + bool decrypt_inplace_; // If true, input and output buffers are the same. vector subsample_size_; size_t total_size_; vector sample_init_data_; @@ -2026,35 +2055,51 @@ TEST_F(OEMCryptoSessionTests, DecryptUnencryptedNoKey) { ASSERT_EQ(in_buffer, out_buffer); } +// Used to construct a specific pattern. +OEMCrypto_CENCEncryptPatternDesc MakePattern(size_t encrypt, size_t skip) { + OEMCrypto_CENCEncryptPatternDesc pattern; + pattern.encrypt = encrypt; + pattern.skip = skip; + pattern.offset = 0; // offset is deprecated. + return pattern; +} + INSTANTIATE_TEST_CASE_P(CTRTests, OEMCryptoSessionTestsPartialBlockTests, - Values(PatternTestVariant(0, 0, - OEMCrypto_CipherMode_CTR))); + Combine(Values(MakePattern(0,0)), + Values(OEMCrypto_CipherMode_CTR), + Bool())); INSTANTIATE_TEST_CASE_P( CBCTests, OEMCryptoSessionTestsPartialBlockTests, - Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC), - PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC), - // HLS Edge case. We should follow the CENC spec, not HLS spec. - PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC), - PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC), - PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC), - PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC))); + Combine( + Values(MakePattern(0, 0), + MakePattern(3, 7), + // HLS Edge case. We should follow the CENC spec, not HLS spec. + MakePattern(9, 1), + MakePattern(1, 9), + MakePattern(1, 3), + MakePattern(2, 1)), + Values(OEMCrypto_CipherMode_CBC), Bool())); INSTANTIATE_TEST_CASE_P( CTRTests, OEMCryptoSessionTestsDecryptTests, - Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CTR), - PatternTestVariant(3, 7, OEMCrypto_CipherMode_CTR), - // Pattern length should be 10, but that is not guaranteed. - PatternTestVariant(1, 3, OEMCrypto_CipherMode_CTR), - PatternTestVariant(2, 1, OEMCrypto_CipherMode_CTR))); + Combine( + Values(MakePattern(0, 0), + MakePattern(3, 7), + // Pattern length should be 10, but that is not guaranteed. + MakePattern(1, 3), + MakePattern(2, 1)), + Values(OEMCrypto_CipherMode_CTR), Bool())); INSTANTIATE_TEST_CASE_P( CBCTests, OEMCryptoSessionTestsDecryptTests, - Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC), - PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC), - // HLS Edge case. We should follow the CENC spec, not HLS spec. - PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC), - PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC), - // Pattern length should be 10, but that is not guaranteed. - PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC), - PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC))); + Combine( + Values(MakePattern(0, 0), + MakePattern(3, 7), + // HLS Edge case. We should follow the CENC spec, not HLS spec. + MakePattern(9, 1), + MakePattern(1, 9), + // Pattern length should be 10, but that is not guaranteed. + MakePattern(1, 3), + MakePattern(2, 1)), + Values(OEMCrypto_CipherMode_CBC), Bool())); TEST_F(OEMCryptoSessionTests, DecryptSecureToClear) { Session s; @@ -3892,6 +3937,25 @@ TEST_F(GenericCryptoTest, GenericKeyBadEncrypt) { BadEncrypt(3, OEMCrypto_AES_CBC_128_NO_PADDING, buffer_size_); } +TEST_F(GenericCryptoTest, GenericKeyEncryptSameBuffer) { + EncryptAndLoadKeys(); + unsigned int key_index = 0; + vector expected_encrypted; + EncryptBuffer(key_index, clear_buffer_, &expected_encrypted); + ASSERT_EQ( + OEMCrypto_SUCCESS, + OEMCrypto_SelectKey(session_.session_id(), + session_.license().keys[key_index].key_id, + session_.license().keys[key_index].key_id_length)); + // Input and output are same buffer: + vector buffer = clear_buffer_; + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_Generic_Encrypt( + session_.session_id(), &buffer[0], buffer.size(), + iv_, OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0])); + ASSERT_EQ(expected_encrypted, buffer); +} + TEST_F(GenericCryptoTest, GenericKeyDecrypt) { EncryptAndLoadKeys(); unsigned int key_index = 1; @@ -3910,6 +3974,24 @@ TEST_F(GenericCryptoTest, GenericKeyDecrypt) { ASSERT_EQ(clear_buffer_, resultant); } +TEST_F(GenericCryptoTest, GenericKeyDecryptSameBuffer) { + EncryptAndLoadKeys(); + unsigned int key_index = 1; + vector encrypted; + EncryptBuffer(key_index, clear_buffer_, &encrypted); + ASSERT_EQ( + OEMCrypto_SUCCESS, + OEMCrypto_SelectKey(session_.session_id(), + session_.license().keys[key_index].key_id, + session_.license().keys[key_index].key_id_length)); + vector buffer = encrypted; + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_Generic_Decrypt( + session_.session_id(), &buffer[0], buffer.size(), iv_, + OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0])); + ASSERT_EQ(clear_buffer_, buffer); +} + TEST_F(GenericCryptoTest, GenericSecureToClear) { session_.license().keys[1].control.control_bits |= htonl( wvoec_mock::kControlObserveDataPath | wvoec_mock::kControlDataPathSecure);