Merge "Source and destination buffers may point to same buffer"

This commit is contained in:
Fred Gylys-Colwell
2017-03-02 04:44:11 +00:00
committed by Android (Google) Code Review
4 changed files with 122 additions and 58 deletions

View File

@@ -35,17 +35,9 @@ using namespace std;
// GTest requires PrintTo to be in the same namespace as the thing it prints,
// which is std::vector in this case.
namespace std {
void PrintTo(const vector<uint8_t>& value, ostream* os) {
*os << wvcdm::b2a_hex(value);
}
void PrintTo(const PatternTestVariant& param, ostream* os) {
*os << ((param.mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode")
<< ", encrypt=" << param.pattern.encrypt
<< ", skip=" << param.pattern.skip;
}
} // namespace std
namespace wvoec {

View File

@@ -18,21 +18,7 @@ using namespace std;
// GTest requires PrintTo to be in the same namespace as the thing it prints,
// which is std::vector in this case.
namespace std {
struct PatternTestVariant {
PatternTestVariant(size_t encrypt, size_t skip, OEMCryptoCipherMode mode) {
this->pattern.encrypt = encrypt;
this->pattern.skip = skip;
this->pattern.offset = 0;
this->mode = mode;
}
OEMCrypto_CENCEncryptPatternDesc pattern;
OEMCryptoCipherMode mode;
};
void PrintTo(const vector<uint8_t>& value, ostream* os);
void PrintTo(const PatternTestVariant& param, ostream* os);
} // namespace std
namespace wvoec {

View File

@@ -36,10 +36,26 @@
#include "wv_cdm_constants.h"
#include "wv_keybox.h"
using namespace std;
using ::testing::WithParamInterface;
using ::testing::Bool;
using ::testing::Combine;
using ::testing::Range;
using ::testing::Values;
using ::testing::WithParamInterface;
using namespace std;
using std::tr1::tuple;
namespace std { // GTest wants PrintTo to be in the std namespace.
void PrintTo(const tuple<OEMCrypto_CENCEncryptPatternDesc,
OEMCryptoCipherMode, bool>& param, ostream* os) {
OEMCrypto_CENCEncryptPatternDesc pattern = std::tr1::get<0>(param);
OEMCryptoCipherMode mode = std::tr1::get<1>(param);
bool decrypt_inplace = std::tr1::get<2>(param);
*os << ((mode == OEMCrypto_CipherMode_CTR) ? "CTR mode" : "CBC mode")
<< ", encrypt=" << pattern.encrypt
<< ", skip=" << pattern.skip
<< ", decrypt in place = " << (decrypt_inplace ? "true":"false");
}
}
namespace wvoec {
@@ -1615,12 +1631,14 @@ struct SampleInitData {
class OEMCryptoSessionTestsDecryptTests
: public OEMCryptoSessionTests,
public WithParamInterface<PatternTestVariant> {
public WithParamInterface<tuple<OEMCrypto_CENCEncryptPatternDesc,
OEMCryptoCipherMode, bool> > {
protected:
virtual void SetUp() {
OEMCryptoSessionTests::SetUp();
pattern_ = GetParam().pattern;
cipher_mode_ = GetParam().mode;
pattern_ = std::tr1::get<0>(GetParam());
cipher_mode_ = std::tr1::get<1>(GetParam());
decrypt_inplace_ = std::tr1::get<2>(GetParam());
}
void FindTotalSize() {
@@ -1728,7 +1746,17 @@ class OEMCryptoSessionTestsDecryptTests
ASSERT_EQ(OEMCrypto_SUCCESS, sts);
// We decrypt each subsample.
vector<uint8_t> outputBuffer(total_size_ + 16, 0xaa);
vector<uint8_t> output_buffer(total_size_ + 16, 0xaa);
const uint8_t *input_buffer = NULL;
if (decrypt_inplace_) { // Use same buffer for input and output.
// Copy the useful data from encryptedData to output_buffer, which
// will be the same as input_buffer. Leave the 0xaa padding at the end.
for(int i=0; i < total_size_; i++) output_buffer[i] = encryptedData[i];
// Now let input_buffer point to the same data.
input_buffer = &output_buffer[0];
} else {
input_buffer = &encryptedData[0];
}
size_t buffer_offset = 0;
for (size_t i = 0; i < subsample_size_.size(); i++) {
OEMCrypto_CENCEncryptPatternDesc pattern = pattern_;
@@ -1739,7 +1767,7 @@ class OEMCryptoSessionTestsDecryptTests
uint8_t subsample_flags = 0;
if (subsample_size_[i].clear_size > 0) {
destBuffer.type = OEMCrypto_BufferType_Clear;
destBuffer.buffer.clear.address = &outputBuffer[buffer_offset];
destBuffer.buffer.clear.address = &output_buffer[buffer_offset];
destBuffer.buffer.clear.max_length = total_size_ - buffer_offset;
if (i == 0) subsample_flags |= OEMCrypto_FirstSubsample;
if ((i == subsample_size_.size() - 1) &&
@@ -1747,7 +1775,7 @@ class OEMCryptoSessionTestsDecryptTests
subsample_flags |= OEMCrypto_LastSubsample;
}
sts =
OEMCrypto_DecryptCENC(s.session_id(), &encryptedData[buffer_offset],
OEMCrypto_DecryptCENC(s.session_id(), input_buffer + buffer_offset,
subsample_size_[i].clear_size, is_encrypted,
sample_init_data_[i].iv, block_offset,
&destBuffer, &pattern, subsample_flags);
@@ -1756,7 +1784,7 @@ class OEMCryptoSessionTestsDecryptTests
}
if (subsample_size_[i].encrypted_size > 0) {
destBuffer.type = OEMCrypto_BufferType_Clear;
destBuffer.buffer.clear.address = &outputBuffer[buffer_offset];
destBuffer.buffer.clear.address = &output_buffer[buffer_offset];
destBuffer.buffer.clear.max_length = total_size_ - buffer_offset;
is_encrypted = true;
block_offset = sample_init_data_[i].block_offset;
@@ -1768,7 +1796,7 @@ class OEMCryptoSessionTestsDecryptTests
subsample_flags |= OEMCrypto_LastSubsample;
}
sts = OEMCrypto_DecryptCENC(
s.session_id(), &encryptedData[buffer_offset],
s.session_id(), input_buffer + buffer_offset,
subsample_size_[i].encrypted_size, is_encrypted,
sample_init_data_[i].iv, block_offset, &destBuffer, &pattern,
subsample_flags);
@@ -1782,13 +1810,14 @@ class OEMCryptoSessionTestsDecryptTests
buffer_offset += subsample_size_[i].encrypted_size;
}
}
EXPECT_EQ(0xaa, outputBuffer[total_size_]) << "Buffer overrun.";
outputBuffer.resize(total_size_);
EXPECT_EQ(unencryptedData, outputBuffer);
EXPECT_EQ(0xaa, output_buffer[total_size_]) << "Buffer overrun.";
output_buffer.resize(total_size_);
EXPECT_EQ(unencryptedData, output_buffer);
}
OEMCrypto_CENCEncryptPatternDesc pattern_;
OEMCryptoCipherMode cipher_mode_;
bool decrypt_inplace_; // If true, input and output buffers are the same.
vector<SampleSize> subsample_size_;
size_t total_size_;
vector<SampleInitData> sample_init_data_;
@@ -2026,35 +2055,51 @@ TEST_F(OEMCryptoSessionTests, DecryptUnencryptedNoKey) {
ASSERT_EQ(in_buffer, out_buffer);
}
// Used to construct a specific pattern.
OEMCrypto_CENCEncryptPatternDesc MakePattern(size_t encrypt, size_t skip) {
OEMCrypto_CENCEncryptPatternDesc pattern;
pattern.encrypt = encrypt;
pattern.skip = skip;
pattern.offset = 0; // offset is deprecated.
return pattern;
}
INSTANTIATE_TEST_CASE_P(CTRTests, OEMCryptoSessionTestsPartialBlockTests,
Values(PatternTestVariant(0, 0,
OEMCrypto_CipherMode_CTR)));
Combine(Values(MakePattern(0,0)),
Values(OEMCrypto_CipherMode_CTR),
Bool()));
INSTANTIATE_TEST_CASE_P(
CBCTests, OEMCryptoSessionTestsPartialBlockTests,
Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC),
PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC),
PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC),
PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC),
PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC)));
Combine(
Values(MakePattern(0, 0),
MakePattern(3, 7),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
MakePattern(9, 1),
MakePattern(1, 9),
MakePattern(1, 3),
MakePattern(2, 1)),
Values(OEMCrypto_CipherMode_CBC), Bool()));
INSTANTIATE_TEST_CASE_P(
CTRTests, OEMCryptoSessionTestsDecryptTests,
Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CTR),
PatternTestVariant(3, 7, OEMCrypto_CipherMode_CTR),
// Pattern length should be 10, but that is not guaranteed.
PatternTestVariant(1, 3, OEMCrypto_CipherMode_CTR),
PatternTestVariant(2, 1, OEMCrypto_CipherMode_CTR)));
Combine(
Values(MakePattern(0, 0),
MakePattern(3, 7),
// Pattern length should be 10, but that is not guaranteed.
MakePattern(1, 3),
MakePattern(2, 1)),
Values(OEMCrypto_CipherMode_CTR), Bool()));
INSTANTIATE_TEST_CASE_P(
CBCTests, OEMCryptoSessionTestsDecryptTests,
Values(PatternTestVariant(0, 0, OEMCrypto_CipherMode_CBC),
PatternTestVariant(3, 7, OEMCrypto_CipherMode_CBC),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
PatternTestVariant(9, 1, OEMCrypto_CipherMode_CBC),
PatternTestVariant(1, 9, OEMCrypto_CipherMode_CBC),
// Pattern length should be 10, but that is not guaranteed.
PatternTestVariant(1, 3, OEMCrypto_CipherMode_CBC),
PatternTestVariant(2, 1, OEMCrypto_CipherMode_CBC)));
Combine(
Values(MakePattern(0, 0),
MakePattern(3, 7),
// HLS Edge case. We should follow the CENC spec, not HLS spec.
MakePattern(9, 1),
MakePattern(1, 9),
// Pattern length should be 10, but that is not guaranteed.
MakePattern(1, 3),
MakePattern(2, 1)),
Values(OEMCrypto_CipherMode_CBC), Bool()));
TEST_F(OEMCryptoSessionTests, DecryptSecureToClear) {
Session s;
@@ -3892,6 +3937,25 @@ TEST_F(GenericCryptoTest, GenericKeyBadEncrypt) {
BadEncrypt(3, OEMCrypto_AES_CBC_128_NO_PADDING, buffer_size_);
}
TEST_F(GenericCryptoTest, GenericKeyEncryptSameBuffer) {
EncryptAndLoadKeys();
unsigned int key_index = 0;
vector<uint8_t> expected_encrypted;
EncryptBuffer(key_index, clear_buffer_, &expected_encrypted);
ASSERT_EQ(
OEMCrypto_SUCCESS,
OEMCrypto_SelectKey(session_.session_id(),
session_.license().keys[key_index].key_id,
session_.license().keys[key_index].key_id_length));
// Input and output are same buffer:
vector<uint8_t> buffer = clear_buffer_;
ASSERT_EQ(OEMCrypto_SUCCESS,
OEMCrypto_Generic_Encrypt(
session_.session_id(), &buffer[0], buffer.size(),
iv_, OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0]));
ASSERT_EQ(expected_encrypted, buffer);
}
TEST_F(GenericCryptoTest, GenericKeyDecrypt) {
EncryptAndLoadKeys();
unsigned int key_index = 1;
@@ -3910,6 +3974,24 @@ TEST_F(GenericCryptoTest, GenericKeyDecrypt) {
ASSERT_EQ(clear_buffer_, resultant);
}
TEST_F(GenericCryptoTest, GenericKeyDecryptSameBuffer) {
EncryptAndLoadKeys();
unsigned int key_index = 1;
vector<uint8_t> encrypted;
EncryptBuffer(key_index, clear_buffer_, &encrypted);
ASSERT_EQ(
OEMCrypto_SUCCESS,
OEMCrypto_SelectKey(session_.session_id(),
session_.license().keys[key_index].key_id,
session_.license().keys[key_index].key_id_length));
vector<uint8_t> buffer = encrypted;
ASSERT_EQ(OEMCrypto_SUCCESS,
OEMCrypto_Generic_Decrypt(
session_.session_id(), &buffer[0], buffer.size(), iv_,
OEMCrypto_AES_CBC_128_NO_PADDING, &buffer[0]));
ASSERT_EQ(clear_buffer_, buffer);
}
TEST_F(GenericCryptoTest, GenericSecureToClear) {
session_.license().keys[1].control.control_bits |= htonl(
wvoec_mock::kControlObserveDataPath | wvoec_mock::kControlDataPathSecure);