From 1183ae813f62a5eafd08c3d15db1a0d03030e537 Mon Sep 17 00:00:00 2001 From: Alex Dale Date: Thu, 9 Nov 2023 17:25:15 -0800 Subject: [PATCH] Filter out key set IDs based on ATSC mode. [ Partial cherry-pick of http://go/wvgerrit/185854 ] Certain GTS tests do not fully consider restrictions on ATSC devices. In particular, GTS assumes if there are any key set IDs returned to the app via the MediaDrm API, then the device must already be provisioned. ATSC license are special in that they may be available, but the CDM is not provisioned while outside of ATCS mode. To work around this assumption made by GTS, we filter out ATSC licenses returned by getOfflineLicenseKeySetIds() when the device is not in ATSC mode, and filter out non-ATSC license when it is in ATSC mode. This is only a soft enforcement mechanism as calling the API with a valid ATSC license while outside ATSC mode (or a non-TSC license in ATSC mode) will continue to result in the failures experienced by certain OEMs. Bug: 301910628 Bug: 291181955 Bug: 296300842 Bug: 302612540 Test: MediaDrmParameterizedTests GTS on oriole Merged from https://widevine-internal-review.googlesource.com/187610 Merged from https://widevine-internal-review.googlesource.com/187831 Change-Id: Id1508571ebb5c466f43bca99a2d79dc402a2134f --- libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp | 68 +++++++---- .../mediadrm/test/WVDrmPlugin_hal_test.cpp | 115 ++++++++++++------ 2 files changed, 125 insertions(+), 58 deletions(-) diff --git a/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp b/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp index 7002f4ae..bdda24b0 100644 --- a/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp +++ b/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp @@ -160,6 +160,18 @@ bool isRootOrShell() { return (uid == AID_ROOT || uid == AID_SHELL); } +bool IsAtscKeySetId(const CdmKeySetId& keySetId) { + if (keySetId.empty()) return false; + // Pre-installed licenses might not perfectly match ATSC_KEY_SET_ID_PREFIX. + // If "atsc" is in the license name, then it is safe to assume + // it is an ATSC license. + return keySetId.find("atsc") != std::string::npos || + keySetId.find("ATSC") != std::string::npos; +} + +bool IsNotAtscKeySetId(const CdmKeySetId& keySetId) { + return !IsAtscKeySetId(keySetId); +} } // namespace WVDrmPlugin::WVDrmPlugin(const android::sp& cdm, @@ -947,39 +959,53 @@ Status WVDrmPlugin::unprovisionDevice() { ::ndk::ScopedAStatus WVDrmPlugin::getOfflineLicenseKeySetIds( vector<::aidl::android::hardware::drm::KeySetId>* _aidl_return) { - vector> keySetIds; - vector keySetIdsVec; - + _aidl_return->clear(); CdmIdentifier identifier; - auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier); + const auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier); if (status != Status::OK) { - *_aidl_return = keySetIdsVec; return toNdkScopedAStatus(status); } - vector levels = {wvcdm::kSecurityLevelL1, - wvcdm::kSecurityLevelL3}; + const std::vector levels = {wvcdm::kSecurityLevelL1, + wvcdm::kSecurityLevelL3}; + std::vector allKeySetIds; CdmResponseType res(wvcdm::UNKNOWN_ERROR); - + bool success = false; for (auto level : levels) { - vector cdmKeySetIds; - res = mCDM->ListStoredLicenses(level, identifier, &cdmKeySetIds); + std::vector levelKeySetIds; + res = mCDM->ListStoredLicenses(level, identifier, &levelKeySetIds); - if (isCdmResponseTypeSuccess(res)) { - keySetIds.clear(); - for (auto id : cdmKeySetIds) { - keySetIds.push_back(StrToVector(id)); - } - KeySetId kid; - for (auto id : keySetIds) { - kid.keySetId = id; - keySetIdsVec.push_back(kid); + if (!isCdmResponseTypeSuccess(res)) continue; + success = true; + if (levelKeySetIds.empty()) continue; + if (allKeySetIds.empty()) { + allKeySetIds = std::move(levelKeySetIds); + } else { + allKeySetIds.reserve(allKeySetIds.size() + levelKeySetIds.size()); + for (CdmKeySetId& keySetId : levelKeySetIds) { + allKeySetIds.push_back(std::move(keySetId)); } } } - *_aidl_return = keySetIdsVec; - return toNdkScopedAStatus(mapCdmResponseType(res)); + + if (!success) { + // Return whatever the last error was. + return toNdkScopedAStatus(mapCdmResponseType(res)); + } + + // Filter out key sets based on ATSC mode. + const auto isAllowedKeySetId = + mPropertySet.use_atsc_mode() ? IsAtscKeySetId : IsNotAtscKeySetId; + std::vector keySetIds; + for (const CdmKeySetId& keySetId : allKeySetIds) { + if (isAllowedKeySetId(keySetId)) { + keySetIds.push_back(KeySetId{StrToVector(keySetId)}); + } + } + + *_aidl_return = std::move(keySetIds); + return toNdkScopedAStatus(mapCdmResponseType(wvcdm::NO_ERROR)); } ::ndk::ScopedAStatus WVDrmPlugin::getOfflineLicenseState( diff --git a/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp b/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp index c4cd6102..f1bac50b 100644 --- a/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp +++ b/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp @@ -2853,49 +2853,90 @@ TEST_F(WVDrmPluginHalTest, DoesNotSetDecryptHashProperties) { EXPECT_TRUE(ret.isOk()); } -TEST_F(WVDrmPluginHalTest, GetOfflineLicenseIds) { - const uint32_t kLicenseCount = 5; +TEST_F(WVDrmPluginHalTest, GetOfflineLicenseKeySetIds_NonAtscMode) { + const std::vector cdmKeySetIdsL1 = { + // Non-ATSC key set IDs + "ksid1111", "ksid2222", "ksid3333", "ksid4444", + // ATSC key set IDs. + "atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1", + "atsc_group1_profile7"}; + const std::vector cdmKeySetIdsL3 = { + // Non-ATSC key set IDs + "ksid5555", "ksid6666", "ksid7777", "ksid8888", + // ATSC key set IDs. + "atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1", + "atsc_group3_profile7"}; + // Expect non-ATSC key set IDs (order is important). + const std::vector expectedCdmKeySetIds = { + // From L1 + "ksid1111", "ksid2222", "ksid3333", "ksid4444", + // From L3 + "ksid5555", "ksid6666", "ksid7777", "ksid8888"}; - uint8_t mockIdsRaw[kLicenseCount * 2][kKeySetIdSize]; - FILE *fp = fopen("/dev/urandom", "r"); - ASSERT_NE(fp, nullptr) << "Failed to open /dev/urandom"; - for (uint32_t i = 0; i < kLicenseCount * 2; ++i) { - fread(mockIdsRaw[i], sizeof(uint8_t), kKeySetIdSize, fp); - } - fclose(fp); - - std::vector mockIdsL1; - for (uint32_t i = 0; i < kLicenseCount; ++i) { - mockIdsL1.push_back( - std::string(mockIdsRaw[i], mockIdsRaw[i] + kKeySetIdSize)); - } - - std::vector mockIdsL3; - for (uint32_t i = 0; i < kLicenseCount; ++i) { - mockIdsL3.push_back( - std::string(mockIdsRaw[i + 5], mockIdsRaw[i + 5] + kKeySetIdSize)); - } - - EXPECT_CALL(*mCdm, - ListStoredLicenses(kSecurityLevelL1, HasOrigin(EMPTY_ORIGIN), _)) - .WillOnce(DoAll(SetArgPointee<2>(mockIdsL1), + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1), testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); - EXPECT_CALL(*mCdm, - ListStoredLicenses(kSecurityLevelL3, HasOrigin(EMPTY_ORIGIN), _)) - .WillOnce(DoAll(SetArgPointee<2>(mockIdsL3), + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3), testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); - std::vector offlineIds; - auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineIds); - EXPECT_TRUE(ret.isOk()); - - size_t index = 0; - for (auto id : offlineIds) { - EXPECT_THAT(id.keySetId, - ElementsAreArray(mockIdsRaw[index++], kKeySetIdSize)); + std::vector offlineKeySetIds; + const auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineKeySetIds); + ASSERT_TRUE(ret.isOk()); + std::vector offlineCdmKeySetIds; + for (const auto &keySetId : offlineKeySetIds) { + offlineCdmKeySetIds.emplace_back(keySetId.keySetId.begin(), + keySetId.keySetId.end()); } - EXPECT_EQ(kLicenseCount * 2, index); + + EXPECT_EQ(expectedCdmKeySetIds.size(), offlineCdmKeySetIds.size()); + EXPECT_EQ(expectedCdmKeySetIds, offlineCdmKeySetIds); +} + +TEST_F(WVDrmPluginHalTest, GetOfflineLicenseKeySetIds_AtscMode) { + const std::vector cdmKeySetIdsL1 = { + // Non-ATSC key set IDs + "ksid1111", "ksid2222", "ksid3333", "ksid4444", + // ATSC key set IDs. + "atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1", + "atsc_group1_profile7"}; + const std::vector cdmKeySetIdsL3 = { + // Non-ATSC key set IDs + "ksid5555", "ksid6666", "ksid7777", "ksid8888", + // ATSC key set IDs. + "atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1", + "atsc_group3_profile7"}; + // Expect ATSC key set IDs (order is important). + const std::vector expectedCdmKeySetIds = { + // From L1 + "atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1", + "atsc_group1_profile7", + // From L3 + "atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1", + "atsc_group3_profile7"}; + + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + + mPlugin->setPropertyString("atscMode", "enable"); + + std::vector offlineKeySetIds; + const auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineKeySetIds); + ASSERT_TRUE(ret.isOk()); + std::vector offlineCdmKeySetIds; + for (const auto &keySetId : offlineKeySetIds) { + offlineCdmKeySetIds.emplace_back(keySetId.keySetId.begin(), + keySetId.keySetId.end()); + } + + EXPECT_EQ(expectedCdmKeySetIds.size(), offlineCdmKeySetIds.size()); + EXPECT_EQ(expectedCdmKeySetIds, offlineCdmKeySetIds); } TEST_F(WVDrmPluginHalTest, GetOfflineLicenseState) {