diff --git a/libwvdrmengine/cdm/core/include/cdm_engine.h b/libwvdrmengine/cdm/core/include/cdm_engine.h index f6c3afec..ae58eaaf 100644 --- a/libwvdrmengine/cdm/core/include/cdm_engine.h +++ b/libwvdrmengine/cdm/core/include/cdm_engine.h @@ -3,6 +3,7 @@ #ifndef CDM_BASE_CDM_ENGINE_H_ #define CDM_BASE_CDM_ENGINE_H_ +#include "crypto_session.h" #include "timer.h" #include "wv_cdm_types.h" @@ -105,7 +106,7 @@ class CdmEngine : public TimerHandler { // private methods // Cancel all sessions bool CancelSessions(); - void CleanupProvisioningSession(const CdmSessionId& cdm_session_id); + void CleanupProvisioningSession(); void ComposeJsonRequestAsQueryString(const std::string& message, CdmProvisioningRequest* request); diff --git a/libwvdrmengine/cdm/core/src/cdm_engine.cpp b/libwvdrmengine/cdm/core/src/cdm_engine.cpp index fe517a13..3a857f03 100644 --- a/libwvdrmengine/cdm/core/src/cdm_engine.cpp +++ b/libwvdrmengine/cdm/core/src/cdm_engine.cpp @@ -327,9 +327,29 @@ CdmResponseType CdmEngine::QueryKeyControlInfo( return iter->second->QueryKeyControlInfo(key_info); } -void CdmEngine::CleanupProvisioningSession(const CdmSessionId& cdm_session_id) { - CloseSession(cdm_session_id); - provisioning_session_ = NULL; +/* + * The certificate provisioning process creates a cdm and a crypto session. + * The lives of these sessions are short and therefore, not added to the + * CdmSessionMap. We need to explicitly delete these objects when error occurs + * or when we are done with provisioning. + */ +void CdmEngine::CleanupProvisioningSession() { + if (provisioning_session_) { + CryptoEngine* crypto_engine = CryptoEngine::GetInstance(); + if (crypto_engine) { + CdmSessionId cdm_session_id = provisioning_session_->session_id(); + CryptoSession* crypto_session = + crypto_engine->FindSession(cdm_session_id); + if (crypto_session) { + LOGV("delete crypto session for id=%s", cdm_session_id.c_str()); + delete crypto_session; + } else { + LOGE("CleanupProvisioningSession: cannot find crypto_session"); + } + } + delete provisioning_session_; + provisioning_session_ = NULL; + } } /* @@ -382,8 +402,7 @@ CdmResponseType CdmEngine::GetProvisioningRequest( default_url->assign(kDefaultProvisioningServerUrl); if (provisioning_session_) { - LOGE("GetProvisioningRequest: duplicate provisioning request?"); - return UNKNOWN_ERROR; + CleanupProvisioningSession(); } // @@ -418,7 +437,10 @@ CdmResponseType CdmEngine::GetProvisioningRequest( return UNKNOWN_ERROR; } + // TODO(edwinwong): replace this cdm session pointer with crypto session + // pointer if feasible provisioning_session_ = cdm_session; + LOGV("provisioning session id=%s", cdm_session_id.c_str()); // //--------------------------------------------------------------------------- @@ -430,7 +452,7 @@ CdmResponseType CdmEngine::GetProvisioningRequest( std::string token; if (!crypto_engine->GetToken(&token)) { LOGE("GetProvisioningRequest: fails to get token"); - CleanupProvisioningSession(cdm_session_id); + CleanupProvisioningSession(); return UNKNOWN_ERROR; } client_id->set_token(token); @@ -438,8 +460,7 @@ CdmResponseType CdmEngine::GetProvisioningRequest( uint32_t nonce; if (!crypto_session->GenerateNonce(&nonce)) { LOGE("GetProvisioningRequest: fails to generate a nonce"); - crypto_engine->DestroySession(cdm_session_id); - CleanupProvisioningSession(cdm_session_id); + CleanupProvisioningSession(); return UNKNOWN_ERROR; } @@ -456,12 +477,12 @@ CdmResponseType CdmEngine::GetProvisioningRequest( if (!crypto_session->PrepareRequest(serialized_message, &request_signature, true)) { request->clear(); - CleanupProvisioningSession(cdm_session_id); + CleanupProvisioningSession(); return UNKNOWN_ERROR; } if (request_signature.empty()) { request->clear(); - CleanupProvisioningSession(cdm_session_id); + CleanupProvisioningSession(); return UNKNOWN_ERROR; } @@ -568,14 +589,14 @@ CdmResponseType CdmEngine::HandleProvisioningResponse( SignedProvisioningMessage signed_response; if (!signed_response.ParseFromString(serialized_signed_response)) { LOGE("Fails to parse signed serialized response"); - CleanupProvisioningSession(cdm_session_id); - return UNKNOWN_ERROR; + CleanupProvisioningSession(); + return UNKNOWN_ERROR; } if (!signed_response.has_signature() || !signed_response.has_message()) { LOGE("Invalid response - signature or message not found"); - CleanupProvisioningSession(cdm_session_id); - return UNKNOWN_ERROR; + CleanupProvisioningSession(); + return UNKNOWN_ERROR; } const std::string& signed_message = signed_response.message(); @@ -583,13 +604,13 @@ CdmResponseType CdmEngine::HandleProvisioningResponse( if (!provisioning_response.ParseFromString(signed_message)) { LOGE("Fails to parse signed message"); - CleanupProvisioningSession(cdm_session_id); + CleanupProvisioningSession(); return UNKNOWN_ERROR; } if (!provisioning_response.has_device_rsa_key()) { LOGE("Invalid response - key not found"); - CleanupProvisioningSession(cdm_session_id); + CleanupProvisioningSession(); return UNKNOWN_ERROR; } @@ -607,7 +628,7 @@ CdmResponseType CdmEngine::HandleProvisioningResponse( rsa_key_iv, &wrapped_rsa_key)) { LOGE("HandleProvisioningResponse: RewrapDeviceRSAKey fails"); - CleanupProvisioningSession(cdm_session_id); + CleanupProvisioningSession(); return UNKNOWN_ERROR; } @@ -616,10 +637,9 @@ CdmResponseType CdmEngine::HandleProvisioningResponse( // //--------------------------------------------------------------------------- - // Closes the cdm session. + // Deletes cdm and crypto sessions created for provisioning. // - CleanupProvisioningSession(cdm_session_id); - + CleanupProvisioningSession(); return NO_ERROR; } diff --git a/libwvdrmengine/cdm/test/request_license_test.cpp b/libwvdrmengine/cdm/test/request_license_test.cpp index 73bc805d..d79f611f 100644 --- a/libwvdrmengine/cdm/test/request_license_test.cpp +++ b/libwvdrmengine/cdm/test/request_license_test.cpp @@ -77,6 +77,81 @@ class WvCdmRequestLicenseTest : public testing::Test { EXPECT_NE(0, static_cast(server_url.size())); } + void DumpResponse(const std::string& description, + const std::string& response) { + if (description.empty()) + return; + + LOGD("%s (%d bytes):", description.c_str(), response.size()); + + size_t remaining = response.size(); + size_t portion = 0; + size_t start = 0; + while (remaining > 0) { + // LOGX may not get to empty its buffer if it is too large, + // pick an arbitrary small size to be safe + portion = (remaining < 512) ? remaining : 512; + LOGD("%s", response.substr(start, portion).c_str()); + start += portion; + remaining -= portion; + } + LOGD("total bytes dumped(%d)", start); + } + + // concatinates all chunks into one blob + // TODO (edwinwong) move this function to url_request class as GetMessageBody + void ConcatenateChunkedResponse(const std::string http_response, + std::string* message_body) { + if (http_response.empty()) + return; + + message_body->clear(); + const std::string kChunkedTag = "Transfer-Encoding: chunked\r\n\r\n"; + size_t chunked_tag_pos = http_response.find(kChunkedTag); + if (std::string::npos != chunked_tag_pos) { + // processes chunked encoding + size_t chunk_size = 0; + size_t chunk_size_pos = chunked_tag_pos + kChunkedTag.size(); + sscanf(&http_response[chunk_size_pos], "%x", &chunk_size); + if (chunk_size > http_response.size()) { + // precaution, in case we misread chunk size + LOGE("invalid chunk size %u", chunk_size); + return; + } + + /* + * searches for chunks + * + * header + * chunk size\r\n <-- chunk_size_pos @ beginning of chunk size + * chunk data\r\n <-- chunk_pos @ beginning of chunk data + * chunk size\r\n + * chunk data\r\n + * 0\r\n + */ + const std::string kCrLf = "\r\n"; + size_t chunk_pos = http_response.find(kCrLf, chunk_size_pos) + + kCrLf.size(); + message_body->assign(&http_response[0], chunk_size_pos); + while ((chunk_size > 0) && (std::string::npos != chunk_pos)) { + message_body->append(&http_response[chunk_pos], chunk_size); + + // searches for next chunk + chunk_size_pos = chunk_pos + chunk_size + kCrLf.size(); + sscanf(&http_response[chunk_size_pos], "%x", &chunk_size); + if (chunk_size > http_response.size()) { + // precaution, in case we misread chunk size + LOGE("invalid chunk size %u", chunk_size); + break; + } + chunk_pos = http_response.find(kCrLf, chunk_size_pos) + kCrLf.size(); + } + } else { + // response is not chunked encoded + message_body->assign(http_response); + } + } + // posts a request and extracts the drm message from the response std::string GetKeyRequestResponse(const std::string& server_url, const std::string& client_auth, @@ -88,14 +163,17 @@ class WvCdmRequestLicenseTest : public testing::Test { } url_request.PostRequest(key_msg_); - std::string response; - int resp_bytes = url_request.GetResponse(response); - LOGD("response:\r\n%s", response.c_str()); + std::string http_response; + std::string message_body; + int resp_bytes = url_request.GetResponse(http_response); + if (resp_bytes) { + ConcatenateChunkedResponse(http_response, &message_body); + } LOGD("end %d bytes response dump", resp_bytes); // Youtube server returns 400 for invalid message while play server returns // 500, so just test inequity here for invalid message - int status_code = url_request.GetStatusCode(response); + int status_code = url_request.GetStatusCode(message_body); if (expected_response == 200) { EXPECT_EQ(200, status_code); } else { @@ -105,7 +183,7 @@ class WvCdmRequestLicenseTest : public testing::Test { std::string drm_msg; if (200 == status_code) { LicenseRequest lic_request; - lic_request.GetDrmMessage(response, drm_msg); + lic_request.GetDrmMessage(message_body, drm_msg); LOGV("drm msg: %u bytes\r\n%s", drm_msg.size(), HexEncode(reinterpret_cast(drm_msg.data()), drm_msg.size()).c_str()); @@ -123,24 +201,23 @@ class WvCdmRequestLicenseTest : public testing::Test { } url_request.PostCertRequestInQueryString(key_msg_); - std::string response; - int resp_bytes = url_request.GetResponse(response); + std::string http_response; + std::string message_body; + int resp_bytes = url_request.GetResponse(http_response); if (resp_bytes) { - LOGD("size=%u, response start:\t\rn%s", response.size(), - response.substr(0, 1024).c_str()); - LOGD("end:\r\n%s", response.substr(response.size() - 256).c_str()); + ConcatenateChunkedResponse(http_response, &message_body); } LOGD("end %d bytes response dump", resp_bytes); // Youtube server returns 400 for invalid message while play server returns // 500, so just test inequity here for invalid message - int status_code = url_request.GetStatusCode(response); + int status_code = url_request.GetStatusCode(message_body); if (expected_response == 200) { EXPECT_EQ(200, status_code); } else { EXPECT_NE(200, status_code); } - return response; + return message_body; } void VerifyKeyRequestResponse(const std::string& server_url, @@ -170,12 +247,28 @@ TEST_F(WvCdmRequestLicenseTest, ProvisioningTest) { decryptor_.OpenSession(g_key_system, &session_id_); std::string provisioning_server_url = ""; - decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url); + EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url)); EXPECT_STREQ(provisioning_server_url.data(), kDefaultProvisioningServerUrl.data()); std::string response = GetCertRequestResponse(kDefaultProvisioningServerUrl, 200); - if (!response.empty()) - decryptor_.HandleProvisioningResponse(response); + EXPECT_NE(0, static_cast(response.size())); + EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.HandleProvisioningResponse(response)); + decryptor_.CloseSession(session_id_); +} + +TEST_F(WvCdmRequestLicenseTest, ProvisioningRetryTest) { + decryptor_.OpenSession(g_key_system, &session_id_); + std::string provisioning_server_url = ""; + + EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url)); + EXPECT_STREQ(provisioning_server_url.data(), kDefaultProvisioningServerUrl.data()); + + EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url)); + EXPECT_STREQ(provisioning_server_url.data(), kDefaultProvisioningServerUrl.data()); + + std::string response = GetCertRequestResponse(kDefaultProvisioningServerUrl, 200); + EXPECT_NE(0, static_cast(response.size())); + EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.HandleProvisioningResponse(response)); decryptor_.CloseSession(session_id_); }