Simplify How Request ID Indices are Generated

(This is a merge of http://go/wvgerrit/70667)

Request ID Index generation has historically worked by incrementing a
shared variable in one place and reading it in another place and
trusting the fact that CdmLicense calls these operations in a certain
order and only once per session to give each session a unique value.
This patch cleans this up a bit, having each session store the current
Request ID Index at the same time as it stores its Request ID Base. This
guarantees that each CryptoSession will receive a unique but stable
combination of Base and ID rather than relying on the calling pattern.

Since all this generation happens during the same function, the full
Request ID can be generated up-front and stored, making
GenerateRequestId() no longer necessary.

This patch also simplifies the threading story around this shared state
by using a std::atomic<uint64_t>. Bringing the code that interacts with
the shared state together into one place and replacing it with atomic
operations will simplify locking around this code when CryptoSession
locking is revamped in a future patch.

Bug: 70889998
Bug: 118584039
Test: CE CDM Unit Tests
Test: Android Unit Tests
Change-Id: I12d2f6501f872f1973e5a9af5125ca03f23e5a56
This commit is contained in:
John W. Bruce
2019-01-18 16:24:59 -08:00
parent 700ee5160a
commit ca00dc7ae4
4 changed files with 20 additions and 29 deletions

View File

@@ -5,6 +5,7 @@
#ifndef WVCDM_CORE_CRYPTO_SESSION_H_
#define WVCDM_CORE_CRYPTO_SESSION_H_
#include <atomic>
#include <map>
#include <memory>
#include <mutex>
@@ -88,7 +89,7 @@ class CryptoSession {
virtual CryptoSessionId oec_session_id() { return oec_session_id_; }
// Key request/response
virtual bool GenerateRequestId(std::string* req_id_str);
virtual const std::string& request_id() { return request_id_; }
virtual bool PrepareRequest(const std::string& key_deriv_message,
bool is_provisioning, std::string* signature);
virtual bool PrepareRenewalRequest(const std::string& message,
@@ -332,8 +333,8 @@ class CryptoSession {
static UsageTableHeader* usage_table_header_l1_;
static UsageTableHeader* usage_table_header_l3_;
uint64_t request_id_base_;
static uint64_t request_id_index_;
std::string request_id_;
static std::atomic<uint64_t> request_id_index_source_;
CdmCipherMode cipher_mode_;
uint32_t api_version_;

View File

@@ -78,9 +78,9 @@ namespace wvcdm {
std::mutex CryptoSession::crypto_lock_;
bool CryptoSession::initialized_ = false;
int CryptoSession::session_count_ = 0;
uint64_t CryptoSession::request_id_index_ = 0;
UsageTableHeader* CryptoSession::usage_table_header_l1_ = NULL;
UsageTableHeader* CryptoSession::usage_table_header_l3_ = NULL;
std::atomic<uint64_t> CryptoSession::request_id_index_source_(0);
size_t GetOffset(std::string message, std::string field) {
size_t pos = message.find(field);
@@ -167,7 +167,6 @@ CryptoSession::CryptoSession(metrics::CryptoMetrics* metrics)
is_usage_support_type_valid_(false),
usage_support_type_(kUnknownUsageSupport),
usage_table_header_(NULL),
request_id_base_(0),
cipher_mode_(kCipherModeCtr),
api_version_(0) {
assert(metrics);
@@ -749,10 +748,16 @@ CdmResponseType CryptoSession::Open(SecurityLevel requested_security_level) {
return LOAD_SYSTEM_ID_ERROR;
}
uint64_t request_id_base;
OEMCryptoResult random_sts = OEMCrypto_GetRandom(
reinterpret_cast<uint8_t*>(&request_id_base_), sizeof(request_id_base_));
reinterpret_cast<uint8_t*>(&request_id_base), sizeof(request_id_base));
metrics_->oemcrypto_get_random_.Increment(random_sts);
++request_id_index_;
uint64_t request_id_index =
request_id_index_source_.fetch_add(1, std::memory_order_relaxed);
request_id_ = HexEncode(reinterpret_cast<uint8_t*>(&request_id_base),
sizeof(request_id_base)) +
HexEncode(reinterpret_cast<uint8_t*>(&request_id_index),
sizeof(request_id_index));
if (!GetApiVersion(&api_version_)) {
LOGE("CryptoSession::Open: GetApiVersion failed");
@@ -819,21 +824,6 @@ void CryptoSession::Close() {
}
}
bool CryptoSession::GenerateRequestId(std::string* req_id_str) {
LOGV("CryptoSession::GenerateRequestId: Lock");
std::unique_lock<std::mutex> auto_lock(crypto_lock_);
if (!req_id_str) {
LOGE("CryptoSession::GenerateRequestId: No output destination provided.");
return false;
}
*req_id_str = HexEncode(reinterpret_cast<uint8_t*>(&request_id_base_),
sizeof(request_id_base_)) +
HexEncode(reinterpret_cast<uint8_t*>(&request_id_index_),
sizeof(request_id_index_));
return true;
}
bool CryptoSession::PrepareRequest(const std::string& message,
bool is_provisioning,
std::string* signature) {

View File

@@ -300,8 +300,7 @@ CdmResponseType CdmLicense::PrepareKeyRequest(
return KEY_MESSAGE;
}
std::string request_id;
crypto_session_->GenerateRequestId(&request_id);
const std::string& request_id = crypto_session_->request_id();
LicenseRequest license_request;
CdmResponseType status;

View File

@@ -138,7 +138,7 @@ class MockCryptoSession : public TestCryptoSession {
MockCryptoSession(metrics::CryptoMetrics* crypto_metrics)
: TestCryptoSession(crypto_metrics) { }
MOCK_METHOD0(IsOpen, bool());
MOCK_METHOD1(GenerateRequestId, bool(std::string*));
MOCK_METHOD0(request_id, const std::string&());
MOCK_METHOD1(UsageInformationSupport, bool(bool*));
MOCK_METHOD2(GetHdcpCapabilities, bool(HdcpCapability*, HdcpCapability*));
MOCK_METHOD1(GetSupportedCertificateTypes, bool(SupportedCertificateTypes*));
@@ -188,6 +188,7 @@ using ::testing::Eq;
using ::testing::NotNull;
using ::testing::PrintToStringParamName;
using ::testing::Return;
using ::testing::ReturnRef;
using ::testing::SetArgPointee;
using ::testing::UnorderedElementsAre;
using ::testing::Values;
@@ -310,8 +311,8 @@ TEST_F(CdmLicenseTest, PrepareKeyRequestValidation) {
EXPECT_CALL(*crypto_session_, IsOpen())
.WillOnce(Return(true));
EXPECT_CALL(*crypto_session_, GenerateRequestId(NotNull()))
.WillOnce(DoAll(SetArgPointee<0>(kCryptoRequestId), Return(true)));
EXPECT_CALL(*crypto_session_, request_id())
.WillOnce(ReturnRef(kCryptoRequestId));
EXPECT_CALL(*crypto_session_, UsageInformationSupport(NotNull()))
.WillOnce(
DoAll(SetArgPointee<0>(usage_information_support), Return(true)));
@@ -430,8 +431,8 @@ TEST_F(CdmLicenseTest, PrepareKeyRequestValidationV15) {
EXPECT_CALL(*crypto_session_, IsOpen())
.WillOnce(Return(true));
EXPECT_CALL(*crypto_session_, GenerateRequestId(NotNull()))
.WillOnce(DoAll(SetArgPointee<0>(kCryptoRequestId), Return(true)));
EXPECT_CALL(*crypto_session_, request_id())
.WillOnce(ReturnRef(kCryptoRequestId));
EXPECT_CALL(*crypto_session_, UsageInformationSupport(NotNull()))
.WillOnce(
DoAll(SetArgPointee<0>(usage_information_support), Return(true)));