diff --git a/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp b/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp index 743716cb..e409c0f0 100644 --- a/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp +++ b/libwvdrmengine/oemcrypto/mock/src/oemcrypto_session.cpp @@ -1134,6 +1134,7 @@ OEMCryptoResult SessionContext::DecryptCBC( AES_KEY aes_key; AES_set_decrypt_key(&key[0], AES_BLOCK_SIZE * 8, &aes_key); uint8_t iv[AES_BLOCK_SIZE]; + uint8_t next_iv[AES_BLOCK_SIZE]; memcpy(iv, &initial_iv[0], AES_BLOCK_SIZE); size_t l = 0; @@ -1151,11 +1152,14 @@ OEMCryptoResult SessionContext::DecryptCBC( memcpy(&clear_data[l], &cipher_data[l], size); } else { uint8_t aes_output[AES_BLOCK_SIZE]; + // Save the iv for the next block, in case cipher_data is in the same + // buffer as clear_data. + memcpy(next_iv, &cipher_data[l], AES_BLOCK_SIZE); AES_decrypt(&cipher_data[l], aes_output, &aes_key); for (size_t n = 0; n < AES_BLOCK_SIZE; n++) { clear_data[l + n] = aes_output[n] ^ iv[n]; } - memcpy(iv, &cipher_data[l], AES_BLOCK_SIZE); + memcpy(iv, next_iv, AES_BLOCK_SIZE); } l += size; } diff --git a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp index d84285a7..f806ddb8 100644 --- a/libwvdrmengine/oemcrypto/test/oec_session_util.cpp +++ b/libwvdrmengine/oemcrypto/test/oec_session_util.cpp @@ -35,17 +35,9 @@ using namespace std; // GTest requires PrintTo to be in the same namespace as the thing it prints, // which is std::vector in this case. namespace std { - void PrintTo(const vector& value, ostream* os) { *os << wvcdm::b2a_hex(value); } - -void PrintTo(const PatternTestVariant& param, ostream* os) { - *os << ((param.mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode") - << ", encrypt=" << param.pattern.encrypt - << ", skip=" << param.pattern.skip; -} - } // namespace std namespace wvoec { diff --git a/libwvdrmengine/oemcrypto/test/oec_session_util.h b/libwvdrmengine/oemcrypto/test/oec_session_util.h index 03716f60..926f21dd 100644 --- a/libwvdrmengine/oemcrypto/test/oec_session_util.h +++ b/libwvdrmengine/oemcrypto/test/oec_session_util.h @@ -18,21 +18,7 @@ using namespace std; // GTest requires PrintTo to be in the same namespace as the thing it prints, // which is std::vector in this case. namespace std { - -struct PatternTestVariant { - PatternTestVariant(size_t encrypt, size_t skip, OEMCryptoCipherMode mode) { - this->pattern.encrypt = encrypt; - this->pattern.skip = skip; - this->pattern.offset = 0; - this->mode = mode; - } - OEMCrypto_CENCEncryptPatternDesc pattern; - OEMCryptoCipherMode mode; -}; - void PrintTo(const vector& value, ostream* os); -void PrintTo(const PatternTestVariant& param, ostream* os); - } // namespace std namespace wvoec { diff --git a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp index 7b4544b6..42eb0bb5 100644 --- a/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp +++ b/libwvdrmengine/oemcrypto/test/oemcrypto_test.cpp @@ -36,10 +36,26 @@ #include "wv_cdm_constants.h" #include "wv_keybox.h" -using namespace std; -using ::testing::WithParamInterface; +using ::testing::Bool; +using ::testing::Combine; using ::testing::Range; using ::testing::Values; +using ::testing::WithParamInterface; +using namespace std; +using std::tr1::tuple; + +namespace std { // GTest wants PrintTo to be in the std namespace. +void PrintTo(const tuple& param, ostream* os) { + OEMCrypto_CENCEncryptPatternDesc pattern = std::tr1::get<0>(param); + OEMCryptoCipherMode mode = std::tr1::get<1>(param); + bool decrypt_inplace = std::tr1::get<2>(param); + *os << ((mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode") + << ", encrypt=" << pattern.encrypt + << ", skip=" << pattern.skip + << ", decrypt in place = " << (decrypt_inplace ? "true":"false"); +} +} namespace wvoec { @@ -1615,12 +1631,14 @@ struct SampleInitData { class OEMCryptoSessionTestsDecryptTests : public OEMCryptoSessionTests, - public WithParamInterface { + public WithParamInterface > { protected: virtual void SetUp() { OEMCryptoSessionTests::SetUp(); - pattern_ = GetParam().pattern; - cipher_mode_ = GetParam().mode; + pattern_ = std::tr1::get<0>(GetParam()); + cipher_mode_ = std::tr1::get<1>(GetParam()); + decrypt_inplace_ = std::tr1::get<2>(GetParam()); } void FindTotalSize() { @@ -1728,7 +1746,17 @@ class OEMCryptoSessionTestsDecryptTests ASSERT_EQ(OEMCrypto_SUCCESS, sts); // We decrypt each subsample. - vector outputBuffer(total_size_ + 16, 0xaa); + vector output_buffer(total_size_ + 16, 0xaa); + const uint8_t *input_buffer = NULL; + if (decrypt_inplace_) { // Use same buffer for input and output. + // Copy the useful data from encryptedData to output_buffer, which + // will be the same as input_buffer. Leave the 0xaa padding at the end. + for(int i=0; i < total_size_; i++) output_buffer[i] = encryptedData[i]; + // Now let input_buffer point to the same data. + input_buffer = &output_buffer[0]; + } else { + input_buffer = &encryptedData[0]; + } size_t buffer_offset = 0; for (size_t i = 0; i < subsample_size_.size(); i++) { OEMCrypto_CENCEncryptPatternDesc pattern = pattern_; @@ -1739,7 +1767,7 @@ class OEMCryptoSessionTestsDecryptTests uint8_t subsample_flags = 0; if (subsample_size_[i].clear_size > 0) { destBuffer.type = OEMCrypto_BufferType_Clear; - destBuffer.buffer.clear.address = &outputBuffer[buffer_offset]; + destBuffer.buffer.clear.address = &output_buffer[buffer_offset]; destBuffer.buffer.clear.max_length = total_size_ - buffer_offset; if (i == 0) subsample_flags |= OEMCrypto_FirstSubsample; if ((i == subsample_size_.size() - 1) && @@ -1747,7 +1775,7 @@ class OEMCryptoSessionTestsDecryptTests subsample_flags |= OEMCrypto_LastSubsample; } sts = - OEMCrypto_DecryptCENC(s.session_id(), &encryptedData[buffer_offset], + OEMCrypto_DecryptCENC(s.session_id(), input_buffer + buffer_offset, subsample_size_[i].clear_size, is_encrypted, sample_init_data_[i].iv, block_offset, &destBuffer, &pattern, subsample_flags); @@ -1756,7 +1784,7 @@ class OEMCryptoSessionTestsDecryptTests } if (subsample_size_[i].encrypted_size > 0) { destBuffer.type = OEMCrypto_BufferType_Clear; - destBuffer.buffer.clear.address = &outputBuffer[buffer_offset]; + destBuffer.buffer.clear.address = &output_buffer[buffer_offset]; destBuffer.buffer.clear.max_length = total_size_ - buffer_offset; is_encrypted = true; block_offset = sample_init_data_[i].block_offset; @@ -1768,7 +1796,7 @@ class OEMCryptoSessionTestsDecryptTests subsample_flags |= OEMCrypto_LastSubsample; } sts = OEMCrypto_DecryptCENC( - s.session_id(), &encryptedData[buffer_offset], + s.session_id(), input_buffer + buffer_offset, subsample_size_[i].encrypted_size, is_encrypted, sample_init_data_[i].iv, block_offset, &destBuffer, &pattern, subsample_flags); @@ -1782,13 +1810,14 @@ class OEMCryptoSessionTestsDecryptTests buffer_offset += subsample_size_[i].encrypted_size; } } - EXPECT_EQ(0xaa, outputBuffer[total_size_]) << "Buffer overrun."; - outputBuffer.resize(total_size_); - EXPECT_EQ(unencryptedData, outputBuffer); + EXPECT_EQ(0xaa, output_buffer[total_size_]) << "Buffer overrun."; + output_buffer.resize(total_size_); + EXPECT_EQ(unencryptedData, output_buffer); } OEMCrypto_CENCEncryptPatternDesc pattern_; OEMCryptoCipherMode cipher_mode_; + bool decrypt_inplace_; // If true, input and output buffers are the same. vector subsample_size_; size_t total_size_; vector sample_init_data_; @@ -2026,35 +2055,51 @@ TEST_F(OEMCryptoSessionTests, DecryptUnencryptedNoKey) { ASSERT_EQ(in_buffer, out_buffer); } +// Used to construct a specific pattern. +OEMCrypto_CENCEncryptPatternDesc MakePattern(size_t encrypt, size_t skip) { + OEMCrypto_CENCEncryptPatternDesc pattern; + pattern.encrypt = encrypt; + pattern.skip = skip; + pattern.offset = 0; // offset is deprecated. + return pattern; +} + INSTANTIATE_TEST_CASE_P(CTRTests, OEMCryptoSessionTestsPartialBlockTests, - Values(PatternTestVariant(0, 0, - OEMCrypto_CipherMode_CTR))); + Combine(Values(MakePattern(0,0)), + Values(OEMCrypto_CipherMode_CTR), + Bool())); INSTANTIATE_TEST_CASE_P( CBCTests, OEMCryptoSessionTestsPartialBlockTests, - Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC), - PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC), - // HLS Edge case. We should follow the CENC spec, not HLS spec. - PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC), - PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC), - PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC), - PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC))); + Combine( + Values(MakePattern(0, 0), + MakePattern(3, 7), + // HLS Edge case. We should follow the CENC spec, not HLS spec. + MakePattern(9, 1), + MakePattern(1, 9), + MakePattern(1, 3), + MakePattern(2, 1)), + Values(OEMCrypto_CipherMode_CBC), Bool())); INSTANTIATE_TEST_CASE_P( CTRTests, OEMCryptoSessionTestsDecryptTests, - Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CTR), - PatternTestVariant(3, 7, OEMCrypto_CipherMode_CTR), - // Pattern length should be 10, but that is not guaranteed. - PatternTestVariant(1, 3, OEMCrypto_CipherMode_CTR), - PatternTestVariant(2, 1, OEMCrypto_CipherMode_CTR))); + Combine( + Values(MakePattern(0, 0), + MakePattern(3, 7), + // Pattern length should be 10, but that is not guaranteed. + MakePattern(1, 3), + MakePattern(2, 1)), + Values(OEMCrypto_CipherMode_CTR), Bool())); INSTANTIATE_TEST_CASE_P( CBCTests, OEMCryptoSessionTestsDecryptTests, - Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC), - PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC), - // HLS Edge case. We should follow the CENC spec, not HLS spec. - PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC), - PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC), - // Pattern length should be 10, but that is not guaranteed. - PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC), - PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC))); + Combine( + Values(MakePattern(0, 0), + MakePattern(3, 7), + // HLS Edge case. We should follow the CENC spec, not HLS spec. + MakePattern(9, 1), + MakePattern(1, 9), + // Pattern length should be 10, but that is not guaranteed. + MakePattern(1, 3), + MakePattern(2, 1)), + Values(OEMCrypto_CipherMode_CBC), Bool())); TEST_F(OEMCryptoSessionTests, DecryptSecureToClear) { Session s; @@ -3892,6 +3937,25 @@ TEST_F(GenericCryptoTest, GenericKeyBadEncrypt) { BadEncrypt(3, OEMCrypto_AES_CBC_128_NO_PADDING, buffer_size_); } +TEST_F(GenericCryptoTest, GenericKeyEncryptSameBuffer) { + EncryptAndLoadKeys(); + unsigned int key_index = 0; + vector expected_encrypted; + EncryptBuffer(key_index, clear_buffer_, &expected_encrypted); + ASSERT_EQ( + OEMCrypto_SUCCESS, + OEMCrypto_SelectKey(session_.session_id(), + session_.license().keys[key_index].key_id, + session_.license().keys[key_index].key_id_length)); + // Input and output are same buffer: + vector buffer = clear_buffer_; + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_Generic_Encrypt( + session_.session_id(), &buffer[0], buffer.size(), + iv_, OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0])); + ASSERT_EQ(expected_encrypted, buffer); +} + TEST_F(GenericCryptoTest, GenericKeyDecrypt) { EncryptAndLoadKeys(); unsigned int key_index = 1; @@ -3910,6 +3974,24 @@ TEST_F(GenericCryptoTest, GenericKeyDecrypt) { ASSERT_EQ(clear_buffer_, resultant); } +TEST_F(GenericCryptoTest, GenericKeyDecryptSameBuffer) { + EncryptAndLoadKeys(); + unsigned int key_index = 1; + vector encrypted; + EncryptBuffer(key_index, clear_buffer_, &encrypted); + ASSERT_EQ( + OEMCrypto_SUCCESS, + OEMCrypto_SelectKey(session_.session_id(), + session_.license().keys[key_index].key_id, + session_.license().keys[key_index].key_id_length)); + vector buffer = encrypted; + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_Generic_Decrypt( + session_.session_id(), &buffer[0], buffer.size(), iv_, + OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0])); + ASSERT_EQ(clear_buffer_, buffer); +} + TEST_F(GenericCryptoTest, GenericSecureToClear) { session_.license().keys[1].control.control_bits |= htonl( wvoec_mock::kControlObserveDataPath | wvoec_mock::kControlDataPathSecure);