From 804c0d470cdd2fe608462e9789902cabdd14859a Mon Sep 17 00:00:00 2001 From: "John W. Bruce" Date: Wed, 24 Jun 2020 14:27:09 -0700 Subject: [PATCH] Rework Device File Matchers to Avoid Buffer Overflow (This is a merge of http://go/wvgerrit/102104) The device file unit tests use some custom matchers that were written back when we didn't have C++11. Because gMock requires std::tuple to pass a pointer AND a length to a matcher, these matchers had to estimate the length of the file. This technically meant they were causing a benign buffer overrun sometimes. Since we have C++11 now, we can fix this by using a matcher over a std::pair of the pointer and length. I also took the opportunity to refactor the matchers a little. The old matchers had many very specific overloads and also collided with the names of some standard gMock matchers. Now there are just two more-general matchers with unique names. Test: CE CDM Unit Tests Test: Android Unit Tests Bug: 159463905 Change-Id: I758b140226bfe2bae6962ee5c64fd6af186b5819 --- .../cdm/core/test/device_files_unittest.cpp | 201 ++++++------------ 1 file changed, 62 insertions(+), 139 deletions(-) diff --git a/libwvdrmengine/cdm/core/test/device_files_unittest.cpp b/libwvdrmengine/cdm/core/test/device_files_unittest.cpp index 94564ea1..09d3f383 100644 --- a/libwvdrmengine/cdm/core/test/device_files_unittest.cpp +++ b/libwvdrmengine/cdm/core/test/device_files_unittest.cpp @@ -9,6 +9,7 @@ #include #include +#include #include "arraysize.h" #include "cdm_random.h" @@ -25,8 +26,6 @@ namespace { const uint32_t kCertificateLen = 700; const uint32_t kWrappedKeyLen = 500; -const uint32_t kProtobufEstimatedOverhead = 200; - const std::string kEmptyString; // Structurally valid test certificate. @@ -2025,6 +2024,7 @@ class MockFileSystem : public FileSystem { // gmock methods using ::testing::_; +using ::testing::AllArgs; using ::testing::AllOf; using ::testing::DoAll; using ::testing::Eq; @@ -2113,87 +2113,18 @@ class DeviceFilesDeleteMultipleUsageInfoTest public ::testing::WithParamInterface {}; MATCHER(IsCreateFileFlagSet, "") { return FileSystem::kCreate & arg; } -MATCHER_P(IsStrEq, str, "") { - // Estimating the length of data. We can have gmock provide length - // as well as pointer to data but that will introduce a dependency on tr1 - return memcmp(arg, str.c_str(), str.size()) == 0; +MATCHER_P(StrAndLenEq, str, "") { + const std::string data(std::get<0>(arg), std::get<1>(arg)); + return data == str; } -MATCHER_P(ContainsAllElementsInVector, str_vector, "") { - // Estimating the length of data. We can have gmock provide length - // as well as pointer to data but that will introduce a dependency on tr1 - size_t str_length = 0; - for (size_t i = 0; i < str_vector.size(); ++i) { - str_length += str_vector[i].size(); - } - std::string data(arg, str_length + kProtobufEstimatedOverhead); - bool all_entries_found = true; - for (size_t i = 0; i < str_vector.size(); ++i) { - if (data.find(str_vector[i]) == std::string::npos) { - all_entries_found = false; +MATCHER_P(StrAndLenContains, str_vector, "") { + const std::string data(std::get<0>(arg), std::get<1>(arg)); + for (const std::string& str : str_vector) { + if (data.find(str) == std::string::npos) { + return false; } } - return all_entries_found; -} -MATCHER_P2(Contains, str1, size, "") { - // Estimating the length of data. We can have gmock provide length - // as well as pointer to data but that will introduce a dependency on tr1 - std::string data(arg, size + str1.size() + kProtobufEstimatedOverhead); - return (data.find(str1) != std::string::npos); -} -MATCHER_P3(Contains, str1, str2, size, "") { - // Estimating the length of data. We can have gmock provide length - // as well as pointer to data but that will introduce a dependency on tr1 - std::string data( - arg, size + str1.size() + str2.size() + kProtobufEstimatedOverhead); - return (data.find(str1) != std::string::npos && - data.find(str2) != std::string::npos); -} -MATCHER_P4(Contains, str1, str2, str3, size, "") { - // Estimating the length of data. We can have gmock provide length - // as well as pointer to data but that will introduce a dependency on tr1 - std::string data(arg, size + str1.size() + str2.size() + str3.size() + - kProtobufEstimatedOverhead); - return (data.find(str1) != std::string::npos && - data.find(str2) != std::string::npos && - data.find(str3) != std::string::npos); -} -MATCHER_P6(Contains, str1, str2, str3, str4, str5, size, "") { - // Estimating the length of data. We can have gmock provide length - // as well as pointer to data but that will introduce a dependency on tr1 - std::string data(arg, size + str1.size() + str2.size() + str3.size() + - str4.size() + str5.size() + - kProtobufEstimatedOverhead); - return (data.find(str1) != std::string::npos && - data.find(str2) != std::string::npos && - data.find(str3) != std::string::npos && - data.find(str4) != std::string::npos && - data.find(str5) != std::string::npos); -} -MATCHER_P8(Contains, str1, str2, str3, str4, str5, str6, map7, str8, "") { - // Estimating the length of data. We can have gmock provide length - // as well as pointer to data but that will introduce a dependency on tr1 - size_t map7_len = 0; - CdmAppParameterMap::const_iterator itr = map7.begin(); - for (itr = map7.begin(); itr != map7.end(); ++itr) { - map7_len += itr->first.length(); - map7_len += itr->second.length(); - } - std::string data(arg, str1.size() + str2.size() + str3.size() + str4.size() + - str5.size() + str6.size() + map7_len + str8.size() + - kProtobufEstimatedOverhead); - bool map7_entries_present = true; - for (itr = map7.begin(); itr != map7.end(); ++itr) { - map7_entries_present = map7_entries_present && - data.find(itr->first) != std::string::npos && - data.find(itr->second) != std::string::npos; - } - return (data.find(str1) != std::string::npos && - data.find(str2) != std::string::npos && - data.find(str3) != std::string::npos && - data.find(str4) != std::string::npos && - data.find(str5) != std::string::npos && - data.find(str6) != std::string::npos && map7_entries_present && - data.find(str8) != std::string::npos); + return true; } TEST_F(DeviceCertificateTest, StoreCertificate) { @@ -2208,8 +2139,9 @@ TEST_F(DeviceCertificateTest, StoreCertificate) { EXPECT_CALL(file_system, DoOpen(StrEq(device_certificate_path), IsCreateFileFlagSet())) .WillOnce(Return(file)); - EXPECT_CALL(*file, Write(Contains(certificate, wrapped_private_key, 0), - Gt(certificate.size() + wrapped_private_key.size()))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenContains( + std::vector{certificate, wrapped_private_key}))) .WillOnce(ReturnArg<1>()); EXPECT_CALL(*file, Read(_, _)).Times(0); @@ -2291,8 +2223,9 @@ TEST_P(DeviceFilesSecurityLevelTest, SecurityLevel) { EXPECT_CALL(file_system, DoOpen(StrEq(device_certificate_path), IsCreateFileFlagSet())) .WillOnce(Return(file)); - EXPECT_CALL(*file, Write(Contains(certificate, wrapped_private_key, 0), - Gt(certificate.size() + wrapped_private_key.size()))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenContains( + std::vector{certificate, wrapped_private_key}))) .WillOnce(ReturnArg<1>()); EXPECT_CALL(*file, Read(_, _)).Times(0); @@ -2314,20 +2247,26 @@ TEST_P(DeviceFilesStoreTest, StoreLicense) { CdmAppParameterMap app_parameters = GetAppParameters(kLicenseTestData[license_num].app_parameters); + std::vector expected_substrings{ + kLicenseTestData[license_num].pssh_data, + kLicenseTestData[license_num].key_request, + kLicenseTestData[license_num].key_response, + kLicenseTestData[license_num].key_renewal_request, + kLicenseTestData[license_num].key_renewal_response, + kLicenseTestData[license_num].key_release_url, + kLicenseTestData[license_num].usage_entry, + }; + for (const auto& iter : app_parameters) { + expected_substrings.push_back(iter.first); + expected_substrings.push_back(iter.second); + } + // Call to Open will return a unique_ptr, freeing this object. MockFile* file = new MockFile(); EXPECT_CALL(file_system, DoOpen(StrEq(license_path), IsCreateFileFlagSet())) .WillOnce(Return(file)); - EXPECT_CALL( - *file, - Write(Contains(kLicenseTestData[license_num].pssh_data, - kLicenseTestData[license_num].key_request, - kLicenseTestData[license_num].key_response, - kLicenseTestData[license_num].key_renewal_request, - kLicenseTestData[license_num].key_renewal_response, - kLicenseTestData[license_num].key_release_url, - app_parameters, kLicenseTestData[license_num].usage_entry), - Gt(GetLicenseDataSize(kLicenseTestData[license_num])))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenContains(expected_substrings))) .WillOnce(ReturnArg<1>()); EXPECT_CALL(*file, Read(_, _)).Times(0); @@ -2366,20 +2305,27 @@ TEST_F(DeviceFilesTest, StoreLicenses) { CdmAppParameterMap app_parameters = GetAppParameters(kLicenseTestData[i].app_parameters); + std::vector expected_substrings{ + kLicenseTestData[i].pssh_data, + kLicenseTestData[i].key_request, + kLicenseTestData[i].key_response, + kLicenseTestData[i].key_renewal_request, + kLicenseTestData[i].key_renewal_response, + kLicenseTestData[i].key_release_url, + kLicenseTestData[i].usage_entry, + }; + for (const auto& iter : app_parameters) { + expected_substrings.push_back(iter.first); + expected_substrings.push_back(iter.second); + } + // Call to Open will return a unique_ptr, freeing this object. MockFile* file = new MockFile(); EXPECT_CALL(file_system, DoOpen(StrEq(license_path), IsCreateFileFlagSet())) .WillOnce(Return(file)); - EXPECT_CALL(*file, - Write(Contains(kLicenseTestData[i].pssh_data, - kLicenseTestData[i].key_request, - kLicenseTestData[i].key_response, - kLicenseTestData[i].key_renewal_request, - kLicenseTestData[i].key_renewal_response, - kLicenseTestData[i].key_release_url, - app_parameters, kLicenseTestData[i].usage_entry), - Gt(GetLicenseDataSize(kLicenseTestData[i])))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenContains(expected_substrings))) .WillOnce(ReturnArg<1>()); EXPECT_CALL(*file, Read(_, _)).Times(0); } @@ -2539,8 +2485,8 @@ TEST_F(DeviceFilesTest, UpdateLicenseState) { MockFile* file = new MockFile(); EXPECT_CALL(file_system, DoOpen(StrEq(license_path), IsCreateFileFlagSet())) .WillOnce(Return(file)); - EXPECT_CALL(*file, Write(IsStrEq(kLicenseUpdateTestData[i].file_data), - Eq(kLicenseUpdateTestData[i].file_data.size()))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenEq(kLicenseUpdateTestData[i].file_data))) .WillOnce(ReturnArg<1>()); EXPECT_CALL(*file, Read(_, _)).Times(0); DeviceFiles::CdmLicenseData license_data{ @@ -2971,7 +2917,6 @@ TEST_P(DeviceFilesUsageInfoTest, Store) { std::string file_name = DeviceFiles::GetUsageInfoFileName(app_id); std::string path = device_base_path_ + file_name; - size_t usage_data_fields_length = 0; std::vector usage_data_fields; std::vector usage_data_list; @@ -2985,18 +2930,12 @@ TEST_P(DeviceFilesUsageInfoTest, Store) { usage_data_fields.push_back(kUsageInfoTestData[i].usage_data.license); usage_data_fields.push_back(kUsageInfoTestData[i].usage_data.key_set_id); usage_data_fields.push_back(kUsageInfoTestData[i].usage_data.usage_entry); - usage_data_fields_length += - kUsageInfoTestData[i].usage_data.provider_session_token.size() + - kUsageInfoTestData[i].usage_data.license_request.size() + - kUsageInfoTestData[i].usage_data.license.size() + - kUsageInfoTestData[i].usage_data.key_set_id.size() + - kUsageInfoTestData[i].usage_data.usage_entry.size(); } } EXPECT_CALL(file_system, DoOpen(StrEq(path), _)).WillOnce(Return(file)); - EXPECT_CALL(*file, Write(ContainsAllElementsInVector(usage_data_fields), - Gt(usage_data_fields_length))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenContains(usage_data_fields))) .WillOnce(ReturnArg<1>()); DeviceFiles device_files(&file_system); @@ -3238,7 +3177,6 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) { std::string file_name = DeviceFiles::GetUsageInfoFileName(app_id); std::string path = device_base_path_ + file_name; - size_t usage_data_fields_length = 0; std::vector usage_data_fields; size_t max_index_by_app_id = 0; @@ -3256,12 +3194,6 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) { kUsageInfoTestData[i].usage_data.key_set_id); usage_data_fields.push_back( kUsageInfoTestData[i].usage_data.usage_entry); - usage_data_fields_length += - kUsageInfoTestData[i].usage_data.provider_session_token.size() + - kUsageInfoTestData[i].usage_data.license_request.size() + - kUsageInfoTestData[i].usage_data.license.size() + - kUsageInfoTestData[i].usage_data.key_set_id.size() + - kUsageInfoTestData[i].usage_data.usage_entry.size(); } } } @@ -3273,12 +3205,6 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) { usage_data_fields.push_back(kUsageInfoUpdateTestData.license); usage_data_fields.push_back(kUsageInfoUpdateTestData.key_set_id); usage_data_fields.push_back(kUsageInfoUpdateTestData.usage_entry); - usage_data_fields_length += - kUsageInfoTestData[index].usage_data.provider_session_token.size() + - kUsageInfoUpdateTestData.license_request.size() + - kUsageInfoUpdateTestData.license.size() + - kUsageInfoUpdateTestData.key_set_id.size() + - kUsageInfoUpdateTestData.usage_entry.size(); } std::string file_data = @@ -3305,14 +3231,14 @@ TEST_P(DeviceFilesUsageInfoTest, UpdateUsageInfo) { .Times(2) .WillOnce(Return(file)) .WillOnce(Return(next_file)); - ON_CALL(*file, Write(ContainsAllElementsInVector(usage_data_fields), - Gt(usage_data_fields_length))) + ON_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenContains(usage_data_fields))) .WillByDefault(DoAll(InvokeWithoutArgs([&write_called]() -> void { write_called = true; }), ReturnArg<1>())); - ON_CALL(*next_file, Write(ContainsAllElementsInVector(usage_data_fields), - Gt(usage_data_fields_length))) + ON_CALL(*next_file, Write(_, _)) + .With(AllArgs(StrAndLenContains(usage_data_fields))) .WillByDefault(DoAll(InvokeWithoutArgs([&write_called]() -> void { write_called = true; }), @@ -3374,8 +3300,9 @@ TEST_P(DeviceFilesHlsAttributesTest, Store) { EXPECT_CALL(file_system, Exists(StrEq(path))).WillRepeatedly(Return(true)); EXPECT_CALL(file_system, DoOpen(StrEq(path), _)).WillOnce(Return(file)); - EXPECT_CALL(*file, Write(Contains(param->media_segment_iv, 0), - Gt(param->media_segment_iv.size()))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs( + StrAndLenContains(std::vector{param->media_segment_iv}))) .WillOnce(ReturnArg<1>()); EXPECT_CALL(*file, Read(_, _)).Times(0); @@ -3410,7 +3337,6 @@ TEST_P(DeviceFilesUsageTableTest, Store) { MockFile* file = new MockFile(); int index = GetParam(); - size_t entry_data_length = 0; std::vector entry_data; std::vector usage_entry_info; usage_entry_info.resize(index + 1); @@ -3418,18 +3344,15 @@ TEST_P(DeviceFilesUsageTableTest, Store) { usage_entry_info[i] = kUsageEntriesTestData[i]; entry_data.push_back(kUsageEntriesTestData[i].key_set_id); entry_data.push_back(kUsageEntriesTestData[i].usage_info_file_name); - entry_data_length += kUsageEntriesTestData[i].key_set_id.size() + - kUsageEntriesTestData[i].usage_info_file_name.size(); } entry_data.push_back(kUsageTableInfoTestData[index].usage_table_header); - entry_data_length += kUsageTableInfoTestData[index].usage_table_header.size(); std::string path = device_base_path_ + DeviceFiles::GetUsageTableFileName(); EXPECT_CALL(file_system, Exists(StrEq(path))).WillRepeatedly(Return(true)); EXPECT_CALL(file_system, DoOpen(StrEq(path), _)).WillOnce(Return(file)); - EXPECT_CALL(*file, Write(ContainsAllElementsInVector(entry_data), - Gt(entry_data_length))) + EXPECT_CALL(*file, Write(_, _)) + .With(AllArgs(StrAndLenContains(entry_data))) .WillOnce(ReturnArg<1>()); EXPECT_CALL(*file, Read(_, _)).Times(0);