Source and destination buffers may point to same buffer

Merge from Widevine repo of http://go/wvgerrit/23581

This CL adds some unit tests to oemcrypto to verify that DecryptCENC
and the generic encrypt and decrypt functions behave correctly when
the input and output buffer is the same. i.e. decrypt in place.

The mock and haystack are also updated to pass the tests.

b/34080119

Change-Id: Ie295bdaddbb8058bebb36f6dab092d307f249ecd
This commit is contained in:
Fred Gylys-Colwell
2017-03-01 13:19:12 -08:00
parent 27c01e82b5
commit 4025322185
4 changed files with 122 additions and 58 deletions

View File

@@ -1134,6 +1134,7 @@ OEMCryptoResult SessionContext::DecryptCBC(
AES_KEY aes_key;
AES_set_decrypt_key(&key[0], AES_BLOCK_SIZE * 8, &aes_key);
uint8_t iv[AES_BLOCK_SIZE];
uint8_t next_iv[AES_BLOCK_SIZE];
memcpy(iv, &initial_iv[0], AES_BLOCK_SIZE);
size_t l = 0;
@@ -1151,11 +1152,14 @@ OEMCryptoResult SessionContext::DecryptCBC(
memcpy(&clear_data[l], &cipher_data[l], size);
} else {
uint8_t aes_output[AES_BLOCK_SIZE];
// Save the iv for the next block, in case cipher_data is in the same
// buffer as clear_data.
memcpy(next_iv, &cipher_data[l], AES_BLOCK_SIZE);
AES_decrypt(&cipher_data[l], aes_output, &aes_key);
for (size_t n = 0; n < AES_BLOCK_SIZE; n++) {
clear_data[l + n] = aes_output[n] ^ iv[n];
}
memcpy(iv, &cipher_data[l], AES_BLOCK_SIZE);
memcpy(iv, next_iv, AES_BLOCK_SIZE);
}
l += size;
}

View File

@@ -35,17 +35,9 @@ using namespace std;
// GTest requires PrintTo to be in the same namespace as the thing it prints,
// which is std::vector in this case.
namespace std {
void PrintTo(const vector<uint8_t>& value, ostream* os) {
*os << wvcdm::b2a_hex(value);
}
void PrintTo(const PatternTestVariant& param, ostream* os) {
*os << ((param.mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode")
<< ", encrypt=" << param.pattern.encrypt
<< ", skip=" << param.pattern.skip;
}
} // namespace std
namespace wvoec {

View File

@@ -18,21 +18,7 @@ using namespace std;
// GTest requires PrintTo to be in the same namespace as the thing it prints,
// which is std::vector in this case.
namespace std {
struct PatternTestVariant {
PatternTestVariant(size_t encrypt, size_t skip, OEMCryptoCipherMode mode) {
this->pattern.encrypt = encrypt;
this->pattern.skip = skip;
this->pattern.offset = 0;
this->mode = mode;
}
OEMCrypto_CENCEncryptPatternDesc pattern;
OEMCryptoCipherMode mode;
};
void PrintTo(const vector<uint8_t>& value, ostream* os);
void PrintTo(const PatternTestVariant& param, ostream* os);
} // namespace std
namespace wvoec {

View File

@@ -36,10 +36,26 @@
#include "wv_cdm_constants.h"
#include "wv_keybox.h"
using namespace std;
using ::testing::WithParamInterface;
using ::testing::Bool;
using ::testing::Combine;
using ::testing::Range;
using ::testing::Values;
using ::testing::WithParamInterface;
using namespace std;
using std::tr1::tuple;
namespace std { // GTest wants PrintTo to be in the std namespace.
void PrintTo(const tuple<OEMCrypto_CENCEncryptPatternDesc,
OEMCryptoCipherMode, bool>& param, ostream* os) {
OEMCrypto_CENCEncryptPatternDesc pattern = std::tr1::get<0>(param);
OEMCryptoCipherMode mode = std::tr1::get<1>(param);
bool decrypt_inplace = std::tr1::get<2>(param);
*os << ((mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode")
<< ", encrypt=" << pattern.encrypt
<< ", skip=" << pattern.skip
<< ", decrypt in place = " << (decrypt_inplace ? "true":"false");
}
}
namespace wvoec {
@@ -1615,12 +1631,14 @@ struct SampleInitData {
class OEMCryptoSessionTestsDecryptTests
: public OEMCryptoSessionTests,
public WithParamInterface<PatternTestVariant> {
public WithParamInterface<tuple<OEMCrypto_CENCEncryptPatternDesc,
OEMCryptoCipherMode, bool> > {
protected:
virtual void SetUp() {
OEMCryptoSessionTests::SetUp();
pattern_ = GetParam().pattern;
cipher_mode_ = GetParam().mode;
pattern_ = std::tr1::get<0>(GetParam());
cipher_mode_ = std::tr1::get<1>(GetParam());
decrypt_inplace_ = std::tr1::get<2>(GetParam());
}
void FindTotalSize() {
@@ -1728,7 +1746,17 @@ class OEMCryptoSessionTestsDecryptTests
ASSERT_EQ(OEMCrypto_SUCCESS, sts);
// We decrypt each subsample.
vector<uint8_t> outputBuffer(total_size_ + 16, 0xaa);
vector<uint8_t> output_buffer(total_size_ + 16, 0xaa);
const uint8_t *input_buffer = NULL;
if (decrypt_inplace_) { // Use same buffer for input and output.
// Copy the useful data from encryptedData to output_buffer, which
// will be the same as input_buffer. Leave the 0xaa padding at the end.
for(int i=0; i < total_size_; i++) output_buffer[i] = encryptedData[i];
// Now let input_buffer point to the same data.
input_buffer = &output_buffer[0];
} else {
input_buffer = &encryptedData[0];
}
size_t buffer_offset = 0;
for (size_t i = 0; i < subsample_size_.size(); i++) {
OEMCrypto_CENCEncryptPatternDesc pattern = pattern_;
@@ -1739,7 +1767,7 @@ class OEMCryptoSessionTestsDecryptTests
uint8_t subsample_flags = 0;
if (subsample_size_[i].clear_size > 0) {
destBuffer.type = OEMCrypto_BufferType_Clear;
destBuffer.buffer.clear.address = &outputBuffer[buffer_offset];
destBuffer.buffer.clear.address = &output_buffer[buffer_offset];
destBuffer.buffer.clear.max_length = total_size_ - buffer_offset;
if (i == 0) subsample_flags |= OEMCrypto_FirstSubsample;
if ((i == subsample_size_.size() - 1) &&
@@ -1747,7 +1775,7 @@ class OEMCryptoSessionTestsDecryptTests
subsample_flags |= OEMCrypto_LastSubsample;
}
sts =
OEMCrypto_DecryptCENC(s.session_id(), &encryptedData[buffer_offset],
OEMCrypto_DecryptCENC(s.session_id(), input_buffer + buffer_offset,
subsample_size_[i].clear_size, is_encrypted,
sample_init_data_[i].iv, block_offset,
&destBuffer, &pattern, subsample_flags);
@@ -1756,7 +1784,7 @@ class OEMCryptoSessionTestsDecryptTests
}
if (subsample_size_[i].encrypted_size > 0) {
destBuffer.type = OEMCrypto_BufferType_Clear;
destBuffer.buffer.clear.address = &outputBuffer[buffer_offset];
destBuffer.buffer.clear.address = &output_buffer[buffer_offset];
destBuffer.buffer.clear.max_length = total_size_ - buffer_offset;
is_encrypted = true;
block_offset = sample_init_data_[i].block_offset;
@@ -1768,7 +1796,7 @@ class OEMCryptoSessionTestsDecryptTests
subsample_flags |= OEMCrypto_LastSubsample;
}
sts = OEMCrypto_DecryptCENC(
s.session_id(), &encryptedData[buffer_offset],
s.session_id(), input_buffer + buffer_offset,
subsample_size_[i].encrypted_size, is_encrypted,
sample_init_data_[i].iv, block_offset, &destBuffer, &pattern,
subsample_flags);
@@ -1782,13 +1810,14 @@ class OEMCryptoSessionTestsDecryptTests
buffer_offset += subsample_size_[i].encrypted_size;
}
}
EXPECT_EQ(0xaa, outputBuffer[total_size_]) << "Buffer overrun.";
outputBuffer.resize(total_size_);
EXPECT_EQ(unencryptedData, outputBuffer);
EXPECT_EQ(0xaa, output_buffer[total_size_]) << "Buffer overrun.";
output_buffer.resize(total_size_);
EXPECT_EQ(unencryptedData, output_buffer);
}
OEMCrypto_CENCEncryptPatternDesc pattern_;
OEMCryptoCipherMode cipher_mode_;
bool decrypt_inplace_; // If true, input and output buffers are the same.
vector<SampleSize> subsample_size_;
size_t total_size_;
vector<SampleInitData> sample_init_data_;
@@ -2026,35 +2055,51 @@ TEST_F(OEMCryptoSessionTests, DecryptUnencryptedNoKey) {
ASSERT_EQ(in_buffer, out_buffer);
}
// Used to construct a specific pattern.
OEMCrypto_CENCEncryptPatternDesc MakePattern(size_t encrypt, size_t skip) {
OEMCrypto_CENCEncryptPatternDesc pattern;
pattern.encrypt = encrypt;
pattern.skip = skip;
pattern.offset = 0; // offset is deprecated.
return pattern;
}
INSTANTIATE_TEST_CASE_P(CTRTests, OEMCryptoSessionTestsPartialBlockTests,
Values(PatternTestVariant(0, 0,
OEMCrypto_CipherMode_CTR)));
Combine(Values(MakePattern(0,0)),
Values(OEMCrypto_CipherMode_CTR),
Bool()));
INSTANTIATE_TEST_CASE_P(
CBCTests, OEMCryptoSessionTestsPartialBlockTests,
Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC),
PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC),
PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC),
PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC),
PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC)));
Combine(
Values(MakePattern(0, 0),
MakePattern(3, 7),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
MakePattern(9, 1),
MakePattern(1, 9),
MakePattern(1, 3),
MakePattern(2, 1)),
Values(OEMCrypto_CipherMode_CBC), Bool()));
INSTANTIATE_TEST_CASE_P(
CTRTests, OEMCryptoSessionTestsDecryptTests,
Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CTR),
PatternTestVariant(3, 7, OEMCrypto_CipherMode_CTR),
// Pattern length should be 10, but that is not guaranteed.
PatternTestVariant(1, 3, OEMCrypto_CipherMode_CTR),
PatternTestVariant(2, 1, OEMCrypto_CipherMode_CTR)));
Combine(
Values(MakePattern(0, 0),
MakePattern(3, 7),
// Pattern length should be 10, but that is not guaranteed.
MakePattern(1, 3),
MakePattern(2, 1)),
Values(OEMCrypto_CipherMode_CTR), Bool()));
INSTANTIATE_TEST_CASE_P(
CBCTests, OEMCryptoSessionTestsDecryptTests,
Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC),
PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC),
PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC),
// Pattern length should be 10, but that is not guaranteed.
PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC),
PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC)));
Combine(
Values(MakePattern(0, 0),
MakePattern(3, 7),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
MakePattern(9, 1),
MakePattern(1, 9),
// Pattern length should be 10, but that is not guaranteed.
MakePattern(1, 3),
MakePattern(2, 1)),
Values(OEMCrypto_CipherMode_CBC), Bool()));
TEST_F(OEMCryptoSessionTests, DecryptSecureToClear) {
Session s;
@@ -3892,6 +3937,25 @@ TEST_F(GenericCryptoTest, GenericKeyBadEncrypt) {
BadEncrypt(3, OEMCrypto_AES_CBC_128_NO_PADDING, buffer_size_);
}
TEST_F(GenericCryptoTest, GenericKeyEncryptSameBuffer) {
EncryptAndLoadKeys();
unsigned int key_index = 0;
vector<uint8_t> expected_encrypted;
EncryptBuffer(key_index, clear_buffer_, &expected_encrypted);
ASSERT_EQ(
OEMCrypto_SUCCESS,
OEMCrypto_SelectKey(session_.session_id(),
session_.license().keys[key_index].key_id,
session_.license().keys[key_index].key_id_length));
// Input and output are same buffer:
vector<uint8_t> buffer = clear_buffer_;
ASSERT_EQ(OEMCrypto_SUCCESS,
OEMCrypto_Generic_Encrypt(
session_.session_id(), &buffer[0], buffer.size(),
iv_, OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0]));
ASSERT_EQ(expected_encrypted, buffer);
}
TEST_F(GenericCryptoTest, GenericKeyDecrypt) {
EncryptAndLoadKeys();
unsigned int key_index = 1;
@@ -3910,6 +3974,24 @@ TEST_F(GenericCryptoTest, GenericKeyDecrypt) {
ASSERT_EQ(clear_buffer_, resultant);
}
TEST_F(GenericCryptoTest, GenericKeyDecryptSameBuffer) {
EncryptAndLoadKeys();
unsigned int key_index = 1;
vector<uint8_t> encrypted;
EncryptBuffer(key_index, clear_buffer_, &encrypted);
ASSERT_EQ(
OEMCrypto_SUCCESS,
OEMCrypto_SelectKey(session_.session_id(),
session_.license().keys[key_index].key_id,
session_.license().keys[key_index].key_id_length));
vector<uint8_t> buffer = encrypted;
ASSERT_EQ(OEMCrypto_SUCCESS,
OEMCrypto_Generic_Decrypt(
session_.session_id(), &buffer[0], buffer.size(), iv_,
OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0]));
ASSERT_EQ(clear_buffer_, buffer);
}
TEST_F(GenericCryptoTest, GenericSecureToClear) {
session_.license().keys[1].control.control_bits |= htonl(
wvoec_mock::kControlObserveDataPath | wvoec_mock::kControlDataPathSecure);