Merge "Rework Device File Matchers to Avoid Buffer Overflow"

This commit is contained in:
John Bruce
2020-06-25 21:58:58 +00:00
committed by Android (Google) Code Review

View File

@@ -9,6 +9,7 @@
#include <memory>
#include <string>
#include <vector>
#include "arraysize.h"
#include "cdm_random.h"
@@ -25,8 +26,6 @@ namespace {
const uint32_t kCertificateLen = 700;
const uint32_t kWrappedKeyLen = 500;
const uint32_t kProtobufEstimatedOverhead = 200;
const std::string kEmptyString;
// Structurally valid test certificate.
@@ -2025,6 +2024,7 @@ class MockFileSystem : public FileSystem {
// gmock methods
using ::testing::_;
using ::testing::AllArgs;
using ::testing::AllOf;
using ::testing::DoAll;
using ::testing::Eq;
@@ -2113,87 +2113,18 @@ class DeviceFilesDeleteMultipleUsageInfoTest
public ::testing::WithParamInterface<int> {};
MATCHER(IsCreateFileFlagSet, "") { return FileSystem::kCreate & arg; }
MATCHER_P(IsStrEq, str, "") {
// Estimating the length of data. We can have gmock provide length
// as well as pointer to data but that will introduce a dependency on tr1
return memcmp(arg, str.c_str(), str.size()) == 0;
MATCHER_P(StrAndLenEq, str, "") {
const std::string data(std::get<0>(arg), std::get<1>(arg));
return data == str;
}
MATCHER_P(ContainsAllElementsInVector, str_vector, "") {
// Estimating the length of data. We can have gmock provide length
// as well as pointer to data but that will introduce a dependency on tr1
size_t str_length = 0;
for (size_t i = 0; i < str_vector.size(); ++i) {
str_length += str_vector[i].size();
}
std::string data(arg, str_length + kProtobufEstimatedOverhead);
bool all_entries_found = true;
for (size_t i = 0; i < str_vector.size(); ++i) {
if (data.find(str_vector[i]) == std::string::npos) {
all_entries_found = false;
MATCHER_P(StrAndLenContains, str_vector, "") {
const std::string data(std::get<0>(arg), std::get<1>(arg));
for (const std::string& str : str_vector) {
if (data.find(str) == std::string::npos) {
return false;
}
}
return all_entries_found;
}
MATCHER_P2(Contains, str1, size, "") {
// Estimating the length of data. We can have gmock provide length
// as well as pointer to data but that will introduce a dependency on tr1
std::string data(arg, size + str1.size() + kProtobufEstimatedOverhead);
return (data.find(str1) != std::string::npos);
}
MATCHER_P3(Contains, str1, str2, size, "") {
// Estimating the length of data. We can have gmock provide length
// as well as pointer to data but that will introduce a dependency on tr1
std::string data(
arg, size + str1.size() + str2.size() + kProtobufEstimatedOverhead);
return (data.find(str1) != std::string::npos &&
data.find(str2) != std::string::npos);
}
MATCHER_P4(Contains, str1, str2, str3, size, "") {
// Estimating the length of data. We can have gmock provide length
// as well as pointer to data but that will introduce a dependency on tr1
std::string data(arg, size + str1.size() + str2.size() + str3.size() +
kProtobufEstimatedOverhead);
return (data.find(str1) != std::string::npos &&
data.find(str2) != std::string::npos &&
data.find(str3) != std::string::npos);
}
MATCHER_P6(Contains, str1, str2, str3, str4, str5, size, "") {
// Estimating the length of data. We can have gmock provide length
// as well as pointer to data but that will introduce a dependency on tr1
std::string data(arg, size + str1.size() + str2.size() + str3.size() +
str4.size() + str5.size() +
kProtobufEstimatedOverhead);
return (data.find(str1) != std::string::npos &&
data.find(str2) != std::string::npos &&
data.find(str3) != std::string::npos &&
data.find(str4) != std::string::npos &&
data.find(str5) != std::string::npos);
}
MATCHER_P8(Contains, str1, str2, str3, str4, str5, str6, map7, str8, "") {
// Estimating the length of data. We can have gmock provide length
// as well as pointer to data but that will introduce a dependency on tr1
size_t map7_len = 0;
CdmAppParameterMap::const_iterator itr = map7.begin();
for (itr = map7.begin(); itr != map7.end(); ++itr) {
map7_len += itr->first.length();
map7_len += itr->second.length();
}
std::string data(arg, str1.size() + str2.size() + str3.size() + str4.size() +
str5.size() + str6.size() + map7_len + str8.size() +
kProtobufEstimatedOverhead);
bool map7_entries_present = true;
for (itr = map7.begin(); itr != map7.end(); ++itr) {
map7_entries_present = map7_entries_present &&
data.find(itr->first) != std::string::npos &&
data.find(itr->second) != std::string::npos;
}
return (data.find(str1) != std::string::npos &&
data.find(str2) != std::string::npos &&
data.find(str3) != std::string::npos &&
data.find(str4) != std::string::npos &&
data.find(str5) != std::string::npos &&
data.find(str6) != std::string::npos && map7_entries_present &&
data.find(str8) != std::string::npos);
return true;
}
TEST_F(DeviceCertificateTest, StoreCertificate) {
@@ -2208,8 +2139,9 @@ TEST_F(DeviceCertificateTest, StoreCertificate) {
EXPECT_CALL(file_system,
DoOpen(StrEq(device_certificate_path), IsCreateFileFlagSet()))
.WillOnce(Return(file));
EXPECT_CALL(*file, Write(Contains(certificate, wrapped_private_key, 0),
Gt(certificate.size() + wrapped_private_key.size())))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenContains(
std::vector<std::string>{certificate, wrapped_private_key})))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(*file, Read(_, _)).Times(0);
@@ -2291,8 +2223,9 @@ TEST_P(DeviceFilesSecurityLevelTest, SecurityLevel) {
EXPECT_CALL(file_system,
DoOpen(StrEq(device_certificate_path), IsCreateFileFlagSet()))
.WillOnce(Return(file));
EXPECT_CALL(*file, Write(Contains(certificate, wrapped_private_key, 0),
Gt(certificate.size() + wrapped_private_key.size())))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenContains(
std::vector<std::string>{certificate, wrapped_private_key})))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(*file, Read(_, _)).Times(0);
@@ -2314,20 +2247,26 @@ TEST_P(DeviceFilesStoreTest, StoreLicense) {
CdmAppParameterMap app_parameters =
GetAppParameters(kLicenseTestData[license_num].app_parameters);
std::vector<std::string> expected_substrings{
kLicenseTestData[license_num].pssh_data,
kLicenseTestData[license_num].key_request,
kLicenseTestData[license_num].key_response,
kLicenseTestData[license_num].key_renewal_request,
kLicenseTestData[license_num].key_renewal_response,
kLicenseTestData[license_num].key_release_url,
kLicenseTestData[license_num].usage_entry,
};
for (const auto& iter : app_parameters) {
expected_substrings.push_back(iter.first);
expected_substrings.push_back(iter.second);
}
// Call to Open will return a unique_ptr, freeing this object.
MockFile* file = new MockFile();
EXPECT_CALL(file_system, DoOpen(StrEq(license_path), IsCreateFileFlagSet()))
.WillOnce(Return(file));
EXPECT_CALL(
*file,
Write(Contains(kLicenseTestData[license_num].pssh_data,
kLicenseTestData[license_num].key_request,
kLicenseTestData[license_num].key_response,
kLicenseTestData[license_num].key_renewal_request,
kLicenseTestData[license_num].key_renewal_response,
kLicenseTestData[license_num].key_release_url,
app_parameters, kLicenseTestData[license_num].usage_entry),
Gt(GetLicenseDataSize(kLicenseTestData[license_num]))))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenContains(expected_substrings)))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(*file, Read(_, _)).Times(0);
@@ -2366,20 +2305,27 @@ TEST_F(DeviceFilesTest, StoreLicenses) {
CdmAppParameterMap app_parameters =
GetAppParameters(kLicenseTestData[i].app_parameters);
std::vector<std::string> expected_substrings{
kLicenseTestData[i].pssh_data,
kLicenseTestData[i].key_request,
kLicenseTestData[i].key_response,
kLicenseTestData[i].key_renewal_request,
kLicenseTestData[i].key_renewal_response,
kLicenseTestData[i].key_release_url,
kLicenseTestData[i].usage_entry,
};
for (const auto& iter : app_parameters) {
expected_substrings.push_back(iter.first);
expected_substrings.push_back(iter.second);
}
// Call to Open will return a unique_ptr, freeing this object.
MockFile* file = new MockFile();
EXPECT_CALL(file_system, DoOpen(StrEq(license_path), IsCreateFileFlagSet()))
.WillOnce(Return(file));
EXPECT_CALL(*file,
Write(Contains(kLicenseTestData[i].pssh_data,
kLicenseTestData[i].key_request,
kLicenseTestData[i].key_response,
kLicenseTestData[i].key_renewal_request,
kLicenseTestData[i].key_renewal_response,
kLicenseTestData[i].key_release_url,
app_parameters, kLicenseTestData[i].usage_entry),
Gt(GetLicenseDataSize(kLicenseTestData[i]))))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenContains(expected_substrings)))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(*file, Read(_, _)).Times(0);
}
@@ -2539,8 +2485,8 @@ TEST_F(DeviceFilesTest, UpdateLicenseState) {
MockFile* file = new MockFile();
EXPECT_CALL(file_system, DoOpen(StrEq(license_path), IsCreateFileFlagSet()))
.WillOnce(Return(file));
EXPECT_CALL(*file, Write(IsStrEq(kLicenseUpdateTestData[i].file_data),
Eq(kLicenseUpdateTestData[i].file_data.size())))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenEq(kLicenseUpdateTestData[i].file_data)))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(*file, Read(_, _)).Times(0);
DeviceFiles::CdmLicenseData license_data{
@@ -2971,7 +2917,6 @@ TEST_P(DeviceFilesUsageInfoTest, Store) {
std::string file_name = DeviceFiles::GetUsageInfoFileName(app_id);
std::string path = device_base_path_ + file_name;
size_t usage_data_fields_length = 0;
std::vector<std::string> usage_data_fields;
std::vector<DeviceFiles::CdmUsageData> usage_data_list;
@@ -2985,18 +2930,12 @@ TEST_P(DeviceFilesUsageInfoTest, Store) {
usage_data_fields.push_back(kUsageInfoTestData[i].usage_data.license);
usage_data_fields.push_back(kUsageInfoTestData[i].usage_data.key_set_id);
usage_data_fields.push_back(kUsageInfoTestData[i].usage_data.usage_entry);
usage_data_fields_length +=
kUsageInfoTestData[i].usage_data.provider_session_token.size() +
kUsageInfoTestData[i].usage_data.license_request.size() +
kUsageInfoTestData[i].usage_data.license.size() +
kUsageInfoTestData[i].usage_data.key_set_id.size() +
kUsageInfoTestData[i].usage_data.usage_entry.size();
}
}
EXPECT_CALL(file_system, DoOpen(StrEq(path), _)).WillOnce(Return(file));
EXPECT_CALL(*file, Write(ContainsAllElementsInVector(usage_data_fields),
Gt(usage_data_fields_length)))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenContains(usage_data_fields)))
.WillOnce(ReturnArg<1>());
DeviceFiles device_files(&file_system);
@@ -3238,7 +3177,6 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) {
std::string file_name = DeviceFiles::GetUsageInfoFileName(app_id);
std::string path = device_base_path_ + file_name;
size_t usage_data_fields_length = 0;
std::vector<std::string> usage_data_fields;
size_t max_index_by_app_id = 0;
@@ -3256,12 +3194,6 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) {
kUsageInfoTestData[i].usage_data.key_set_id);
usage_data_fields.push_back(
kUsageInfoTestData[i].usage_data.usage_entry);
usage_data_fields_length +=
kUsageInfoTestData[i].usage_data.provider_session_token.size() +
kUsageInfoTestData[i].usage_data.license_request.size() +
kUsageInfoTestData[i].usage_data.license.size() +
kUsageInfoTestData[i].usage_data.key_set_id.size() +
kUsageInfoTestData[i].usage_data.usage_entry.size();
}
}
}
@@ -3273,12 +3205,6 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) {
usage_data_fields.push_back(kUsageInfoUpdateTestData.license);
usage_data_fields.push_back(kUsageInfoUpdateTestData.key_set_id);
usage_data_fields.push_back(kUsageInfoUpdateTestData.usage_entry);
usage_data_fields_length +=
kUsageInfoTestData[index].usage_data.provider_session_token.size() +
kUsageInfoUpdateTestData.license_request.size() +
kUsageInfoUpdateTestData.license.size() +
kUsageInfoUpdateTestData.key_set_id.size() +
kUsageInfoUpdateTestData.usage_entry.size();
}
std::string file_data =
@@ -3305,14 +3231,14 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) {
.Times(2)
.WillOnce(Return(file))
.WillOnce(Return(next_file));
ON_CALL(*file, Write(ContainsAllElementsInVector(usage_data_fields),
Gt(usage_data_fields_length)))
ON_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenContains(usage_data_fields)))
.WillByDefault(DoAll(InvokeWithoutArgs([&write_called]() -> void {
write_called = true;
}),
ReturnArg<1>()));
ON_CALL(*next_file, Write(ContainsAllElementsInVector(usage_data_fields),
Gt(usage_data_fields_length)))
ON_CALL(*next_file, Write(_, _))
.With(AllArgs(StrAndLenContains(usage_data_fields)))
.WillByDefault(DoAll(InvokeWithoutArgs([&write_called]() -> void {
write_called = true;
}),
@@ -3374,8 +3300,9 @@ TEST_P(DeviceFilesHlsAttributesTest, Store) {
EXPECT_CALL(file_system, Exists(StrEq(path))).WillRepeatedly(Return(true));
EXPECT_CALL(file_system, DoOpen(StrEq(path), _)).WillOnce(Return(file));
EXPECT_CALL(*file, Write(Contains(param->media_segment_iv, 0),
Gt(param->media_segment_iv.size())))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(
StrAndLenContains(std::vector<std::string>{param->media_segment_iv})))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(*file, Read(_, _)).Times(0);
@@ -3410,7 +3337,6 @@ TEST_P(DeviceFilesUsageTableTest, Store) {
MockFile* file = new MockFile();
int index = GetParam();
size_t entry_data_length = 0;
std::vector<std::string> entry_data;
std::vector<CdmUsageEntryInfo> usage_entry_info;
usage_entry_info.resize(index + 1);
@@ -3418,18 +3344,15 @@ TEST_P(DeviceFilesUsageTableTest, Store) {
usage_entry_info[i] = kUsageEntriesTestData[i];
entry_data.push_back(kUsageEntriesTestData[i].key_set_id);
entry_data.push_back(kUsageEntriesTestData[i].usage_info_file_name);
entry_data_length += kUsageEntriesTestData[i].key_set_id.size() +
kUsageEntriesTestData[i].usage_info_file_name.size();
}
entry_data.push_back(kUsageTableInfoTestData[index].usage_table_header);
entry_data_length += kUsageTableInfoTestData[index].usage_table_header.size();
std::string path = device_base_path_ + DeviceFiles::GetUsageTableFileName();
EXPECT_CALL(file_system, Exists(StrEq(path))).WillRepeatedly(Return(true));
EXPECT_CALL(file_system, DoOpen(StrEq(path), _)).WillOnce(Return(file));
EXPECT_CALL(*file, Write(ContainsAllElementsInVector(entry_data),
Gt(entry_data_length)))
EXPECT_CALL(*file, Write(_, _))
.With(AllArgs(StrAndLenContains(entry_data)))
.WillOnce(ReturnArg<1>());
EXPECT_CALL(*file, Read(_, _)).Times(0);