Reference code for ECC operations.
[ Merge of http://go/wvgerrit/113750 ] This introduces two classes EccPublicKey and EccPrivateKey which perform all ECC-specific crypto operations. The main operations required by ECC are: - Load/serialize keys from/to X.509 DER formats - Generate ECC signatures - Verify ECC signatures - Derive session keys used by other OEMCrypto operations These new classes still need to be plugged into rest of the reference OEMCrypto implementation. Bug: 135283522 Test: Future CL Change-Id: Id071cad9129f95a6eb08662322154ba7d1548d40
This commit is contained in:
781
libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.cpp
Normal file
781
libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.cpp
Normal file
@@ -0,0 +1,781 @@
|
||||
// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary
|
||||
// source code may only be used and distributed under the Widevine License
|
||||
// Agreement.
|
||||
//
|
||||
// Reference implementation of OEMCrypto APIs
|
||||
//
|
||||
#include <assert.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
#include <openssl/bn.h>
|
||||
#include <openssl/crypto.h>
|
||||
#include <openssl/ec.h>
|
||||
#include <openssl/ecdsa.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <openssl/x509.h>
|
||||
|
||||
#include "log.h"
|
||||
#include "oemcrypto_ecc_key.h"
|
||||
#include "scoped_object.h"
|
||||
|
||||
namespace wvoec_ref {
|
||||
namespace {
|
||||
// Estimated max size (in bytes) of a serialized ECC key (public or
|
||||
// private). These values are based on rough calculations for
|
||||
// secp521r1 (largest of the supported curves) and should be slightly
|
||||
// larger needed.
|
||||
constexpr size_t kPrivateKeySize = 230;
|
||||
constexpr size_t kPublicKeySize = 164;
|
||||
|
||||
// 256 bit key, intended to be used with CMAC-AES-256.
|
||||
constexpr size_t kEccSessionKeySize = 32;
|
||||
|
||||
using ScopedBigNum = ScopedObject<BIGNUM, BN_free>;
|
||||
using ScopedBigNumCtx = ScopedObject<BN_CTX, BN_CTX_free>;
|
||||
using ScopedEcKey = ScopedObject<EC_KEY, EC_KEY_free>;
|
||||
using ScopedSigPoint = ScopedObject<ECDSA_SIG, ECDSA_SIG_free>;
|
||||
using ScopedEvpMdCtx = ScopedObject<EVP_MD_CTX, EVP_MD_CTX_free>;
|
||||
|
||||
const EC_GROUP* GetEcGroup(EccCurve curve) {
|
||||
// Creating a named EC_GROUP is an expensive operation, and they
|
||||
// are always used in a manner which does not transfer ownership.
|
||||
// Maintaining a process-wide set of supported EC groups reduces
|
||||
// the overhead of group operations.
|
||||
static std::mutex group_mutex;
|
||||
static EC_GROUP* group_256 = nullptr;
|
||||
static EC_GROUP* group_384 = nullptr;
|
||||
static EC_GROUP* group_521 = nullptr;
|
||||
std::lock_guard<std::mutex> group_lock(group_mutex);
|
||||
switch (curve) {
|
||||
case kEccSecp256r1: {
|
||||
if (group_256 == nullptr) {
|
||||
LOGD("Creating secp256r1 group");
|
||||
// The curve secp256r1 was originally named prime256v1
|
||||
// in the X9.62 specification.
|
||||
group_256 = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
|
||||
assert(group_256 != nullptr);
|
||||
}
|
||||
return group_256;
|
||||
}
|
||||
case kEccSecp384r1: {
|
||||
if (group_384 == nullptr) {
|
||||
LOGD("Creating secp384r1 group");
|
||||
group_384 = EC_GROUP_new_by_curve_name(NID_secp384r1);
|
||||
assert(group_384 != nullptr);
|
||||
}
|
||||
return group_384;
|
||||
}
|
||||
case kEccSecp521r1: {
|
||||
if (group_521 == nullptr) {
|
||||
LOGD("Creating secp521r1 group");
|
||||
group_521 = EC_GROUP_new_by_curve_name(NID_secp521r1);
|
||||
assert(group_521 != nullptr);
|
||||
}
|
||||
return group_521;
|
||||
}
|
||||
default:
|
||||
LOGE("Cannot get EC group for unknown curve: curve = %d",
|
||||
static_cast<int>(curve));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Determines which of the supported ECC curves the provided |key|
|
||||
// belongs to.
|
||||
//
|
||||
// This is intended to be used on keys that have been deserialized
|
||||
// from an ASN.1 structure which may have contained a key which is
|
||||
// supported by OpenSSL/BoringSSL but not necessarily by OEMCrypto.
|
||||
//
|
||||
// If the key group is unknown to OEMCrypto or if an error occurs,
|
||||
// kEccCurveUnknown is returned.
|
||||
EccCurve GetCurveFromKeyGroup(const EC_KEY* key) {
|
||||
ScopedBigNumCtx ctx(BN_CTX_new());
|
||||
if (!ctx) {
|
||||
LOGE("Failed to allocate BN ctx");
|
||||
return kEccCurveUnknown;
|
||||
}
|
||||
const EC_GROUP* group = EC_KEY_get0_group(key);
|
||||
if (group == nullptr) {
|
||||
LOGE("Provided key does not have a group");
|
||||
return kEccCurveUnknown;
|
||||
}
|
||||
int rc = EC_GROUP_cmp(group, GetEcGroup(kEccSecp256r1), ctx.get());
|
||||
if (rc == 0) {
|
||||
return kEccSecp256r1;
|
||||
}
|
||||
if (rc == -1) {
|
||||
LOGE("Error occurred while checking against secp256r1");
|
||||
return kEccCurveUnknown;
|
||||
}
|
||||
|
||||
rc = EC_GROUP_cmp(group, GetEcGroup(kEccSecp384r1), ctx.get());
|
||||
if (rc == 0) {
|
||||
return kEccSecp384r1;
|
||||
}
|
||||
if (rc == -1) {
|
||||
LOGE("Error occurred while checking against secp384r1");
|
||||
return kEccCurveUnknown;
|
||||
}
|
||||
|
||||
rc = EC_GROUP_cmp(group, GetEcGroup(kEccSecp521r1), ctx.get());
|
||||
if (rc == 0) {
|
||||
return kEccSecp521r1;
|
||||
}
|
||||
if (rc == -1) {
|
||||
LOGE("Error occurred while checking against secp521r1");
|
||||
return kEccCurveUnknown;
|
||||
}
|
||||
|
||||
LOGW("Unsupported curve group");
|
||||
return kEccCurveUnknown;
|
||||
}
|
||||
|
||||
// Compares the public EC points of both keys to see if they are the
|
||||
// equal.
|
||||
// Both |public_key| and |private_key| must be of the same group.
|
||||
bool IsMatchingKeyPair(const EC_KEY* public_key, const EC_KEY* private_key) {
|
||||
ScopedBigNumCtx ctx(BN_CTX_new());
|
||||
if (!ctx) {
|
||||
LOGE("Failed to allocate BN ctx");
|
||||
return false;
|
||||
}
|
||||
// Returns: 1 if not equal, 0 if equal, -1 if error.
|
||||
const int res = EC_POINT_cmp(EC_KEY_get0_group(public_key),
|
||||
EC_KEY_get0_public_key(public_key),
|
||||
EC_KEY_get0_public_key(private_key), ctx.get());
|
||||
if (res == -1) {
|
||||
LOGE("Error occurred comparing keys");
|
||||
}
|
||||
return res == 0;
|
||||
}
|
||||
|
||||
// Performs a SHA2 digest on the provided |message| and outputs the
|
||||
// computed hash to |digest|.
|
||||
// The digest algorithm used depends on which curve is used.
|
||||
// - secp256r1 -> SHA-256
|
||||
// - secp384r1 -> SHA-384
|
||||
// - secp521r1 -> SHA-512
|
||||
// This function assumes that all parameters are valid.
|
||||
// Returns true on success, false otherwise.
|
||||
bool DigestMessage(EccCurve curve, const uint8_t* message, size_t message_size,
|
||||
std::vector<uint8_t>* digest) {
|
||||
const EVP_MD* md_engine = nullptr;
|
||||
switch (curve) {
|
||||
case kEccSecp256r1: {
|
||||
md_engine = EVP_sha256();
|
||||
break;
|
||||
}
|
||||
case kEccSecp384r1: {
|
||||
md_engine = EVP_sha384();
|
||||
break;
|
||||
}
|
||||
case kEccSecp521r1: {
|
||||
md_engine = EVP_sha512();
|
||||
break;
|
||||
}
|
||||
case kEccCurveUnknown:
|
||||
// This case is to suppress compiler warnings. It will never
|
||||
// occur.
|
||||
break;
|
||||
}
|
||||
if (md_engine == nullptr) {
|
||||
LOGE("Failed to get MD engine: curve = %d", static_cast<int>(curve));
|
||||
return false;
|
||||
}
|
||||
|
||||
ScopedEvpMdCtx md_ctx(EVP_MD_CTX_new());
|
||||
if (!md_ctx) {
|
||||
LOGE("Failed to create MD CTX");
|
||||
return false;
|
||||
}
|
||||
if (!EVP_DigestInit_ex(md_ctx.get(), md_engine, nullptr)) {
|
||||
LOGE("Failed to init MD CTX");
|
||||
return false;
|
||||
}
|
||||
if (message_size > 0 &&
|
||||
!EVP_DigestUpdate(md_ctx.get(), message, message_size)) {
|
||||
LOGE("Failed to update");
|
||||
return false;
|
||||
}
|
||||
digest->resize(EVP_MD_CTX_size(md_ctx.get()), 0);
|
||||
const int res = EVP_DigestFinal_ex(md_ctx.get(), digest->data(), nullptr);
|
||||
if (!res) {
|
||||
LOGE("Failed to finalize");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// This KDF function is defined by OEMCrypto ECC specification.
|
||||
// Function signature is based on the |kdf| parameter of
|
||||
// ECDH_compute_key(). This function assumes that all pointer
|
||||
// parameters are not null.
|
||||
extern "C" void* WidevineEccKdf(const void* secret, size_t secret_length,
|
||||
void* key, size_t* key_size) {
|
||||
if (*key_size < kEccSessionKeySize) {
|
||||
LOGE("Output buffer is too small: required = %zu, size = %zu",
|
||||
kEccSessionKeySize, *key_size);
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<uint8_t> digest;
|
||||
if (!DigestMessage(kEccSecp256r1 /* SHA-256 */,
|
||||
reinterpret_cast<const uint8_t*>(secret), secret_length,
|
||||
&digest)) {
|
||||
LOGE("Cannot derive key: Failed to hash secret");
|
||||
return nullptr;
|
||||
}
|
||||
if (digest.size() != kEccSessionKeySize) {
|
||||
LOGE("Unexpected hash size: actual = %zu, expected = %zu", digest.size(),
|
||||
kEccSessionKeySize);
|
||||
return nullptr;
|
||||
}
|
||||
*key_size = kEccSessionKeySize;
|
||||
memcpy(key, digest.data(), *key_size);
|
||||
return key;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::string EccCurveToString(EccCurve curve) {
|
||||
switch (curve) {
|
||||
case kEccSecp256r1:
|
||||
return "secp256r1";
|
||||
case kEccSecp384r1:
|
||||
return "secp384r1";
|
||||
case kEccSecp521r1:
|
||||
return "secp521r1";
|
||||
case kEccCurveUnknown:
|
||||
return "Unknown";
|
||||
}
|
||||
return "Unknown(" + std::to_string(static_cast<int>(curve)) + ")";
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPublicKey> EccPublicKey::New(
|
||||
const EccPrivateKey& private_key) {
|
||||
std::unique_ptr<EccPublicKey> key(new EccPublicKey());
|
||||
if (!key->InitFromPrivateKey(private_key)) {
|
||||
LOGE("Failed to initialize public key from private key");
|
||||
key.reset();
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPublicKey> EccPublicKey::Load(const uint8_t* buffer,
|
||||
size_t length) {
|
||||
std::unique_ptr<EccPublicKey> key;
|
||||
if (buffer == nullptr) {
|
||||
LOGE("Provided public key buffer is null");
|
||||
return key;
|
||||
}
|
||||
if (length == 0) {
|
||||
LOGE("Provided public key buffer is zero length");
|
||||
return key;
|
||||
}
|
||||
key.reset(new EccPublicKey());
|
||||
if (!key->InitFromBuffer(buffer, length)) {
|
||||
LOGE("Failed to initialize public key from buffer");
|
||||
key.reset();
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPublicKey> EccPublicKey::Load(const std::string& buffer) {
|
||||
if (buffer.empty()) {
|
||||
LOGE("Provided public key buffer is empty");
|
||||
return std::unique_ptr<EccPublicKey>();
|
||||
}
|
||||
return Load(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size());
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPublicKey> EccPublicKey::Load(
|
||||
const std::vector<uint8_t>& buffer) {
|
||||
if (buffer.empty()) {
|
||||
LOGE("Provided public key buffer is empty");
|
||||
return std::unique_ptr<EccPublicKey>();
|
||||
}
|
||||
return Load(buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
bool EccPublicKey::IsMatchingPrivateKey(
|
||||
const EccPrivateKey& private_key) const {
|
||||
if (private_key.curve() != curve_) {
|
||||
return false;
|
||||
}
|
||||
return IsMatchingKeyPair(GetEcKey(), private_key.GetEcKey());
|
||||
}
|
||||
|
||||
OEMCryptoResult EccPublicKey::Serialize(uint8_t* buffer,
|
||||
size_t* buffer_size) const {
|
||||
if (buffer_size == nullptr) {
|
||||
LOGE("Output buffer size is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (buffer == nullptr && *buffer_size > 0) {
|
||||
LOGE("Output buffer is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
|
||||
uint8_t* der_key = nullptr;
|
||||
const int der_res = i2d_EC_PUBKEY(key_, &der_key);
|
||||
if (der_res < 0) {
|
||||
LOGE("Public key serialization failed");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
if (der_key == nullptr) {
|
||||
LOGE("Encoded key is unexpectedly null");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
if (der_res == 0) {
|
||||
LOGE("Unexpected DER encoded size");
|
||||
OPENSSL_free(der_key);
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
const size_t required_size = static_cast<size_t>(der_res);
|
||||
if (buffer == nullptr || *buffer_size < required_size) {
|
||||
*buffer_size = required_size;
|
||||
OPENSSL_free(der_key);
|
||||
return OEMCrypto_ERROR_SHORT_BUFFER;
|
||||
}
|
||||
memcpy(buffer, der_key, required_size);
|
||||
*buffer_size = required_size;
|
||||
OPENSSL_free(der_key);
|
||||
return OEMCrypto_SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> EccPublicKey::Serialize() const {
|
||||
size_t key_size = kPublicKeySize;
|
||||
std::vector<uint8_t> key_data(key_size, 0);
|
||||
const OEMCryptoResult res = Serialize(key_data.data(), &key_size);
|
||||
if (res != OEMCrypto_SUCCESS) {
|
||||
LOGE("Failed to serialize public key: result = %d", static_cast<int>(res));
|
||||
key_data.clear();
|
||||
} else {
|
||||
key_data.resize(key_size);
|
||||
}
|
||||
return key_data;
|
||||
}
|
||||
|
||||
OEMCryptoResult EccPublicKey::VerifySignature(const uint8_t* message,
|
||||
size_t message_length,
|
||||
const uint8_t* signature,
|
||||
size_t signature_length) const {
|
||||
if (signature == nullptr || signature_length == 0) {
|
||||
LOGE("Signature is missing");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (message == nullptr && message_length > 0) {
|
||||
LOGE("Bad message data");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
// Step 1: Parse signature.
|
||||
const uint8_t* tp = signature;
|
||||
ScopedSigPoint sig_point(d2i_ECDSA_SIG(nullptr, &tp, signature_length));
|
||||
if (!sig_point) {
|
||||
LOGE("Failed to parse signature");
|
||||
// Most likely an invalid signature than an OpenSSL error.
|
||||
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
|
||||
}
|
||||
// Step 2: Hash message
|
||||
std::vector<uint8_t> digest;
|
||||
if (!DigestMessage(curve_, message, message_length, &digest)) {
|
||||
LOGE("Failed to digest message");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
// Step 3: Verify signature
|
||||
const int res =
|
||||
ECDSA_do_verify(digest.data(), digest.size(), sig_point.get(), key_);
|
||||
if (res == -1) {
|
||||
LOGE("Error occurred checking signature");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
if (res == 0) {
|
||||
LOGD("Signature did not match");
|
||||
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
|
||||
}
|
||||
return OEMCrypto_SUCCESS;
|
||||
}
|
||||
|
||||
OEMCryptoResult EccPublicKey::VerifySignature(
|
||||
const std::string& message, const std::string& signature) const {
|
||||
if (signature.empty()) {
|
||||
LOGE("Signature should not be empty");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
return VerifySignature(
|
||||
reinterpret_cast<const uint8_t*>(message.data()), message.size(),
|
||||
reinterpret_cast<const uint8_t*>(signature.data()), signature.size());
|
||||
}
|
||||
|
||||
OEMCryptoResult EccPublicKey::VerifySignature(
|
||||
const std::vector<uint8_t>& message,
|
||||
const std::vector<uint8_t>& signature) const {
|
||||
if (signature.empty()) {
|
||||
LOGE("Signature should not be empty");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
return VerifySignature(message.data(), message.size(), signature.data(),
|
||||
signature.size());
|
||||
}
|
||||
|
||||
EccPublicKey::~EccPublicKey() {
|
||||
if (key_ != nullptr) {
|
||||
EC_KEY_free(key_);
|
||||
key_ = nullptr;
|
||||
}
|
||||
curve_ = kEccCurveUnknown;
|
||||
}
|
||||
|
||||
bool EccPublicKey::InitFromBuffer(const uint8_t* buffer, size_t length) {
|
||||
// Deserialize SubjectPublicKeyInfo
|
||||
const uint8_t* tp = buffer;
|
||||
ScopedEcKey key(d2i_EC_PUBKEY(nullptr, &tp, length));
|
||||
if (!key) {
|
||||
LOGE("Failed to parse key");
|
||||
return false;
|
||||
}
|
||||
curve_ = GetCurveFromKeyGroup(key.get());
|
||||
if (curve_ == kEccCurveUnknown) {
|
||||
LOGE("Failed to determine key group");
|
||||
return false;
|
||||
}
|
||||
// Required flags for IETF compliance.
|
||||
EC_KEY_set_asn1_flag(key.get(), OPENSSL_EC_NAMED_CURVE);
|
||||
EC_KEY_set_conv_form(key.get(), POINT_CONVERSION_UNCOMPRESSED);
|
||||
key_ = key.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EccPublicKey::InitFromPrivateKey(const EccPrivateKey& private_key) {
|
||||
ScopedEcKey key(EC_KEY_new());
|
||||
if (!key) {
|
||||
LOGE("Failed to allocate key");
|
||||
return false;
|
||||
}
|
||||
if (!EC_KEY_set_group(key.get(), EC_KEY_get0_group(private_key.GetEcKey()))) {
|
||||
LOGE("Failed to set group");
|
||||
return false;
|
||||
}
|
||||
if (!EC_KEY_set_public_key(key.get(),
|
||||
EC_KEY_get0_public_key(private_key.GetEcKey()))) {
|
||||
LOGE("Failed to set public point");
|
||||
return false;
|
||||
}
|
||||
curve_ = private_key.curve();
|
||||
// Required flags for IETF compliance.
|
||||
EC_KEY_set_asn1_flag(key.get(), OPENSSL_EC_NAMED_CURVE);
|
||||
EC_KEY_set_conv_form(key.get(), POINT_CONVERSION_UNCOMPRESSED);
|
||||
key_ = key.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPrivateKey> EccPrivateKey::New(EccCurve curve) {
|
||||
std::unique_ptr<EccPrivateKey> key(new EccPrivateKey());
|
||||
if (!key->InitFromCurve(curve)) {
|
||||
LOGE("Failed to initialize private key from curve");
|
||||
key.reset();
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPrivateKey> EccPrivateKey::Load(const uint8_t* buffer,
|
||||
size_t length) {
|
||||
std::unique_ptr<EccPrivateKey> key;
|
||||
if (buffer == nullptr) {
|
||||
LOGE("Provided private key buffer is null");
|
||||
return key;
|
||||
}
|
||||
if (length == 0) {
|
||||
LOGE("Provided private key buffer is zero length");
|
||||
return key;
|
||||
}
|
||||
key.reset(new EccPrivateKey());
|
||||
if (!key->InitFromBuffer(buffer, length)) {
|
||||
LOGE("Failed to initialize private key from buffer");
|
||||
key.reset();
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPrivateKey> EccPrivateKey::Load(const std::string& buffer) {
|
||||
if (buffer.empty()) {
|
||||
LOGE("Provided private key buffer is empty");
|
||||
return std::unique_ptr<EccPrivateKey>();
|
||||
}
|
||||
return Load(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size());
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<EccPrivateKey> EccPrivateKey::Load(
|
||||
const std::vector<uint8_t>& buffer) {
|
||||
if (buffer.empty()) {
|
||||
LOGE("Provided private key buffer is empty");
|
||||
return std::unique_ptr<EccPrivateKey>();
|
||||
}
|
||||
return Load(buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
std::unique_ptr<EccPublicKey> EccPrivateKey::MakePublicKey() const {
|
||||
return EccPublicKey::New(*this);
|
||||
}
|
||||
|
||||
bool EccPrivateKey::IsMatchingPublicKey(const EccPublicKey& public_key) const {
|
||||
if (public_key.curve() != curve_) {
|
||||
return false;
|
||||
}
|
||||
return IsMatchingKeyPair(public_key.GetEcKey(), GetEcKey());
|
||||
}
|
||||
|
||||
OEMCryptoResult EccPrivateKey::Serialize(uint8_t* buffer,
|
||||
size_t* buffer_size) const {
|
||||
if (buffer_size == nullptr) {
|
||||
LOGE("Output buffer size is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (buffer == nullptr && *buffer_size > 0) {
|
||||
LOGE("Output buffer is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
uint8_t* der_key = nullptr;
|
||||
const int der_res = i2d_ECPrivateKey(key_, &der_key);
|
||||
if (der_res < 0) {
|
||||
LOGE("Private key serialization failed");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
if (der_key == nullptr) {
|
||||
LOGE("Encoded key is unexpectedly null");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
if (der_res == 0) {
|
||||
LOGE("Unexpected DER encoded size");
|
||||
OPENSSL_free(der_key);
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
const size_t required_size = static_cast<size_t>(der_res);
|
||||
if (*buffer_size < required_size || buffer == nullptr) {
|
||||
*buffer_size = required_size;
|
||||
OPENSSL_free(der_key);
|
||||
return OEMCrypto_ERROR_SHORT_BUFFER;
|
||||
}
|
||||
memcpy(buffer, der_key, required_size);
|
||||
*buffer_size = required_size;
|
||||
OPENSSL_free(der_key);
|
||||
return OEMCrypto_SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> EccPrivateKey::Serialize() const {
|
||||
size_t key_size = kPrivateKeySize;
|
||||
std::vector<uint8_t> key_data(key_size, 0);
|
||||
const OEMCryptoResult res = Serialize(key_data.data(), &key_size);
|
||||
if (res != OEMCrypto_SUCCESS) {
|
||||
LOGE("Failed to serialize private key: result = %d", static_cast<int>(res));
|
||||
key_data.clear();
|
||||
} else {
|
||||
key_data.resize(key_size);
|
||||
}
|
||||
return key_data;
|
||||
}
|
||||
|
||||
OEMCryptoResult EccPrivateKey::GenerateSignature(
|
||||
const uint8_t* message, size_t message_length, uint8_t* signature,
|
||||
size_t* signature_length) const {
|
||||
if (signature_length == nullptr) {
|
||||
LOGE("Output signature size is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (signature == nullptr && *signature_length > 0) {
|
||||
LOGE("Output signature is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (message == nullptr && message_length > 0) {
|
||||
LOGE("Invalid message data");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
const size_t expected_signature_length = ECDSA_size(key_);
|
||||
if (*signature_length < expected_signature_length) {
|
||||
*signature_length = expected_signature_length;
|
||||
return OEMCrypto_ERROR_SHORT_BUFFER;
|
||||
}
|
||||
|
||||
// Step 1: Hash message.
|
||||
std::vector<uint8_t> digest;
|
||||
if (!DigestMessage(curve_, message, message_length, &digest)) {
|
||||
LOGE("Failed to digest message");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
// Step 2: Generate signature point.
|
||||
ScopedSigPoint sig_point(ECDSA_do_sign(digest.data(), digest.size(), key_));
|
||||
if (!sig_point) {
|
||||
LOGE("Failed to perform ECDSA");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
// Step 3: Serialize
|
||||
std::vector<uint8_t> temp(expected_signature_length);
|
||||
uint8_t* sig_ptr = temp.data();
|
||||
const int res = i2d_ECDSA_SIG(sig_point.get(), &sig_ptr);
|
||||
if (res <= 0) {
|
||||
LOGE("Failed to serialize signature");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
const size_t required_size = static_cast<size_t>(res);
|
||||
if (signature == nullptr || *signature_length < required_size) {
|
||||
*signature_length = required_size;
|
||||
return OEMCrypto_ERROR_SHORT_BUFFER;
|
||||
}
|
||||
memcpy(signature, temp.data(), required_size);
|
||||
*signature_length = required_size;
|
||||
return OEMCrypto_SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> EccPrivateKey::GenerateSignature(
|
||||
const std::string& message) const {
|
||||
size_t signature_size = SignatureSize();
|
||||
std::vector<uint8_t> signature(signature_size, 0);
|
||||
const OEMCryptoResult res =
|
||||
GenerateSignature(reinterpret_cast<const uint8_t*>(message.data()),
|
||||
message.size(), signature.data(), &signature_size);
|
||||
if (res != OEMCrypto_SUCCESS) {
|
||||
LOGE("Failed to generate signature: result = %d", static_cast<int>(res));
|
||||
signature.clear();
|
||||
} else {
|
||||
signature.resize(signature_size);
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> EccPrivateKey::GenerateSignature(
|
||||
const std::vector<uint8_t>& message) const {
|
||||
size_t signature_size = SignatureSize();
|
||||
std::vector<uint8_t> signature(signature_size, 0);
|
||||
const OEMCryptoResult res = GenerateSignature(
|
||||
message.data(), message.size(), signature.data(), &signature_size);
|
||||
if (res != OEMCrypto_SUCCESS) {
|
||||
LOGE("Failed to generate signature: result = %d", static_cast<int>(res));
|
||||
signature.clear();
|
||||
} else {
|
||||
signature.resize(signature_size);
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
size_t EccPrivateKey::SignatureSize() const { return ECDSA_size(key_); }
|
||||
|
||||
OEMCryptoResult EccPrivateKey::DeriveSessionKey(
|
||||
const EccPublicKey& public_key, uint8_t* session_key,
|
||||
size_t* session_key_size) const {
|
||||
if (public_key.curve() != curve_) {
|
||||
LOGE("Incompatible ECC keys: public = %s, private = %s",
|
||||
EccCurveToString(public_key.curve()).c_str(),
|
||||
EccCurveToString(curve_).c_str());
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (session_key_size == nullptr) {
|
||||
LOGE("Output session key size buffer is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (session_key == nullptr && *session_key_size > 0) {
|
||||
LOGE("Output session key buffer is null");
|
||||
return OEMCrypto_ERROR_INVALID_CONTEXT;
|
||||
}
|
||||
if (*session_key_size < kEccSessionKeySize) {
|
||||
*session_key_size = kEccSessionKeySize;
|
||||
return OEMCrypto_ERROR_SHORT_BUFFER;
|
||||
}
|
||||
const int res = ECDH_compute_key(
|
||||
session_key, kEccSessionKeySize,
|
||||
EC_KEY_get0_public_key(public_key.GetEcKey()), key_, WidevineEccKdf);
|
||||
if (res < 0) {
|
||||
LOGE("ECDH error occurred");
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
if (static_cast<size_t>(res) != kEccSessionKeySize) {
|
||||
LOGE("Unexpected key size: size = %d", res);
|
||||
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
|
||||
}
|
||||
*session_key_size = kEccSessionKeySize;
|
||||
return OEMCrypto_SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> EccPrivateKey::DeriveSessionKey(
|
||||
const EccPublicKey& public_key) const {
|
||||
size_t session_key_size = kEccSessionKeySize;
|
||||
std::vector<uint8_t> session_key(session_key_size, 0);
|
||||
const OEMCryptoResult res =
|
||||
DeriveSessionKey(public_key, session_key.data(), &session_key_size);
|
||||
if (res != OEMCrypto_SUCCESS) {
|
||||
LOGE("Failed to derive session key: result = %d", static_cast<int>(res));
|
||||
session_key.clear();
|
||||
} else {
|
||||
session_key.resize(session_key_size);
|
||||
}
|
||||
return session_key;
|
||||
}
|
||||
|
||||
size_t EccPrivateKey::SessionKeyLength() const { return kEccSessionKeySize; }
|
||||
|
||||
EccPrivateKey::~EccPrivateKey() {
|
||||
if (key_ != nullptr) {
|
||||
EC_KEY_free(key_);
|
||||
key_ = nullptr;
|
||||
}
|
||||
curve_ = kEccCurveUnknown;
|
||||
}
|
||||
|
||||
bool EccPrivateKey::InitFromBuffer(const uint8_t* buffer, size_t length) {
|
||||
// Deserialize ECPrivateKey
|
||||
const uint8_t* tp = buffer;
|
||||
ScopedEcKey key(d2i_ECPrivateKey(nullptr, &tp, length));
|
||||
if (!key) {
|
||||
LOGE("Failed to parse key");
|
||||
return false;
|
||||
}
|
||||
curve_ = GetCurveFromKeyGroup(key.get());
|
||||
if (curve_ == kEccCurveUnknown) {
|
||||
LOGE("Failed to determine key group");
|
||||
return false;
|
||||
}
|
||||
// Required flags for IETF compliance.
|
||||
EC_KEY_set_asn1_flag(key.get(), OPENSSL_EC_NAMED_CURVE);
|
||||
EC_KEY_set_conv_form(key.get(), POINT_CONVERSION_UNCOMPRESSED);
|
||||
key_ = key.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EccPrivateKey::InitFromCurve(EccCurve curve) {
|
||||
const EC_GROUP* group = GetEcGroup(curve);
|
||||
if (group == nullptr) {
|
||||
LOGE("Failed to get ECC group");
|
||||
return false;
|
||||
}
|
||||
ScopedEcKey key(EC_KEY_new());
|
||||
if (!key) {
|
||||
LOGE("Failed to allocate key");
|
||||
return false;
|
||||
}
|
||||
if (!EC_KEY_set_group(key.get(), group)) {
|
||||
LOGE("Failed to set group");
|
||||
return false;
|
||||
}
|
||||
// Generate random key.
|
||||
if (!EC_KEY_generate_key(key.get())) {
|
||||
LOGE("Failed to generate random key");
|
||||
return false;
|
||||
}
|
||||
curve_ = curve;
|
||||
// Required flags for IETF compliance.
|
||||
EC_KEY_set_asn1_flag(key.get(), OPENSSL_EC_NAMED_CURVE);
|
||||
EC_KEY_set_conv_form(key.get(), POINT_CONVERSION_UNCOMPRESSED);
|
||||
key_ = key.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace wvoec_ref
|
||||
245
libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.h
Normal file
245
libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.h
Normal file
@@ -0,0 +1,245 @@
|
||||
// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary
|
||||
// source code may only be used and distributed under the Widevine License
|
||||
// Agreement.
|
||||
//
|
||||
// Reference implementation of OEMCrypto APIs
|
||||
//
|
||||
#ifndef OEMCRYPTO_ECC_KEY_H_
|
||||
#define OEMCRYPTO_ECC_KEY_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <openssl/ec.h>
|
||||
|
||||
#include "OEMCryptoCENCCommon.h"
|
||||
|
||||
namespace wvoec_ref {
|
||||
|
||||
enum EccCurve {
|
||||
kEccCurveUnknown = 0,
|
||||
kEccSecp256r1 = 256,
|
||||
kEccSecp384r1 = 384,
|
||||
kEccSecp521r1 = 521
|
||||
};
|
||||
|
||||
// Returns the string representation of the provided curve.
|
||||
// Intended for logging purposes.
|
||||
std::string EccCurveToString(EccCurve curve);
|
||||
|
||||
class EccPrivateKey;
|
||||
|
||||
class EccPublicKey {
|
||||
public:
|
||||
// Creates a new public key equivalent of the provided private key.
|
||||
static std::unique_ptr<EccPublicKey> New(const EccPrivateKey& private_key);
|
||||
|
||||
// Loads a serialized EC public key.
|
||||
// The provided |buffer| must contain a valid ASN.1 DER encoded
|
||||
// SubjectPublicKey. Only supported curves by this API are those
|
||||
// enumerated by EccCurve.
|
||||
//
|
||||
// buffer: SubjectPublicKeyInfo = {
|
||||
// algorithm: AlgorithmIdentifier = {
|
||||
// algorithm: OID = id-ecPublicKey,
|
||||
// parameters: ECParameters = {
|
||||
// namedCurve: OID = secp256r1 | secp384r1 | secp521r1
|
||||
// }
|
||||
// },
|
||||
// subjectPublicKey: BIT STRING = ... -- SEC1 encoded ECPoint
|
||||
// }
|
||||
//
|
||||
// Failure will occur if the provided |buffer| does not contain a
|
||||
// valid SubjectPublicKey, or if the specified curve is not
|
||||
// supported.
|
||||
static std::unique_ptr<EccPublicKey> Load(const uint8_t* buffer,
|
||||
size_t length);
|
||||
static std::unique_ptr<EccPublicKey> Load(const std::string& buffer);
|
||||
static std::unique_ptr<EccPublicKey> Load(const std::vector<uint8_t>& buffer);
|
||||
|
||||
EccCurve curve() const { return curve_; }
|
||||
const EC_KEY* GetEcKey() const { return key_; }
|
||||
|
||||
// Checks if the provided |private_key| is the EC private key of this
|
||||
// public key.
|
||||
bool IsMatchingPrivateKey(const EccPrivateKey& private_key) const;
|
||||
|
||||
// Serializes the public key into an ASN.1 DER encoded SubjectPublicKey
|
||||
// representation.
|
||||
// On success, |*buffer_size| is populated with the number of bytes
|
||||
// written to |buffer|, and OEMCrypto_SUCCESS is returned.
|
||||
// If the provided |*buffer_size| is too small, ERROR_SHORT_BUFFER
|
||||
// is returned and |*buffer_size| is set to the required buffer size.
|
||||
OEMCryptoResult Serialize(uint8_t* buffer, size_t* buffer_size) const;
|
||||
// Same as above, except directly returns the serialized key.
|
||||
// Returns an empty vector on error.
|
||||
std::vector<uint8_t> Serialize() const;
|
||||
|
||||
// Verifies the |signature| matches the provided |message| by the
|
||||
// private equivalent of this public key.
|
||||
// The |signature| should be a valid ASN.1 DER encoded
|
||||
// ECDSA-Sig-Value.
|
||||
// This implementation uses ECDSA with the following digest
|
||||
// algorithms for the supported curve types.
|
||||
// - SHA-256 / secp256r1
|
||||
// - SHA-384 / secp384r1 (optional support)
|
||||
// - SHA-512 / secp521r1 (optional support)
|
||||
// Returns:
|
||||
// OEMCrypto_SUCCESS if signature is valid
|
||||
// OEMCrypto_ERROR_SIGNATURE_FAILURE if the signature is invalid
|
||||
// Any other result indicates an unexpected error
|
||||
OEMCryptoResult VerifySignature(const uint8_t* message, size_t message_length,
|
||||
const uint8_t* signature,
|
||||
size_t signature_length) const;
|
||||
OEMCryptoResult VerifySignature(const std::string& message,
|
||||
const std::string& signature) const;
|
||||
OEMCryptoResult VerifySignature(const std::vector<uint8_t>& message,
|
||||
const std::vector<uint8_t>& signature) const;
|
||||
|
||||
~EccPublicKey();
|
||||
|
||||
EccPublicKey(const EccPublicKey&) = delete;
|
||||
EccPublicKey(EccPublicKey&&) = delete;
|
||||
const EccPublicKey& operator=(const EccPublicKey&) = delete;
|
||||
EccPublicKey& operator=(EccPublicKey&&) = delete;
|
||||
|
||||
private:
|
||||
EccPublicKey() {}
|
||||
|
||||
// Initializes the public key object using the provided |buffer|.
|
||||
// In case of any failure, false is return and the key should be
|
||||
// discarded.
|
||||
bool InitFromBuffer(const uint8_t* buffer, size_t length);
|
||||
// Initializes the public key object from a private.
|
||||
bool InitFromPrivateKey(const EccPrivateKey& private_key);
|
||||
|
||||
// OpenSSL/BoringSSL implementation of an ECC key.
|
||||
// As a public key, this will only have key point initialized.
|
||||
EC_KEY* key_ = nullptr;
|
||||
EccCurve curve_ = kEccCurveUnknown;
|
||||
};
|
||||
|
||||
class EccPrivateKey {
|
||||
public:
|
||||
// Creates a new, pseudorandom ECC private key belonging to the
|
||||
// curve specified.
|
||||
static std::unique_ptr<EccPrivateKey> New(EccCurve curve);
|
||||
|
||||
// Loads a serialized ECC private key.
|
||||
// The provided |buffer| must contain a valid ASN.1 DER encoded
|
||||
// ECPrivateKey. Only supported curves by this API are those
|
||||
// enumerated by EccCurve.
|
||||
//
|
||||
// buffer: ECPrivateKey = {
|
||||
// version: INTEGER = ecPrivateKeyVer1,
|
||||
// privateKey: OCTET STRING = ..., -- I2OSP of private key point
|
||||
// parameters: ECParameters = {
|
||||
// namedCurve: OID = secp256r1 | secp384r1 | secp521r1
|
||||
// },
|
||||
// publicKey: BIT STRING OPTIONAL = ... -- SEC1 encoded ECPoint
|
||||
// }
|
||||
// Note: If the public key is not included, then it is computed from
|
||||
// the private key.
|
||||
//
|
||||
// Failure will occur if the provided |buffer| does not contain a
|
||||
// valid ECPrivateKey, or if the specified curve is not supported.
|
||||
static std::unique_ptr<EccPrivateKey> Load(const uint8_t* buffer,
|
||||
size_t length);
|
||||
static std::unique_ptr<EccPrivateKey> Load(const std::string& buffer);
|
||||
static std::unique_ptr<EccPrivateKey> Load(
|
||||
const std::vector<uint8_t>& buffer);
|
||||
|
||||
// Creates a new ECC public key of this private key.
|
||||
// Equivalent to calling EccPublicKey::New with this private
|
||||
// key.
|
||||
std::unique_ptr<EccPublicKey> MakePublicKey() const;
|
||||
|
||||
EccCurve curve() const { return curve_; }
|
||||
const EC_KEY* GetEcKey() const { return key_; }
|
||||
|
||||
// Checks if the provided |public_key| is the EC public key of this
|
||||
// private key.
|
||||
bool IsMatchingPublicKey(const EccPublicKey& public_key) const;
|
||||
|
||||
// Serializes the private key into an ASN.1 DER encoded ECPrivateKey
|
||||
// representation.
|
||||
// On success, |*buffer_size| is populated with the number of bytes
|
||||
// written to |buffer|, and SUCCESS is returned.
|
||||
// If the provided |*buffer_size| is too small,
|
||||
// OEMCrypto_ERROR_SHORT_BUFFER is returned and |*buffer_size| is
|
||||
// set to the required buffer size.
|
||||
OEMCryptoResult Serialize(uint8_t* buffer, size_t* buffer_size) const;
|
||||
// Same as above, except directly returns the serialized key.
|
||||
// Returns an empty vector on error.
|
||||
std::vector<uint8_t> Serialize() const;
|
||||
|
||||
// Signs the provided |message| and serializes the signature
|
||||
// point to |signature| as a ASN.1 DER encoded ECDSA-Sig-Value.
|
||||
// This implementation uses ECDSA with the following digest
|
||||
// algorithms for the supported curve types.
|
||||
// - SHA-256 / secp256r1
|
||||
// - SHA-384 / secp384r1 (optional support)
|
||||
// - SHA-512 / secp521r1 (optional support)
|
||||
// On success, |*signature_length| is populated with the number of
|
||||
// bytes written to |signature|, and SUCCESS is returned.
|
||||
// If the provided |*signature_length| is too small,
|
||||
// OEMCrypto_ERROR_SHORT_BUFFER is returned and |*signature_length|
|
||||
// is set to the required signature size.
|
||||
OEMCryptoResult GenerateSignature(const uint8_t* message,
|
||||
size_t message_length, uint8_t* signature,
|
||||
size_t* signature_length) const;
|
||||
// Same as above, except directly returns the serialized signature.
|
||||
// Returns an empty vector on error.
|
||||
std::vector<uint8_t> GenerateSignature(
|
||||
const std::vector<uint8_t>& message) const;
|
||||
std::vector<uint8_t> GenerateSignature(const std::string& message) const;
|
||||
// Returns an upper bound for the signature size. May be larger than
|
||||
// the actual signature generated by GenerateSignature().
|
||||
size_t SignatureSize() const;
|
||||
|
||||
// Derives the OEMCrypto session key used for deriving other keys.
|
||||
// The provided public key must be of the same curve.
|
||||
// On success, |*session_key_size| is populated with the number of
|
||||
// bytes written to |session_key|, and OEMCrypto_SUCCESS is returned.
|
||||
// If the provided |*session_key_size| is too small,
|
||||
// OEMCrypto_ERROR_SHORT_BUFFER is returned and |*session_key_size|
|
||||
// is set to the required buffer size.
|
||||
OEMCryptoResult DeriveSessionKey(const EccPublicKey& public_key,
|
||||
uint8_t* session_key,
|
||||
size_t* session_key_size) const;
|
||||
// Same as above, except directly returns the derived key.
|
||||
std::vector<uint8_t> DeriveSessionKey(const EccPublicKey& public_key) const;
|
||||
// Returns the byte length of the symmetric key that would be derived
|
||||
// by DeriveSymmetricKey().
|
||||
size_t SessionKeyLength() const;
|
||||
|
||||
~EccPrivateKey();
|
||||
|
||||
EccPrivateKey(const EccPrivateKey&) = delete;
|
||||
EccPrivateKey(EccPrivateKey&&) = delete;
|
||||
const EccPrivateKey& operator=(const EccPrivateKey&) = delete;
|
||||
EccPrivateKey& operator=(EccPrivateKey&&) = delete;
|
||||
|
||||
private:
|
||||
EccPrivateKey() {}
|
||||
|
||||
// Initializes the public key object using the provided |buffer|.
|
||||
// In case of any failure, false is return and the key should be
|
||||
// discarded.
|
||||
bool InitFromBuffer(const uint8_t* buffer, size_t length);
|
||||
// Generates a new key based on the provided curve.
|
||||
bool InitFromCurve(EccCurve curve);
|
||||
|
||||
// OpenSSL/BoringSSL implementation of an ECC key.
|
||||
// The public point of the key will always be present.
|
||||
EC_KEY* key_ = nullptr;
|
||||
EccCurve curve_ = kEccCurveUnknown;
|
||||
};
|
||||
|
||||
} // namespace wvoec_ref
|
||||
|
||||
#endif // OEMCRYPTO_EC_KEY_H_
|
||||
71
libwvdrmengine/oemcrypto/ref/src/scoped_object.h
Normal file
71
libwvdrmengine/oemcrypto/ref/src/scoped_object.h
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary
|
||||
// source code may only be used and distributed under the Widevine License
|
||||
// Agreement.
|
||||
//
|
||||
// Reference implementation of OEMCrypto APIs
|
||||
//
|
||||
#ifndef SCOPED_OBJECT_H_
|
||||
#define SCOPED_OBJECT_H_
|
||||
|
||||
namespace wvoec_ref {
|
||||
|
||||
// A generic wrapper around pointer. This allows for automatic
|
||||
// memory clean up when the ScopedObject variable goes out of scope.
|
||||
// This is intended to be used with OpenSSL/BoringSSL structs.
|
||||
template <typename Type, void Destructor(Type*)>
|
||||
class ScopedObject {
|
||||
public:
|
||||
ScopedObject() : ptr_(nullptr) {}
|
||||
ScopedObject(Type* ptr) : ptr_(ptr) {}
|
||||
~ScopedObject() {
|
||||
if (ptr_) {
|
||||
Destructor(ptr_);
|
||||
ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy construction and assignment are not allowed.
|
||||
ScopedObject(const ScopedObject& other) = delete;
|
||||
ScopedObject& operator=(const ScopedObject& other) = delete;
|
||||
|
||||
// Move construction and assignment are allowed.
|
||||
ScopedObject(ScopedObject&& other) : ptr_(other.ptr_) {
|
||||
other.ptr_ = nullptr;
|
||||
}
|
||||
ScopedObject& operator=(ScopedObject&& other) {
|
||||
if (ptr_) {
|
||||
Destructor(ptr_);
|
||||
}
|
||||
ptr_ = other.ptr_;
|
||||
other.ptr_ = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
explicit operator bool() const { return ptr_ != nullptr; }
|
||||
|
||||
Type& operator*() { return *ptr_; }
|
||||
Type* get() const { return ptr_; }
|
||||
Type* operator->() const { return ptr_; }
|
||||
|
||||
// Releasing the pointer will remove the responsibility of the
|
||||
// ScopedObject to clean up the pointer.
|
||||
Type* release() {
|
||||
Type* temp = ptr_;
|
||||
ptr_ = nullptr;
|
||||
return temp;
|
||||
}
|
||||
|
||||
void reset(Type* ptr = nullptr) {
|
||||
if (ptr_) {
|
||||
Destructor(ptr_);
|
||||
}
|
||||
ptr_ = ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
Type* ptr_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace wvoec_ref
|
||||
|
||||
#endif // SCOPED_OBJECT_H_
|
||||
Reference in New Issue
Block a user