diff --git a/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp b/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp index bdda24b0..7b4da59a 100644 --- a/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp +++ b/libwvdrmengine/mediadrm/src/WVDrmPlugin.cpp @@ -1061,29 +1061,42 @@ Status WVDrmPlugin::unprovisionDevice() { ::ndk::ScopedAStatus WVDrmPlugin::removeOfflineLicense( const ::aidl::android::hardware::drm::KeySetId& in_keySetId) { - if (!in_keySetId.keySetId.size()) { + if (in_keySetId.keySetId.empty()) { return toNdkScopedAStatus(Status::BAD_VALUE); } CdmIdentifier identifier; - auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier); + const auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier); if (status != Status::OK) { return toNdkScopedAStatus(status); } - CdmResponseType res(wvcdm::UNKNOWN_ERROR); + const std::vector levels = {wvcdm::kSecurityLevelL1, + wvcdm::kSecurityLevelL3}; + const CdmKeySetId cdmKeySetId(in_keySetId.keySetId.begin(), + in_keySetId.keySetId.end()); - res = mCDM->RemoveOfflineLicense( - std::string(in_keySetId.keySetId.begin(), in_keySetId.keySetId.end()), - wvcdm::kSecurityLevelL1, identifier); - if (!isCdmResponseTypeSuccess(res)) { - CdmResponseType res = mCDM->RemoveOfflineLicense( - std::string(in_keySetId.keySetId.begin(), in_keySetId.keySetId.end()), - wvcdm::kSecurityLevelL3, identifier); - return toNdkScopedAStatus(mapCdmResponseType(res)); + for (const CdmSecurityLevel level : levels) { + std::vector keySetIds; + const CdmResponseType res = + mCDM->ListStoredLicenses(level, identifier, &keySetIds); + if (!isCdmResponseTypeSuccess(res)) { + // This could failure for several reasons, but none that are + // worth returning to the app at this time. + ALOGW("Failed to list stored licenses: res = %d", static_cast(res)); + continue; + } + // Check if exists. + if (keySetIds.empty() || std::find(keySetIds.begin(), keySetIds.end(), + cdmKeySetId) == keySetIds.end()) { + // Does not exist for this security level. + continue; + } + return toNdkScopedAStatus(mapCdmResponseType( + mCDM->RemoveOfflineLicense(cdmKeySetId, level, identifier))); } - - return toNdkScopedAStatus(Status::OK); + // Could only reach this state if the key set could not be found. + return toNdkScopedAStatus(Status::BAD_VALUE); } ::ndk::ScopedAStatus WVDrmPlugin::getPropertyString( diff --git a/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp b/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp index f1bac50b..226fc0d4 100644 --- a/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp +++ b/libwvdrmengine/mediadrm/test/WVDrmPlugin_hal_test.cpp @@ -2971,13 +2971,78 @@ TEST_F(WVDrmPluginHalTest, GetOfflineLicenseState) { ASSERT_EQ(OfflineLicenseState::UNKNOWN, result); } -TEST_F(WVDrmPluginHalTest, RemoveOfflineLicense) { - EXPECT_CALL( - *mCdm, RemoveOfflineLicense(_, kSecurityLevelL1, HasOrigin(EMPTY_ORIGIN))) - .Times(1); +TEST_F(WVDrmPluginHalTest, RemoveOfflineLicense_L1) { + // Key set to remove. + const CdmKeySetId cdmKeySetId = "ksidDEADBEAF"; + const KeySetId keySetId{ + std::vector(cdmKeySetId.begin(), cdmKeySetId.end())}; + // Desired key set ID found in L1. + const std::vector cdmKeySetIdsL1 = {"ksid1234", "ksid9876", + "ksid9999", cdmKeySetId, + "ksidBAD", "ksidCAFEB0BA"}; - auto ret = mPlugin->removeOfflineLicense(keySetId); - EXPECT_TRUE(ret.isOk()); + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + // Only call L1. + EXPECT_CALL(*mCdm, RemoveOfflineLicense(cdmKeySetId, kSecurityLevelL1, _)) + .WillOnce(testing::Return(CdmResponseType(wvcdm::NO_ERROR))); + EXPECT_CALL(*mCdm, RemoveOfflineLicense(_, kSecurityLevelL3, _)).Times(0); + + const auto status = mPlugin->removeOfflineLicense(keySetId); + ASSERT_TRUE(status.isOk()); +} + +TEST_F(WVDrmPluginHalTest, RemoveOfflineLicense_L3) { + // Key set to remove. + const CdmKeySetId cdmKeySetId = "ksidDEADBEAF"; + const KeySetId keySetId{ + std::vector(cdmKeySetId.begin(), cdmKeySetId.end())}; + // Desired key set ID is not found in L1. + const std::vector cdmKeySetIdsL1 = {"ksid1234", "ksid9876", + "ksid9999"}; + // Desired key set ID found in L3. + const std::vector cdmKeySetIdsL3 = { + "ksidDEADC0DE", "ksid1337", cdmKeySetId, "ksidBAD", "ksidCAFEB0BA"}; + + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + // Only call L3. + EXPECT_CALL(*mCdm, RemoveOfflineLicense(_, kSecurityLevelL1, _)).Times(0); + EXPECT_CALL(*mCdm, RemoveOfflineLicense(cdmKeySetId, kSecurityLevelL3, _)) + .WillOnce(testing::Return(CdmResponseType(wvcdm::NO_ERROR))); + + const auto status = mPlugin->removeOfflineLicense(keySetId); + ASSERT_TRUE(status.isOk()); +} + +TEST_F(WVDrmPluginHalTest, RemoveOfflineLicense_NotFound) { + // Key set to remove. + const CdmKeySetId cdmKeySetId = "ksidDEADBEAF"; + const KeySetId keySetId{ + std::vector(cdmKeySetId.begin(), cdmKeySetId.end())}; + // Desired key set ID is not found in L1. + const std::vector cdmKeySetIdsL1 = {"ksid1234", "ksid9876", + "ksid9999"}; + // Desired key set ID is not found in L3. + const std::vector cdmKeySetIdsL3 = {"ksidDEADC0DE", "ksid1337", + "ksidBAD", "ksidCAFEB0BA"}; + + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull())) + .WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3), + testing::Return(CdmResponseType(wvcdm::NO_ERROR)))); + // No call to RemoveOfflineLicense should be made. + EXPECT_CALL(*mCdm, RemoveOfflineLicense(_, _, _)).Times(0); + + const auto status = mPlugin->removeOfflineLicense(keySetId); + ASSERT_FALSE(status.isOk()); } TEST_F(WVDrmPluginHalTest, CanStoreAtscLicense) {