diff --git a/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp b/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp index 1859a517..bb56f52f 100644 --- a/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp +++ b/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp @@ -162,6 +162,18 @@ bool isCsrAccessAllowed() { return (uid == AID_ROOT || uid == AID_SYSTEM || uid == AID_SHELL); } +bool IsAtscKeySetId(const CdmKeySetId& keySetId) { + if (keySetId.empty()) return false; + // Pre-installed licenses might not perfectly match ATSC_KEY_SET_ID_PREFIX. + // If "atsc" is in the license name, then it is safe to assume + // it is an ATSC license. + return keySetId.find("atsc") != std::string::npos || + keySetId.find("ATSC") != std::string::npos; +} + +bool IsNotAtscKeySetId(const CdmKeySetId& keySetId) { + return !IsAtscKeySetId(keySetId); +} } // namespace WVDrmPlugin::WVDrmPlugin(const android::sp& cdm, @@ -949,39 +961,53 @@ Status WVDrmPlugin::unprovisionDevice() { ::ndk::ScopedAStatus WVDrmPlugin::getOfflineLicenseKeySetIds( vector<::aidl::android::hardware::drm::KeySetId>* _aidl_return) { - vector> keySetIds; - vector keySetIdsVec; - + _aidl_return->clear(); CdmIdentifier identifier; - auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier); + const auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier); if (status != Status::OK) { - *_aidl_return = keySetIdsVec; return toNdkScopedAStatus(status); } - vector levels = {wvcdm::kSecurityLevelL1, - wvcdm::kSecurityLevelL3}; + const std::vector levels = {wvcdm::kSecurityLevelL1, + wvcdm::kSecurityLevelL3}; + std::vector allKeySetIds; CdmResponseType res(wvcdm::UNKNOWN_ERROR); - + bool success = false; for (auto level : levels) { - vector cdmKeySetIds; - res = mCDM->ListStoredLicenses(level, identifier, &cdmKeySetIds); + std::vector levelKeySetIds; + res = mCDM->ListStoredLicenses(level, identifier, &levelKeySetIds); - if (isCdmResponseTypeSuccess(res)) { - keySetIds.clear(); - for (auto id : cdmKeySetIds) { - keySetIds.push_back(StrToVector(id)); - } - KeySetId kid; - for (auto id : keySetIds) { - kid.keySetId = id; - keySetIdsVec.push_back(kid); + if (!isCdmResponseTypeSuccess(res)) continue; + success = true; + if (levelKeySetIds.empty()) continue; + if (allKeySetIds.empty()) { + allKeySetIds = std::move(levelKeySetIds); + } else { + allKeySetIds.reserve(allKeySetIds.size() + levelKeySetIds.size()); + for (CdmKeySetId& keySetId : levelKeySetIds) { + allKeySetIds.push_back(std::move(keySetId)); } } } - *_aidl_return = keySetIdsVec; - return toNdkScopedAStatus(mapCdmResponseType(res)); + + if (!success) { + // Return whatever the last error was. + return toNdkScopedAStatus(mapCdmResponseType(res)); + } + + // Filter out key sets based on ATSC mode. + const auto isAllowedKeySetId = + mPropertySet.use_atsc_mode() ? IsAtscKeySetId : IsNotAtscKeySetId; + std::vector keySetIds; + for (const CdmKeySetId& keySetId : allKeySetIds) { + if (isAllowedKeySetId(keySetId)) { + keySetIds.push_back(KeySetId{StrToVector(keySetId)}); + } + } + + *_aidl_return = std::move(keySetIds); + return toNdkScopedAStatus(mapCdmResponseType(wvcdm::NO_ERROR)); } ::ndk::ScopedAStatus WVDrmPlugin::getOfflineLicenseState( diff --git a/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp b/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp index c4cd6102..f1bac50b 100644 --- a/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp +++ b/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp @@ -2853,49 +2853,90 @@ TEST_F(WVDrmPluginHalTest, DoesNotSetDecryptHashProperties) { EXPECT_TRUE(ret.isOk()); } -TEST_F(WVDrmPluginHalTest, GetOfflineLicenseIds) { - const uint32_t kLicenseCount = 5; +TEST_F(WVDrmPluginHalTest, GetOfflineLicenseKeySetIds_NonAtscMode) { + const std::vector cdmKeySetIdsL1 = { + // Non-ATSC key set IDs + "ksid1111", "ksid2222", "ksid3333", "ksid4444", + // ATSC key set IDs. + "atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1", + "atsc_group1_profile7"}; + const std::vector cdmKeySetIdsL3 = { + // Non-ATSC key set IDs + "ksid5555", "ksid6666", "ksid7777", "ksid8888", + // ATSC key set IDs. + "atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1", + "atsc_group3_profile7"}; + // Expect non-ATSC key set IDs (order is important). + const std::vector expectedCdmKeySetIds = { + // From L1 + "ksid1111", "ksid2222", "ksid3333", "ksid4444", + // From L3 + "ksid5555", "ksid6666", "ksid7777", "ksid8888"}; - uint8_t mockIdsRaw[kLicenseCount * 2][kKeySetIdSize]; - FILE *fp = fopen("/dev/urandom", "r"); - ASSERT_NE(fp, nullptr) << "Failed to open /dev/urandom"; - for (uint32_t i = 0; i < kLicenseCount * 2; ++i) { - fread(mockIdsRaw[i], sizeof(uint8_t), kKeySetIdSize, fp); - } - fclose(fp); - - std::vector mockIdsL1; - for (uint32_t i = 0; i < kLicenseCount; ++i) { - mockIdsL1.push_back( - std::string(mockIdsRaw[i], mockIdsRaw[i] + kKeySetIdSize)); - } - - std::vector mockIdsL3; - for (uint32_t i = 0; i < kLicenseCount; ++i) { - mockIdsL3.push_back( - std::string(mockIdsRaw[i + 5], mockIdsRaw[i + 5] + kKeySetIdSize)); - } - - EXPECT_CALL(*mCdm, - ListStoredLicenses(kSecurityLevelL1, HasOrigin(EMPTY_ORIGIN), _)) - .WillOnce(DoAll(SetArgPointee<2>(mockIdsL1), + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1), testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); - EXPECT_CALL(*mCdm, - ListStoredLicenses(kSecurityLevelL3, HasOrigin(EMPTY_ORIGIN), _)) - .WillOnce(DoAll(SetArgPointee<2>(mockIdsL3), + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3), testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); - std::vector offlineIds; - auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineIds); - EXPECT_TRUE(ret.isOk()); - - size_t index = 0; - for (auto id : offlineIds) { - EXPECT_THAT(id.keySetId, - ElementsAreArray(mockIdsRaw[index++], kKeySetIdSize)); + std::vector offlineKeySetIds; + const auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineKeySetIds); + ASSERT_TRUE(ret.isOk()); + std::vector offlineCdmKeySetIds; + for (const auto &keySetId : offlineKeySetIds) { + offlineCdmKeySetIds.emplace_back(keySetId.keySetId.begin(), + keySetId.keySetId.end()); } - EXPECT_EQ(kLicenseCount * 2, index); + + EXPECT_EQ(expectedCdmKeySetIds.size(), offlineCdmKeySetIds.size()); + EXPECT_EQ(expectedCdmKeySetIds, offlineCdmKeySetIds); +} + +TEST_F(WVDrmPluginHalTest, GetOfflineLicenseKeySetIds_AtscMode) { + const std::vector cdmKeySetIdsL1 = { + // Non-ATSC key set IDs + "ksid1111", "ksid2222", "ksid3333", "ksid4444", + // ATSC key set IDs. + "atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1", + "atsc_group1_profile7"}; + const std::vector cdmKeySetIdsL3 = { + // Non-ATSC key set IDs + "ksid5555", "ksid6666", "ksid7777", "ksid8888", + // ATSC key set IDs. + "atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1", + "atsc_group3_profile7"}; + // Expect ATSC key set IDs (order is important). + const std::vector expectedCdmKeySetIds = { + // From L1 + "atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1", + "atsc_group1_profile7", + // From L3 + "atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1", + "atsc_group3_profile7"}; + + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + + mPlugin->setPropertyString("atscMode", "enable"); + + std::vector offlineKeySetIds; + const auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineKeySetIds); + ASSERT_TRUE(ret.isOk()); + std::vector offlineCdmKeySetIds; + for (const auto &keySetId : offlineKeySetIds) { + offlineCdmKeySetIds.emplace_back(keySetId.keySetId.begin(), + keySetId.keySetId.end()); + } + + EXPECT_EQ(expectedCdmKeySetIds.size(), offlineCdmKeySetIds.size()); + EXPECT_EQ(expectedCdmKeySetIds, offlineCdmKeySetIds); } TEST_F(WVDrmPluginHalTest, GetOfflineLicenseState) {