diff --git a/libwvdrmengine/cdm/core/include/cdm_engine.h b/libwvdrmengine/cdm/core/include/cdm_engine.h index 15448ad0..cf321551 100644 --- a/libwvdrmengine/cdm/core/include/cdm_engine.h +++ b/libwvdrmengine/cdm/core/include/cdm_engine.h @@ -41,12 +41,33 @@ class CdmEngine { virtual CdmResponseType CloseKeySetSession(const CdmKeySetId& key_set_id); // License related methods - // Construct a valid license request + + // Construct a valid license request. The arguments are used as follows: + // session_id: The Session ID of the session the request is being generated + // for. This is ignored for license release requests. + // key_set_id: The Key Set ID of the key set the request is being generated + // for. This is ignored except for license release requests. + // init_data: The initialization data from the media file, which is used to + // build the key request. This is ignored for release and renewal + // requests. + // license_type: The type of license being requested. Never ignored. + // app_parameters: Additional, application-specific parameters that factor + // into the request generation. This is ignored for release + // and renewal requests. + // key_request: This must be non-null and point to a CdmKeyMessage. The buffer + // will have its contents replaced with the key request. + // server_url: This must be non-null and point to a string. The string will + // have its contents replaced with the default URL (if one is + // known) to send this key request to. + // key_set_id_out: May be null. If it is non-null, the CdmKeySetId pointed to + // will have its contents replaced with the key set ID of the + // session. Note that for non-offline license requests, the + // key set ID is empty, so the CdmKeySetId will be cleared. virtual CdmResponseType GenerateKeyRequest( const CdmSessionId& session_id, const CdmKeySetId& key_set_id, const InitializationData& init_data, const CdmLicenseType license_type, CdmAppParameterMap& app_parameters, CdmKeyMessage* key_request, - std::string* server_url); + std::string* server_url, CdmKeySetId* key_set_id_out); // Accept license response and extract key info. virtual CdmResponseType AddKey(const CdmSessionId& session_id, diff --git a/libwvdrmengine/cdm/core/include/cdm_session.h b/libwvdrmengine/cdm/core/include/cdm_session.h index 1635ec6a..39f11da6 100644 --- a/libwvdrmengine/cdm/core/include/cdm_session.h +++ b/libwvdrmengine/cdm/core/include/cdm_session.h @@ -36,7 +36,7 @@ class CdmSession { virtual CdmResponseType GenerateKeyRequest( const InitializationData& init_data, const CdmLicenseType license_type, const CdmAppParameterMap& app_parameters, CdmKeyMessage* key_request, - std::string* server_url); + std::string* server_url, CdmKeySetId* key_set_id); // AddKey() - Accept license response and extract key info. virtual CdmResponseType AddKey(const CdmKeyResponse& key_response, diff --git a/libwvdrmengine/cdm/core/include/device_files.h b/libwvdrmengine/cdm/core/include/device_files.h index a02e5404..4e9058e1 100644 --- a/libwvdrmengine/cdm/core/include/device_files.h +++ b/libwvdrmengine/cdm/core/include/device_files.h @@ -59,6 +59,7 @@ class DeviceFiles { virtual bool DeleteAllFiles(); virtual bool DeleteAllLicenses(); virtual bool LicenseExists(const std::string& key_set_id); + virtual bool ReserveLicenseId(const std::string& key_set_id); virtual bool StoreUsageInfo(const std::string& provider_session_token, const CdmKeyMessage& key_request, @@ -80,8 +81,9 @@ class DeviceFiles { CdmKeyResponse* license_response); private: - bool StoreFile(const char* name, const std::string& serialized_file); - bool RetrieveFile(const char* name, std::string* serialized_file); + bool StoreFileWithHash(const char* name, const std::string& serialized_file); + bool StoreFileRaw(const char* name, const std::string& serialized_file); + bool RetrieveHashedFile(const char* name, std::string* serialized_file); // Certificate and offline licenses are now stored in security // level specific directories. In an earlier version they were @@ -92,6 +94,7 @@ class DeviceFiles { static std::string GetCertificateFileName(); static std::string GetLicenseFileNameExtension(); static std::string GetUsageInfoFileName(const std::string& app_id); + static std::string GetBlankFileData(); void SetTestFile(File* file); #if defined(UNIT_TEST) FRIEND_TEST(DeviceFilesSecurityLevelTest, SecurityLevel); @@ -99,6 +102,7 @@ class DeviceFiles { FRIEND_TEST(DeviceFilesStoreTest, StoreLicense); FRIEND_TEST(DeviceFilesTest, DeleteLicense); FRIEND_TEST(DeviceFilesTest, ReadCertificate); + FRIEND_TEST(DeviceFilesTest, ReserveLicenseIds); FRIEND_TEST(DeviceFilesTest, RetrieveLicenses); FRIEND_TEST(DeviceFilesTest, SecurityLevelPathBackwardCompatibility); FRIEND_TEST(DeviceFilesTest, StoreLicenses); diff --git a/libwvdrmengine/cdm/core/src/cdm_engine.cpp b/libwvdrmengine/cdm/core/src/cdm_engine.cpp index cf1d1c9a..6bc1fa55 100644 --- a/libwvdrmengine/cdm/core/src/cdm_engine.cpp +++ b/libwvdrmengine/cdm/core/src/cdm_engine.cpp @@ -174,7 +174,8 @@ CdmResponseType CdmEngine::GenerateKeyRequest( const CdmLicenseType license_type, CdmAppParameterMap& app_parameters, CdmKeyMessage* key_request, - std::string* server_url) { + std::string* server_url, + CdmKeySetId* key_set_id_out) { LOGI("CdmEngine::GenerateKeyRequest"); CdmSessionId id = session_id; @@ -227,7 +228,7 @@ CdmResponseType CdmEngine::GenerateKeyRequest( sts = iter->second->GenerateKeyRequest(init_data, license_type, app_parameters, key_request, - server_url); + server_url, key_set_id_out); if (KEY_MESSAGE != sts) { if (sts == NEED_PROVISIONING) { diff --git a/libwvdrmengine/cdm/core/src/cdm_session.cpp b/libwvdrmengine/cdm/core/src/cdm_session.cpp index acfc2847..5deb51eb 100644 --- a/libwvdrmengine/cdm/core/src/cdm_session.cpp +++ b/libwvdrmengine/cdm/core/src/cdm_session.cpp @@ -194,7 +194,7 @@ CdmResponseType CdmSession::RestoreUsageSession( CdmResponseType CdmSession::GenerateKeyRequest( const InitializationData& init_data, const CdmLicenseType license_type, const CdmAppParameterMap& app_parameters, CdmKeyMessage* key_request, - std::string* server_url) { + std::string* server_url, CdmKeySetId* key_set_id) { if (crypto_session_.get() == NULL) { LOGW("CdmSession::GenerateKeyRequest: Invalid crypto session"); return UNKNOWN_ERROR; @@ -229,6 +229,10 @@ CdmResponseType CdmSession::GenerateKeyRequest( LOGW("CdmSession::GenerateKeyRequest: init data absent"); return KEY_ERROR; } + if (is_offline_ && !GenerateKeySetId(&key_set_id_)) { + LOGE("CdmSession::GenerateKeyRequest: Unable to generate key set ID"); + return UNKNOWN_ERROR; + } if (!license_parser_->PrepareKeyRequest(init_data, license_type, app_parameters, session_id_, @@ -242,6 +246,7 @@ CdmResponseType CdmSession::GenerateKeyRequest( offline_release_server_url_ = *server_url; } + if (key_set_id) *key_set_id = key_set_id_; return KEY_MESSAGE; } } @@ -277,7 +282,7 @@ CdmResponseType CdmSession::AddKey(const CdmKeyResponse& key_response, if (sts != NO_ERROR) return sts; } - *key_set_id = key_set_id_; + if (key_set_id) *key_set_id = key_set_id_; return KEY_ADDED; } } @@ -465,13 +470,15 @@ bool CdmSession::GenerateKeySetId(CdmKeySetId* key_set_id) { key_set_id->clear(); } } + // Reserve the license ID to avoid collisions. + file_handle_->ReserveLicenseId(*key_set_id); return true; } CdmResponseType CdmSession::StoreLicense() { if (is_offline_) { - if (!GenerateKeySetId(&key_set_id_)) { - LOGE("CdmSession::StoreLicense: Unable to generate key set Id"); + if (key_set_id_.empty()) { + LOGE("CdmSession::StoreLicense: No key set ID"); return UNKNOWN_ERROR; } diff --git a/libwvdrmengine/cdm/core/src/device_files.cpp b/libwvdrmengine/cdm/core/src/device_files.cpp index 0bee6012..8d499535 100644 --- a/libwvdrmengine/cdm/core/src/device_files.cpp +++ b/libwvdrmengine/cdm/core/src/device_files.cpp @@ -41,6 +41,8 @@ const char* kSecurityLevelPathCompatibilityExclusionList[] = {"ay64.dat"}; size_t kSecurityLevelPathCompatibilityExclusionListSize = sizeof(kSecurityLevelPathCompatibilityExclusionList) / sizeof(*kSecurityLevelPathCompatibilityExclusionList); +// Some platforms cannot store a truly blank file, so we use a W for Widevine. +const char kBlankFileData[] = "W"; bool Hash(const std::string& data, std::string* hash) { if (!hash) return false; @@ -102,7 +104,7 @@ bool DeviceFiles::StoreCertificate(const std::string& certificate, std::string serialized_file; file.SerializeToString(&serialized_file); - return StoreFile(kCertificateFileName, serialized_file); + return StoreFileWithHash(kCertificateFileName, serialized_file); } bool DeviceFiles::RetrieveCertificate(std::string* certificate, @@ -117,7 +119,7 @@ bool DeviceFiles::RetrieveCertificate(std::string* certificate, } std::string serialized_file; - if (!RetrieveFile(kCertificateFileName, &serialized_file)) + if (!RetrieveHashedFile(kCertificateFileName, &serialized_file)) return false; video_widevine_client::sdk::File file; @@ -195,7 +197,7 @@ bool DeviceFiles::StoreLicense(const std::string& key_set_id, file.SerializeToString(&serialized_file); std::string file_name = key_set_id + kLicenseFileNameExt; - return StoreFile(file_name.c_str(), serialized_file); + return StoreFileWithHash(file_name.c_str(), serialized_file); } bool DeviceFiles::RetrieveLicense(const std::string& key_set_id, @@ -214,7 +216,7 @@ bool DeviceFiles::RetrieveLicense(const std::string& key_set_id, std::string serialized_file; std::string file_name = key_set_id + kLicenseFileNameExt; - if (!RetrieveFile(file_name.c_str(), &serialized_file)) return false; + if (!RetrieveHashedFile(file_name.c_str(), &serialized_file)) return false; video_widevine_client::sdk::File file; if (!file.ParseFromString(serialized_file)) { @@ -320,7 +322,7 @@ bool DeviceFiles::LicenseExists(const std::string& key_set_id) { std::string path; if (!Properties::GetDeviceFilesBasePath(security_level_, &path)) { - LOGW("DeviceFiles::StoreFile: Unable to get base path"); + LOGW("DeviceFiles::LicenseExists: Unable to get base path"); return false; } path.append(key_set_id); @@ -329,6 +331,16 @@ bool DeviceFiles::LicenseExists(const std::string& key_set_id) { return file_->Exists(path); } +bool DeviceFiles::ReserveLicenseId(const std::string& key_set_id) { + if (!initialized_) { + LOGW("DeviceFiles::ReserveLicenseId: not initialized"); + return false; + } + + std::string file_name = key_set_id + kLicenseFileNameExt; + return StoreFileRaw(file_name.c_str(), kBlankFileData); +} + bool DeviceFiles::StoreUsageInfo(const std::string& provider_session_token, const CdmKeyMessage& key_request, const CdmKeyResponse& key_response, @@ -341,7 +353,7 @@ bool DeviceFiles::StoreUsageInfo(const std::string& provider_session_token, std::string serialized_file; video_widevine_client::sdk::File file; std::string file_name = GetUsageInfoFileName(app_id); - if (!RetrieveFile(file_name.c_str(), &serialized_file)) { + if (!RetrieveHashedFile(file_name.c_str(), &serialized_file)) { file.set_type(video_widevine_client::sdk::File::USAGE_INFO); file.set_version(video_widevine_client::sdk::File::VERSION_1); } else { @@ -362,7 +374,7 @@ bool DeviceFiles::StoreUsageInfo(const std::string& provider_session_token, key_response.size()); file.SerializeToString(&serialized_file); - return StoreFile(file_name.c_str(), serialized_file); + return StoreFileWithHash(file_name.c_str(), serialized_file); } bool DeviceFiles::DeleteUsageInfo(const std::string& app_id, @@ -374,7 +386,7 @@ bool DeviceFiles::DeleteUsageInfo(const std::string& app_id, std::string serialized_file; std::string file_name = GetUsageInfoFileName(app_id); - if (!RetrieveFile(file_name.c_str(), &serialized_file)) return false; + if (!RetrieveHashedFile(file_name.c_str(), &serialized_file)) return false; video_widevine_client::sdk::File file; if (!file.ParseFromString(serialized_file)) { @@ -406,7 +418,7 @@ bool DeviceFiles::DeleteUsageInfo(const std::string& app_id, sessions->RemoveLast(); file.SerializeToString(&serialized_file); - return StoreFile(file_name.c_str(), serialized_file); + return StoreFileWithHash(file_name.c_str(), serialized_file); } bool DeviceFiles::DeleteAllUsageInfoForApp(const std::string& app_id) { @@ -442,7 +454,7 @@ bool DeviceFiles::RetrieveUsageInfo( std::string serialized_file; std::string file_name = GetUsageInfoFileName(app_id); - if (!RetrieveFile(file_name.c_str(), &serialized_file)) { + if (!RetrieveHashedFile(file_name.c_str(), &serialized_file)) { std::string path; if (!Properties::GetDeviceFilesBasePath(security_level_, &path)) { return false; @@ -484,7 +496,7 @@ bool DeviceFiles::RetrieveUsageInfo(const std::string& app_id, } std::string serialized_file; std::string file_name = GetUsageInfoFileName(app_id); - if (!RetrieveFile(file_name.c_str(), &serialized_file)) return false; + if (!RetrieveHashedFile(file_name.c_str(), &serialized_file)) return false; video_widevine_client::sdk::File file; if (!file.ParseFromString(serialized_file)) { @@ -510,22 +522,22 @@ bool DeviceFiles::RetrieveUsageInfo(const std::string& app_id, return true; } -bool DeviceFiles::StoreFile(const char* name, - const std::string& serialized_file) { +bool DeviceFiles::StoreFileWithHash(const char* name, + const std::string& serialized_file) { if (!file_.get()) { - LOGW("DeviceFiles::StoreFile: Invalid file handle"); + LOGW("DeviceFiles::StoreFileWithHash: Invalid file handle"); return false; } if (!name) { - LOGW("DeviceFiles::StoreFile: Unspecified file name parameter"); + LOGW("DeviceFiles::StoreFileWithHash: Unspecified file name parameter"); return false; } // calculate SHA hash std::string hash; if (!Hash(serialized_file, &hash)) { - LOGW("DeviceFiles::StoreFile: Hash computation failed"); + LOGW("DeviceFiles::StoreFileWithHash: Hash computation failed"); return false; } @@ -537,9 +549,24 @@ bool DeviceFiles::StoreFile(const char* name, std::string serialized_hash_file; hash_file.SerializeToString(&serialized_hash_file); + return StoreFileRaw(name, serialized_hash_file); +} + +bool DeviceFiles::StoreFileRaw(const char* name, + const std::string& serialized_file) { + if (!file_.get()) { + LOGW("DeviceFiles::StoreFileRaw: Invalid file handle"); + return false; + } + + if (!name) { + LOGW("DeviceFiles::StoreFileRaw: Unspecified file name parameter"); + return false; + } + std::string path; if (!Properties::GetDeviceFilesBasePath(security_level_, &path)) { - LOGW("DeviceFiles::StoreFile: Unable to get base path"); + LOGW("DeviceFiles::StoreFileRaw: Unable to get base path"); return false; } @@ -550,40 +577,40 @@ bool DeviceFiles::StoreFile(const char* name, path += name; if (!file_->Open(path, File::kCreate | File::kTruncate | File::kBinary)) { - LOGW("DeviceFiles::StoreFile: File open failed: %s", path.c_str()); + LOGW("DeviceFiles::StoreFileRaw: File open failed: %s", path.c_str()); return false; } - ssize_t bytes = file_->Write(serialized_hash_file.data(), - serialized_hash_file.size()); + ssize_t bytes = file_->Write(serialized_file.data(), + serialized_file.size()); file_->Close(); - if (bytes != static_cast(serialized_hash_file.size())) { - LOGW("DeviceFiles::StoreFile: write failed: (actual: %d, expected: %d)", + if (bytes != static_cast(serialized_file.size())) { + LOGW("DeviceFiles::StoreFileRaw: write failed: (actual: %d, expected: %d)", bytes, - serialized_hash_file.size()); + serialized_file.size()); return false; } - LOGV("DeviceFiles::StoreFile: success: %s (%db)", + LOGV("DeviceFiles::StoreFileRaw: success: %s (%db)", path.c_str(), - serialized_hash_file.size()); + serialized_file.size()); return true; } -bool DeviceFiles::RetrieveFile(const char* name, std::string* serialized_file) { +bool DeviceFiles::RetrieveHashedFile(const char* name, std::string* serialized_file) { if (!file_.get()) { - LOGW("DeviceFiles::RetrieveFile: Invalid file handle"); + LOGW("DeviceFiles::RetrieveHashedFile: Invalid file handle"); return false; } if (!name) { - LOGW("DeviceFiles::RetrieveFile: Unspecified file name parameter"); + LOGW("DeviceFiles::RetrieveHashedFile: Unspecified file name parameter"); return false; } if (!serialized_file) { - LOGW("DeviceFiles::RetrieveFile: Unspecified serialized_file parameter"); + LOGW("DeviceFiles::RetrieveHashedFile: Unspecified serialized_file parameter"); return false; } @@ -596,13 +623,13 @@ bool DeviceFiles::RetrieveFile(const char* name, std::string* serialized_file) { path += name; if (!file_->Exists(path)) { - LOGW("DeviceFiles::RetrieveFile: %s does not exist", path.c_str()); + LOGW("DeviceFiles::RetrieveHashedFile: %s does not exist", path.c_str()); return false; } ssize_t bytes = file_->FileSize(path); if (bytes <= 0) { - LOGW("DeviceFiles::RetrieveFile: File size invalid: %s", path.c_str()); + LOGW("DeviceFiles::RetrieveHashedFile: File size invalid: %s", path.c_str()); // Remove the corrupted file so the caller will not get the same error // when trying to access the file repeatedly, causing the system to stall. file_->Remove(path); @@ -619,27 +646,27 @@ bool DeviceFiles::RetrieveFile(const char* name, std::string* serialized_file) { file_->Close(); if (bytes != static_cast(serialized_hash_file.size())) { - LOGW("DeviceFiles::RetrieveFile: read failed"); + LOGW("DeviceFiles::RetrieveHashedFile: read failed"); return false; } - LOGV("DeviceFiles::RetrieveFile: success: %s (%db)", path.c_str(), + LOGV("DeviceFiles::RetrieveHashedFile: success: %s (%db)", path.c_str(), serialized_hash_file.size()); HashedFile hash_file; if (!hash_file.ParseFromString(serialized_hash_file)) { - LOGW("DeviceFiles::RetrieveFile: Unable to parse hash file"); + LOGW("DeviceFiles::RetrieveHashedFile: Unable to parse hash file"); return false; } std::string hash; if (!Hash(hash_file.file(), &hash)) { - LOGW("DeviceFiles::RetrieveFile: Hash computation failed"); + LOGW("DeviceFiles::RetrieveHashedFile: Hash computation failed"); return false; } if (hash.compare(hash_file.hash())) { - LOGW("DeviceFiles::RetrieveFile: Hash mismatch"); + LOGW("DeviceFiles::RetrieveHashedFile: Hash mismatch"); // Remove the corrupted file so the caller will not get the same error // when trying to access the file repeatedly, causing the system to stall. file_->Remove(path); @@ -733,6 +760,8 @@ std::string DeviceFiles::GetUsageInfoFileName(const std::string& app_id) { std::string(kUsageInfoFileNameExt); } +std::string DeviceFiles::GetBlankFileData() { return kBlankFileData; } + void DeviceFiles::SetTestFile(File* file) { file_.reset(file); test_file_ = true; diff --git a/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp b/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp index 380fe919..78a02987 100644 --- a/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp +++ b/libwvdrmengine/cdm/core/test/cdm_engine_test.cpp @@ -94,7 +94,8 @@ class WvCdmEngineTest : public testing::Test { kLicenseTypeStreaming, app_parameters, &key_msg_, - &server_url)); + &server_url, + NULL)); } void GenerateRenewalRequest() { diff --git a/libwvdrmengine/cdm/core/test/device_files_unittest.cpp b/libwvdrmengine/cdm/core/test/device_files_unittest.cpp index 4bbb1968..53ac7093 100644 --- a/libwvdrmengine/cdm/core/test/device_files_unittest.cpp +++ b/libwvdrmengine/cdm/core/test/device_files_unittest.cpp @@ -16,6 +16,7 @@ using ::testing::AllOf; using ::testing::Eq; using ::testing::Gt; using ::testing::HasSubstr; +using ::testing::InSequence; using ::testing::NotNull; using ::testing::Return; using ::testing::ReturnArg; @@ -1745,6 +1746,36 @@ TEST_F(DeviceFilesTest, DeleteLicense) { EXPECT_FALSE(device_files.LicenseExists(license_test_data[0].key_set_id)); } +TEST_F(DeviceFilesTest, ReserveLicenseIds) { + MockFile file; + EXPECT_CALL(file, IsDirectory(StrEq(device_base_path_))) + .Times(kNumberOfLicenses) + .WillRepeatedly(Return(true)); + EXPECT_CALL(file, CreateDirectory(_)).Times(0); + + for (size_t i = 0; i < kNumberOfLicenses; ++i) { + std::string license_path = device_base_path_ + + license_test_data[i].key_set_id + + DeviceFiles::GetLicenseFileNameExtension(); + InSequence calls; + EXPECT_CALL(file, Open(StrEq(license_path), + AllOf(IsCreateFileFlagSet(), IsBinaryFileFlagSet()))) + .WillOnce(Return(true)); + EXPECT_CALL(file, Write(StrEq(DeviceFiles::GetBlankFileData()), + DeviceFiles::GetBlankFileData().size())) + .WillOnce(ReturnArg<1>()); + EXPECT_CALL(file, Close()); + } + EXPECT_CALL(file, Read(_, _)).Times(0); + + DeviceFiles device_files; + EXPECT_TRUE(device_files.Init(kSecurityLevelL1)); + device_files.SetTestFile(&file); + for (size_t i = 0; i < kNumberOfLicenses; i++) { + EXPECT_TRUE(device_files.ReserveLicenseId(license_test_data[i].key_set_id)); + } +} + TEST_P(DeviceFilesUsageInfoTest, Read) { MockFile file; std::string app_id; // TODO(fredgc): expand tests. diff --git a/libwvdrmengine/cdm/src/wv_content_decryption_module.cpp b/libwvdrmengine/cdm/src/wv_content_decryption_module.cpp index c5fd34a1..b59c6154 100644 --- a/libwvdrmengine/cdm/src/wv_content_decryption_module.cpp +++ b/libwvdrmengine/cdm/src/wv_content_decryption_module.cpp @@ -79,7 +79,7 @@ CdmResponseType WvContentDecryptionModule::GenerateKeyRequest( sts = cdm_engine_->GenerateKeyRequest(session_id, key_set_id, initialization_data, license_type, app_parameters, key_request, - server_url); + server_url, NULL); switch(license_type) { case kLicenseTypeRelease: