am 7aa99d4a: Squashed commit of 3 CLs related to provisioning retries

* commit '7aa99d4a36434af618db6240d4a74ea1ef641802':
  Squashed commit of 3 CLs related to provisioning retries
This commit is contained in:
Jeff Tinker
2013-05-07 10:25:30 -07:00
committed by Android Git Automerger
3 changed files with 150 additions and 36 deletions

View File

@@ -3,6 +3,7 @@
#ifndef CDM_BASE_CDM_ENGINE_H_
#define CDM_BASE_CDM_ENGINE_H_
#include "crypto_session.h"
#include "timer.h"
#include "wv_cdm_types.h"
@@ -105,7 +106,7 @@ class CdmEngine : public TimerHandler {
// private methods
// Cancel all sessions
bool CancelSessions();
void CleanupProvisioningSession(const CdmSessionId& cdm_session_id);
void CleanupProvisioningSession();
void ComposeJsonRequestAsQueryString(const std::string& message,
CdmProvisioningRequest* request);

View File

@@ -327,9 +327,29 @@ CdmResponseType CdmEngine::QueryKeyControlInfo(
return iter->second->QueryKeyControlInfo(key_info);
}
void CdmEngine::CleanupProvisioningSession(const CdmSessionId& cdm_session_id) {
CloseSession(cdm_session_id);
provisioning_session_ = NULL;
/*
* The certificate provisioning process creates a cdm and a crypto session.
* The lives of these sessions are short and therefore, not added to the
* CdmSessionMap. We need to explicitly delete these objects when error occurs
* or when we are done with provisioning.
*/
void CdmEngine::CleanupProvisioningSession() {
if (provisioning_session_) {
CryptoEngine* crypto_engine = CryptoEngine::GetInstance();
if (crypto_engine) {
CdmSessionId cdm_session_id = provisioning_session_->session_id();
CryptoSession* crypto_session =
crypto_engine->FindSession(cdm_session_id);
if (crypto_session) {
LOGV("delete crypto session for id=%s", cdm_session_id.c_str());
delete crypto_session;
} else {
LOGE("CleanupProvisioningSession: cannot find crypto_session");
}
}
delete provisioning_session_;
provisioning_session_ = NULL;
}
}
/*
@@ -382,8 +402,7 @@ CdmResponseType CdmEngine::GetProvisioningRequest(
default_url->assign(kDefaultProvisioningServerUrl);
if (provisioning_session_) {
LOGE("GetProvisioningRequest: duplicate provisioning request?");
return UNKNOWN_ERROR;
CleanupProvisioningSession();
}
//
@@ -418,7 +437,10 @@ CdmResponseType CdmEngine::GetProvisioningRequest(
return UNKNOWN_ERROR;
}
// TODO(edwinwong): replace this cdm session pointer with crypto session
// pointer if feasible
provisioning_session_ = cdm_session;
LOGV("provisioning session id=%s", cdm_session_id.c_str());
//
//---------------------------------------------------------------------------
@@ -430,7 +452,7 @@ CdmResponseType CdmEngine::GetProvisioningRequest(
std::string token;
if (!crypto_engine->GetToken(&token)) {
LOGE("GetProvisioningRequest: fails to get token");
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
client_id->set_token(token);
@@ -438,8 +460,7 @@ CdmResponseType CdmEngine::GetProvisioningRequest(
uint32_t nonce;
if (!crypto_session->GenerateNonce(&nonce)) {
LOGE("GetProvisioningRequest: fails to generate a nonce");
crypto_engine->DestroySession(cdm_session_id);
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
@@ -456,12 +477,12 @@ CdmResponseType CdmEngine::GetProvisioningRequest(
if (!crypto_session->PrepareRequest(serialized_message,
&request_signature, true)) {
request->clear();
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
if (request_signature.empty()) {
request->clear();
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
@@ -568,14 +589,14 @@ CdmResponseType CdmEngine::HandleProvisioningResponse(
SignedProvisioningMessage signed_response;
if (!signed_response.ParseFromString(serialized_signed_response)) {
LOGE("Fails to parse signed serialized response");
CleanupProvisioningSession(cdm_session_id);
return UNKNOWN_ERROR;
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
if (!signed_response.has_signature() || !signed_response.has_message()) {
LOGE("Invalid response - signature or message not found");
CleanupProvisioningSession(cdm_session_id);
return UNKNOWN_ERROR;
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
const std::string& signed_message = signed_response.message();
@@ -583,13 +604,13 @@ CdmResponseType CdmEngine::HandleProvisioningResponse(
if (!provisioning_response.ParseFromString(signed_message)) {
LOGE("Fails to parse signed message");
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
if (!provisioning_response.has_device_rsa_key()) {
LOGE("Invalid response - key not found");
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
@@ -607,7 +628,7 @@ CdmResponseType CdmEngine::HandleProvisioningResponse(
rsa_key_iv,
&wrapped_rsa_key)) {
LOGE("HandleProvisioningResponse: RewrapDeviceRSAKey fails");
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return UNKNOWN_ERROR;
}
@@ -616,10 +637,9 @@ CdmResponseType CdmEngine::HandleProvisioningResponse(
//
//---------------------------------------------------------------------------
// Closes the cdm session.
// Deletes cdm and crypto sessions created for provisioning.
//
CleanupProvisioningSession(cdm_session_id);
CleanupProvisioningSession();
return NO_ERROR;
}

View File

@@ -77,6 +77,81 @@ class WvCdmRequestLicenseTest : public testing::Test {
EXPECT_NE(0, static_cast<int>(server_url.size()));
}
void DumpResponse(const std::string& description,
const std::string& response) {
if (description.empty())
return;
LOGD("%s (%d bytes):", description.c_str(), response.size());
size_t remaining = response.size();
size_t portion = 0;
size_t start = 0;
while (remaining > 0) {
// LOGX may not get to empty its buffer if it is too large,
// pick an arbitrary small size to be safe
portion = (remaining < 512) ? remaining : 512;
LOGD("%s", response.substr(start, portion).c_str());
start += portion;
remaining -= portion;
}
LOGD("total bytes dumped(%d)", start);
}
// concatinates all chunks into one blob
// TODO (edwinwong) move this function to url_request class as GetMessageBody
void ConcatenateChunkedResponse(const std::string http_response,
std::string* message_body) {
if (http_response.empty())
return;
message_body->clear();
const std::string kChunkedTag = "Transfer-Encoding: chunked\r\n\r\n";
size_t chunked_tag_pos = http_response.find(kChunkedTag);
if (std::string::npos != chunked_tag_pos) {
// processes chunked encoding
size_t chunk_size = 0;
size_t chunk_size_pos = chunked_tag_pos + kChunkedTag.size();
sscanf(&http_response[chunk_size_pos], "%x", &chunk_size);
if (chunk_size > http_response.size()) {
// precaution, in case we misread chunk size
LOGE("invalid chunk size %u", chunk_size);
return;
}
/*
* searches for chunks
*
* header
* chunk size\r\n <-- chunk_size_pos @ beginning of chunk size
* chunk data\r\n <-- chunk_pos @ beginning of chunk data
* chunk size\r\n
* chunk data\r\n
* 0\r\n
*/
const std::string kCrLf = "\r\n";
size_t chunk_pos = http_response.find(kCrLf, chunk_size_pos) +
kCrLf.size();
message_body->assign(&http_response[0], chunk_size_pos);
while ((chunk_size > 0) && (std::string::npos != chunk_pos)) {
message_body->append(&http_response[chunk_pos], chunk_size);
// searches for next chunk
chunk_size_pos = chunk_pos + chunk_size + kCrLf.size();
sscanf(&http_response[chunk_size_pos], "%x", &chunk_size);
if (chunk_size > http_response.size()) {
// precaution, in case we misread chunk size
LOGE("invalid chunk size %u", chunk_size);
break;
}
chunk_pos = http_response.find(kCrLf, chunk_size_pos) + kCrLf.size();
}
} else {
// response is not chunked encoded
message_body->assign(http_response);
}
}
// posts a request and extracts the drm message from the response
std::string GetKeyRequestResponse(const std::string& server_url,
const std::string& client_auth,
@@ -88,14 +163,17 @@ class WvCdmRequestLicenseTest : public testing::Test {
}
url_request.PostRequest(key_msg_);
std::string response;
int resp_bytes = url_request.GetResponse(response);
LOGD("response:\r\n%s", response.c_str());
std::string http_response;
std::string message_body;
int resp_bytes = url_request.GetResponse(http_response);
if (resp_bytes) {
ConcatenateChunkedResponse(http_response, &message_body);
}
LOGD("end %d bytes response dump", resp_bytes);
// Youtube server returns 400 for invalid message while play server returns
// 500, so just test inequity here for invalid message
int status_code = url_request.GetStatusCode(response);
int status_code = url_request.GetStatusCode(message_body);
if (expected_response == 200) {
EXPECT_EQ(200, status_code);
} else {
@@ -105,7 +183,7 @@ class WvCdmRequestLicenseTest : public testing::Test {
std::string drm_msg;
if (200 == status_code) {
LicenseRequest lic_request;
lic_request.GetDrmMessage(response, drm_msg);
lic_request.GetDrmMessage(message_body, drm_msg);
LOGV("drm msg: %u bytes\r\n%s", drm_msg.size(),
HexEncode(reinterpret_cast<const uint8_t*>(drm_msg.data()),
drm_msg.size()).c_str());
@@ -123,24 +201,23 @@ class WvCdmRequestLicenseTest : public testing::Test {
}
url_request.PostCertRequestInQueryString(key_msg_);
std::string response;
int resp_bytes = url_request.GetResponse(response);
std::string http_response;
std::string message_body;
int resp_bytes = url_request.GetResponse(http_response);
if (resp_bytes) {
LOGD("size=%u, response start:\t\rn%s", response.size(),
response.substr(0, 1024).c_str());
LOGD("end:\r\n%s", response.substr(response.size() - 256).c_str());
ConcatenateChunkedResponse(http_response, &message_body);
}
LOGD("end %d bytes response dump", resp_bytes);
// Youtube server returns 400 for invalid message while play server returns
// 500, so just test inequity here for invalid message
int status_code = url_request.GetStatusCode(response);
int status_code = url_request.GetStatusCode(message_body);
if (expected_response == 200) {
EXPECT_EQ(200, status_code);
} else {
EXPECT_NE(200, status_code);
}
return response;
return message_body;
}
void VerifyKeyRequestResponse(const std::string& server_url,
@@ -170,12 +247,28 @@ TEST_F(WvCdmRequestLicenseTest, ProvisioningTest) {
decryptor_.OpenSession(g_key_system, &session_id_);
std::string provisioning_server_url = "";
decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url);
EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url));
EXPECT_STREQ(provisioning_server_url.data(), kDefaultProvisioningServerUrl.data());
std::string response = GetCertRequestResponse(kDefaultProvisioningServerUrl, 200);
if (!response.empty())
decryptor_.HandleProvisioningResponse(response);
EXPECT_NE(0, static_cast<int>(response.size()));
EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.HandleProvisioningResponse(response));
decryptor_.CloseSession(session_id_);
}
TEST_F(WvCdmRequestLicenseTest, ProvisioningRetryTest) {
decryptor_.OpenSession(g_key_system, &session_id_);
std::string provisioning_server_url = "";
EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url));
EXPECT_STREQ(provisioning_server_url.data(), kDefaultProvisioningServerUrl.data());
EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.GetProvisioningRequest(&key_msg_, &provisioning_server_url));
EXPECT_STREQ(provisioning_server_url.data(), kDefaultProvisioningServerUrl.data());
std::string response = GetCertRequestResponse(kDefaultProvisioningServerUrl, 200);
EXPECT_NE(0, static_cast<int>(response.size()));
EXPECT_EQ(wvcdm::NO_ERROR, decryptor_.HandleProvisioningResponse(response));
decryptor_.CloseSession(session_id_);
}