Merge "Filter out key set IDs based on ATSC mode." into udc-widevine-dev

This commit is contained in:
Alex Dale
2023-11-16 22:27:11 +00:00
committed by Android (Google) Code Review
2 changed files with 125 additions and 58 deletions

View File

@@ -162,6 +162,18 @@ bool isCsrAccessAllowed() {
return (uid == AID_ROOT || uid == AID_SYSTEM || uid == AID_SHELL);
}
bool IsAtscKeySetId(const CdmKeySetId& keySetId) {
if (keySetId.empty()) return false;
// Pre-installed licenses might not perfectly match ATSC_KEY_SET_ID_PREFIX.
// If "atsc" is in the license name, then it is safe to assume
// it is an ATSC license.
return keySetId.find("atsc") != std::string::npos ||
keySetId.find("ATSC") != std::string::npos;
}
bool IsNotAtscKeySetId(const CdmKeySetId& keySetId) {
return !IsAtscKeySetId(keySetId);
}
} // namespace
WVDrmPlugin::WVDrmPlugin(const android::sp<WvContentDecryptionModule>& cdm,
@@ -949,39 +961,53 @@ Status WVDrmPlugin::unprovisionDevice() {
::ndk::ScopedAStatus WVDrmPlugin::getOfflineLicenseKeySetIds(
vector<::aidl::android::hardware::drm::KeySetId>* _aidl_return) {
vector<vector<uint8_t>> keySetIds;
vector<KeySetId> keySetIdsVec;
_aidl_return->clear();
CdmIdentifier identifier;
auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier);
const auto status = mCdmIdentifierBuilder.getCdmIdentifier(&identifier);
if (status != Status::OK) {
*_aidl_return = keySetIdsVec;
return toNdkScopedAStatus(status);
}
vector<CdmSecurityLevel> levels = {wvcdm::kSecurityLevelL1,
wvcdm::kSecurityLevelL3};
const std::vector<CdmSecurityLevel> levels = {wvcdm::kSecurityLevelL1,
wvcdm::kSecurityLevelL3};
std::vector<CdmKeySetId> allKeySetIds;
CdmResponseType res(wvcdm::UNKNOWN_ERROR);
bool success = false;
for (auto level : levels) {
vector<CdmKeySetId> cdmKeySetIds;
res = mCDM->ListStoredLicenses(level, identifier, &cdmKeySetIds);
std::vector<CdmKeySetId> levelKeySetIds;
res = mCDM->ListStoredLicenses(level, identifier, &levelKeySetIds);
if (isCdmResponseTypeSuccess(res)) {
keySetIds.clear();
for (auto id : cdmKeySetIds) {
keySetIds.push_back(StrToVector(id));
}
KeySetId kid;
for (auto id : keySetIds) {
kid.keySetId = id;
keySetIdsVec.push_back(kid);
if (!isCdmResponseTypeSuccess(res)) continue;
success = true;
if (levelKeySetIds.empty()) continue;
if (allKeySetIds.empty()) {
allKeySetIds = std::move(levelKeySetIds);
} else {
allKeySetIds.reserve(allKeySetIds.size() + levelKeySetIds.size());
for (CdmKeySetId& keySetId : levelKeySetIds) {
allKeySetIds.push_back(std::move(keySetId));
}
}
}
*_aidl_return = keySetIdsVec;
return toNdkScopedAStatus(mapCdmResponseType(res));
if (!success) {
// Return whatever the last error was.
return toNdkScopedAStatus(mapCdmResponseType(res));
}
// Filter out key sets based on ATSC mode.
const auto isAllowedKeySetId =
mPropertySet.use_atsc_mode() ? IsAtscKeySetId : IsNotAtscKeySetId;
std::vector<KeySetId> keySetIds;
for (const CdmKeySetId& keySetId : allKeySetIds) {
if (isAllowedKeySetId(keySetId)) {
keySetIds.push_back(KeySetId{StrToVector(keySetId)});
}
}
*_aidl_return = std::move(keySetIds);
return toNdkScopedAStatus(mapCdmResponseType(wvcdm::NO_ERROR));
}
::ndk::ScopedAStatus WVDrmPlugin::getOfflineLicenseState(

View File

@@ -2853,49 +2853,90 @@ TEST_F(WVDrmPluginHalTest, DoesNotSetDecryptHashProperties) {
EXPECT_TRUE(ret.isOk());
}
TEST_F(WVDrmPluginHalTest, GetOfflineLicenseIds) {
const uint32_t kLicenseCount = 5;
TEST_F(WVDrmPluginHalTest, GetOfflineLicenseKeySetIds_NonAtscMode) {
const std::vector<CdmKeySetId> cdmKeySetIdsL1 = {
// Non-ATSC key set IDs
"ksid1111", "ksid2222", "ksid3333", "ksid4444",
// ATSC key set IDs.
"atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1",
"atsc_group1_profile7"};
const std::vector<CdmKeySetId> cdmKeySetIdsL3 = {
// Non-ATSC key set IDs
"ksid5555", "ksid6666", "ksid7777", "ksid8888",
// ATSC key set IDs.
"atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1",
"atsc_group3_profile7"};
// Expect non-ATSC key set IDs (order is important).
const std::vector<CdmKeySetId> expectedCdmKeySetIds = {
// From L1
"ksid1111", "ksid2222", "ksid3333", "ksid4444",
// From L3
"ksid5555", "ksid6666", "ksid7777", "ksid8888"};
uint8_t mockIdsRaw[kLicenseCount * 2][kKeySetIdSize];
FILE *fp = fopen("/dev/urandom", "r");
ASSERT_NE(fp, nullptr) << "Failed to open /dev/urandom";
for (uint32_t i = 0; i < kLicenseCount * 2; ++i) {
fread(mockIdsRaw[i], sizeof(uint8_t), kKeySetIdSize, fp);
}
fclose(fp);
std::vector<std::string> mockIdsL1;
for (uint32_t i = 0; i < kLicenseCount; ++i) {
mockIdsL1.push_back(
std::string(mockIdsRaw[i], mockIdsRaw[i] + kKeySetIdSize));
}
std::vector<std::string> mockIdsL3;
for (uint32_t i = 0; i < kLicenseCount; ++i) {
mockIdsL3.push_back(
std::string(mockIdsRaw[i + 5], mockIdsRaw[i + 5] + kKeySetIdSize));
}
EXPECT_CALL(*mCdm,
ListStoredLicenses(kSecurityLevelL1, HasOrigin(EMPTY_ORIGIN), _))
.WillOnce(DoAll(SetArgPointee<2>(mockIdsL1),
EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull()))
.WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1),
testing::Return(CdmResponseType(wvcdm::NO_ERROR))));
EXPECT_CALL(*mCdm,
ListStoredLicenses(kSecurityLevelL3, HasOrigin(EMPTY_ORIGIN), _))
.WillOnce(DoAll(SetArgPointee<2>(mockIdsL3),
EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull()))
.WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3),
testing::Return(CdmResponseType(wvcdm::NO_ERROR))));
std::vector<KeySetId> offlineIds;
auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineIds);
EXPECT_TRUE(ret.isOk());
size_t index = 0;
for (auto id : offlineIds) {
EXPECT_THAT(id.keySetId,
ElementsAreArray(mockIdsRaw[index++], kKeySetIdSize));
std::vector<KeySetId> offlineKeySetIds;
const auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineKeySetIds);
ASSERT_TRUE(ret.isOk());
std::vector<CdmKeySetId> offlineCdmKeySetIds;
for (const auto &keySetId : offlineKeySetIds) {
offlineCdmKeySetIds.emplace_back(keySetId.keySetId.begin(),
keySetId.keySetId.end());
}
EXPECT_EQ(kLicenseCount * 2, index);
EXPECT_EQ(expectedCdmKeySetIds.size(), offlineCdmKeySetIds.size());
EXPECT_EQ(expectedCdmKeySetIds, offlineCdmKeySetIds);
}
TEST_F(WVDrmPluginHalTest, GetOfflineLicenseKeySetIds_AtscMode) {
const std::vector<CdmKeySetId> cdmKeySetIdsL1 = {
// Non-ATSC key set IDs
"ksid1111", "ksid2222", "ksid3333", "ksid4444",
// ATSC key set IDs.
"atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1",
"atsc_group1_profile7"};
const std::vector<CdmKeySetId> cdmKeySetIdsL3 = {
// Non-ATSC key set IDs
"ksid5555", "ksid6666", "ksid7777", "ksid8888",
// ATSC key set IDs.
"atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1",
"atsc_group3_profile7"};
// Expect ATSC key set IDs (order is important).
const std::vector<CdmKeySetId> expectedCdmKeySetIds = {
// From L1
"atscksid1111", "atscksid2222", "atscksid3333", "atsc_group1_profile1",
"atsc_group1_profile7",
// From L3
"atscksid5555", "atscksid6666", "atscksid7777", "atsc_group3_profile1",
"atsc_group3_profile7"};
EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL1, _, NotNull()))
.WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL1),
testing::Return(CdmResponseType(wvcdm::NO_ERROR))));
EXPECT_CALL(*mCdm, ListStoredLicenses(kSecurityLevelL3, _, NotNull()))
.WillOnce(DoAll(SetArgPointee<2>(cdmKeySetIdsL3),
testing::Return(CdmResponseType(wvcdm::NO_ERROR))));
mPlugin->setPropertyString("atscMode", "enable");
std::vector<KeySetId> offlineKeySetIds;
const auto ret = mPlugin->getOfflineLicenseKeySetIds(&offlineKeySetIds);
ASSERT_TRUE(ret.isOk());
std::vector<CdmKeySetId> offlineCdmKeySetIds;
for (const auto &keySetId : offlineKeySetIds) {
offlineCdmKeySetIds.emplace_back(keySetId.keySetId.begin(),
keySetId.keySetId.end());
}
EXPECT_EQ(expectedCdmKeySetIds.size(), offlineCdmKeySetIds.size());
EXPECT_EQ(expectedCdmKeySetIds, offlineCdmKeySetIds);
}
TEST_F(WVDrmPluginHalTest, GetOfflineLicenseState) {