Files
android/libwvdrmengine/oemcrypto/ref/src/oemcrypto_rsa_key.cpp
Alex Dale f6f5099604 Restructed reference root of trust (2/3 DRM Cert)
[ Merge of http://go/wvgerrit/115551 ]

This change is the second part of a three part change for restructing
the root of trust used by the reference implementation.

The use of RSA_shared_ptr has been replaced with the standard library
std::shared_ptr using the RsaPrivateKey wrapper class.  The
AuthenticationRoot class now uses this for the built-in DRM cert key.

RSA decryption and signature operations within the session context are
now performed the RsaPrivateKey class.  This has reduced the code size
and complexity within the reference and testbed, focusing their
implementation on key policy and less on mechanics.

Bug: 168544740
Bug: 135283522
Test: oemcrypto_unittests ce_cdm_tests
Change-Id: Ic743a529a9858f3182290d8bcf5e1633737b005b
2021-03-24 19:14:17 -07:00

1169 lines
39 KiB
C++

// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary
// source code may only be used and distributed under the Widevine License
// Agreement.
//
// Reference implementation of OEMCrypto APIs
//
#include "oemcrypto_rsa_key.h"
#include <assert.h>
#include <string.h>
#include <netinet/in.h>
#include <openssl/bio.h>
#include <openssl/bn.h>
#include <openssl/crypto.h>
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include "OEMCryptoCENC.h"
#include "log.h"
#include "oemcrypto_types.h"
#include "scoped_object.h"
namespace wvoec_ref {
namespace {
// Estimated size of an RSA private key.
// The private key constists of:
// n : Modulos : byteSize(n) ~= bitSize(n)/8
// e : Public exponent : 4 bytes
// d : Private exponent : ~byteSize(n)
// p : Prime 1 : ~byteSize(n)/2
// q : Prime 2 : ~byteSize(n)/2
// d (mod p-1) : Exponent 1 : ~byteSize(n)/2
// d (mod q-1) : Exponent 2 : ~byteSize(n)/2
// q^-1 (mod p) : Coefficient : ~byteSize(n)/2
// And the ASN.1 tags for each component (roughly 25 bytes).
constexpr size_t kPrivateKeySize = (3072 / 8) * 5 + 25;
// Estimated size of an RSA public key.
// The public key constists of:
// The private key constists of:
// n : Modulos : byteSize(n) ~= bitSize(n)/8
// e : Public exponent : 4 bytes
// And the ASN.1 tags + outer structure.
constexpr size_t kPublicKeySize = (3072 / 8) + 100;
// 128 bit key, intended to be used with CMAC-AES-128.
constexpr size_t kRsaSessionKeySize = 16;
// Encryption key used by OEMCrypto session for encrypting and
// decrypting data.
constexpr size_t kEncryptionKeySize = wvoec::KEY_SIZE;
// Salt length used by OEMCrypto's RSASSA-PSS implementation.
// See description of kRsaPssDefault for more information.
constexpr size_t kPssSaltLength = 20;
// Requirement of CAST receivers.
constexpr size_t kRsaPkcs1CastMaxMessageSize = 83;
using ScopedBio = ScopedObject<BIO, BIO_vfree>;
using ScopedBigNum = ScopedObject<BIGNUM, BN_free>;
using ScopedEvpMdCtx = ScopedObject<EVP_MD_CTX, EVP_MD_CTX_free>;
using ScopedEvpPkey = ScopedObject<EVP_PKEY, EVP_PKEY_free>;
using ScopedRsaKey = ScopedObject<RSA, RSA_free>;
using ScopedRsaPrivKeyInfo =
ScopedObject<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_free>;
// Estimates the RSA rough field size from the real bit size of the
// RSA modulos. The actual bit length could vary by a few bits.
RsaFieldSize RealBitSizeToFieldSize(int bits) {
if (bits > 1800 && bits < 2200) {
return kRsa2048Bit;
}
if (bits > 2800 && bits < 3200) {
return kRsa3072Bit;
}
return kRsaFieldUnknown;
}
} // namespace
std::string RsaFieldSizeToString(RsaFieldSize field_size) {
switch (field_size) {
case kRsa2048Bit:
return "RSA-2048";
case kRsa3072Bit:
return "RSA-3072";
case kRsaFieldUnknown:
return "Unknown";
}
return "Unknown(" + std::to_string(static_cast<int>(field_size)) + ")";
}
bool RsaKeysAreMatchingPair(const RSA* public_key, const RSA* private_key) {
if (public_key == nullptr) {
LOGE("Public key is null");
return false;
}
if (private_key == nullptr) {
LOGE("Private key is null");
return false;
}
// Step 1: Extract public key components.
const BIGNUM* public_n = nullptr;
const BIGNUM* public_e = nullptr;
const BIGNUM* d = nullptr;
RSA_get0_key(public_key, &public_n, &public_e, &d);
if (public_n == nullptr || public_e == nullptr) {
LOGE("Failed to get RSA public key components");
return false;
}
// Step 2: Extract private key components.
const BIGNUM* private_n = nullptr;
const BIGNUM* private_e = nullptr;
RSA_get0_key(private_key, &private_n, &private_e, &d);
if (private_n == nullptr || private_e == nullptr) {
LOGE("Failed to get RSA private key components");
return false;
}
// Step 3: Compare RSA components.
if (BN_cmp(public_n, private_n)) {
LOGD("RSA modulos do not match");
return false;
}
if (BN_cmp(public_e, private_e)) {
LOGD("RSA exponents do not match");
return false;
}
return true;
}
// ===== ===== ===== RSA Public Key ===== ===== =====
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::New(
const RsaPrivateKey& private_key) {
std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
if (!key->InitFromPrivateKey(private_key)) {
LOGE("Failed to initialize public key from private key");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::Load(const uint8_t* buffer,
size_t length) {
std::unique_ptr<RsaPublicKey> key;
if (buffer == nullptr) {
LOGE("Provided public key buffer is null");
return key;
}
if (length == 0) {
LOGE("Provided public key buffer is zero length");
return key;
}
key.reset(new RsaPublicKey());
if (!key->InitFromBuffer(buffer, length)) {
LOGE("Failed to initialize public key from buffer");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::Load(const std::string& buffer) {
if (buffer.empty()) {
LOGE("Provided public key buffer is empty");
return std::unique_ptr<RsaPublicKey>();
}
return Load(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size());
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::Load(
const std::vector<uint8_t>& buffer) {
std::unique_ptr<RsaPublicKey> key;
if (buffer.empty()) {
LOGE("Provided public key buffer is empty");
return std::unique_ptr<RsaPublicKey>();
}
return Load(buffer.data(), buffer.size());
}
bool RsaPublicKey::IsMatchingPrivateKey(
const RsaPrivateKey& private_key) const {
if (private_key.field_size() != field_size_) {
return false;
}
return RsaKeysAreMatchingPair(GetRsaKey(), private_key.GetRsaKey());
}
OEMCryptoResult RsaPublicKey::Serialize(uint8_t* buffer,
size_t* buffer_size) const {
if (buffer_size == nullptr) {
LOGE("Output buffer size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (buffer == nullptr && *buffer_size > 0) {
LOGE("Output buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
uint8_t* der_key = nullptr;
const int der_res = i2d_RSA_PUBKEY(key_, &der_key);
if (der_res < 0) {
LOGE("Public key serialization failed");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (der_key == nullptr) {
LOGE("Encoded key is unexpectedly null");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (der_res == 0) {
LOGE("Unexpected DER encoded size");
OPENSSL_free(der_key);
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
const size_t required_size = static_cast<size_t>(der_res);
if (buffer == nullptr || *buffer_size < required_size) {
*buffer_size = required_size;
OPENSSL_free(der_key);
return OEMCrypto_ERROR_SHORT_BUFFER;
}
memcpy(buffer, der_key, required_size);
*buffer_size = required_size;
OPENSSL_free(der_key);
return OEMCrypto_SUCCESS;
}
std::vector<uint8_t> RsaPublicKey::Serialize() const {
size_t key_size = kPublicKeySize;
std::vector<uint8_t> key_data(key_size, 0);
const OEMCryptoResult res = Serialize(key_data.data(), &key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to serialize public key: result = %d", static_cast<int>(res));
key_data.clear();
} else {
key_data.resize(key_size);
}
return key_data;
}
OEMCryptoResult RsaPublicKey::VerifySignature(
const uint8_t* message, size_t message_length, const uint8_t* signature,
size_t signature_length, RsaSignatureAlgorithm algorithm) const {
if (signature == nullptr || signature_length == 0) {
LOGE("Signature is missing");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (message == nullptr && message_length > 0) {
LOGE("Bad message data");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
switch (algorithm) {
case kRsaPssDefault:
return VerifySignaturePss(message, message_length, signature,
signature_length);
case kRsaPkcs1Cast:
return VerifySignaturePkcs1Cast(message, message_length, signature,
signature_length);
}
LOGE("Unknown RSA signature algorithm: %d", static_cast<int>(algorithm));
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
OEMCryptoResult RsaPublicKey::VerifySignature(
const std::string& message, const std::string& signature,
RsaSignatureAlgorithm algorithm) const {
if (signature.empty()) {
LOGE("Signature should not be empty");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return VerifySignature(reinterpret_cast<const uint8_t*>(message.data()),
message.size(),
reinterpret_cast<const uint8_t*>(signature.data()),
signature.size(), algorithm);
}
OEMCryptoResult RsaPublicKey::VerifySignature(
const std::vector<uint8_t>& message, const std::vector<uint8_t>& signature,
RsaSignatureAlgorithm algorithm) const {
if (signature.empty()) {
LOGE("Signature should not be empty");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return VerifySignature(message.data(), message.size(), signature.data(),
signature.size(), algorithm);
}
OEMCryptoResult RsaPublicKey::EncryptSessionKey(
const uint8_t* session_key, size_t session_key_size,
uint8_t* enc_session_key, size_t* enc_session_key_size) const {
if (session_key == nullptr) {
LOGE("Session key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (session_key_size != kRsaSessionKeySize) {
LOGE("Unexpected session key size: expected = %zu, actual = %zu",
kRsaSessionKeySize, session_key_size);
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_session_key_size == nullptr) {
LOGE("Output encrypted session key size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_session_key == nullptr && *enc_session_key_size > 0) {
LOGE("Output encrypted session key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return EncryptOaep(session_key, session_key_size, enc_session_key,
enc_session_key_size);
}
std::vector<uint8_t> RsaPublicKey::EncryptSessionKey(
const std::vector<uint8_t>& session_key) const {
if (session_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_session_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_session_key(enc_session_key_size);
const OEMCryptoResult res =
EncryptSessionKey(session_key.data(), session_key.size(),
enc_session_key.data(), &enc_session_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt session key: result = %d", static_cast<int>(res));
enc_session_key.clear();
} else {
enc_session_key.resize(enc_session_key_size);
}
return enc_session_key;
}
std::vector<uint8_t> RsaPublicKey::EncryptSessionKey(
const std::string& session_key) const {
if (session_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_session_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_session_key(enc_session_key_size);
const OEMCryptoResult res = EncryptSessionKey(
reinterpret_cast<const uint8_t*>(session_key.data()), session_key.size(),
enc_session_key.data(), &enc_session_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt session key: result = %d", static_cast<int>(res));
enc_session_key.clear();
} else {
enc_session_key.resize(enc_session_key_size);
}
return enc_session_key;
}
OEMCryptoResult RsaPublicKey::EncryptEncryptionKey(
const uint8_t* encryption_key, size_t encryption_key_size,
uint8_t* enc_encryption_key, size_t* enc_encryption_key_size) const {
if (encryption_key == nullptr) {
LOGE("Encryption key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (encryption_key_size != kEncryptionKeySize) {
LOGE("Unexpected encryption key size: expected = %zu, actual = %zu",
kEncryptionKeySize, encryption_key_size);
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_encryption_key_size == nullptr) {
LOGE("Output encrypted encryption key size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_encryption_key == nullptr && *enc_encryption_key_size > 0) {
LOGE("Output encrypted encryption key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return EncryptOaep(encryption_key, encryption_key_size, enc_encryption_key,
enc_encryption_key_size);
}
std::vector<uint8_t> RsaPublicKey::EncryptEncryptionKey(
const std::vector<uint8_t>& encryption_key) const {
if (encryption_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_encryption_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_encryption_key(enc_encryption_key_size);
const OEMCryptoResult res =
EncryptSessionKey(encryption_key.data(), encryption_key.size(),
enc_encryption_key.data(), &enc_encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt encryption key: result = %d",
static_cast<int>(res));
enc_encryption_key.clear();
} else {
enc_encryption_key.resize(enc_encryption_key_size);
}
return enc_encryption_key;
}
std::vector<uint8_t> RsaPublicKey::EncryptEncryptionKey(
const std::string& encryption_key) const {
if (encryption_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_encryption_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_encryption_key(enc_encryption_key_size);
const OEMCryptoResult res =
EncryptSessionKey(reinterpret_cast<const uint8_t*>(encryption_key.data()),
encryption_key.size(), enc_encryption_key.data(),
&enc_encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt encryption key: result = %d",
static_cast<int>(res));
enc_encryption_key.clear();
} else {
enc_encryption_key.resize(enc_encryption_key_size);
}
return enc_encryption_key;
}
RsaPublicKey::~RsaPublicKey() {
if (key_ != nullptr) {
RSA_free(key_);
key_ = nullptr;
}
allowed_schemes_ = 0;
field_size_ = kRsaFieldUnknown;
}
bool RsaPublicKey::InitFromBuffer(const uint8_t* buffer, size_t length) {
// Step 1: Deserialize SubjectPublicKeyInfo as RSA key.
const uint8_t* tp = buffer;
ScopedRsaKey key(d2i_RSA_PUBKEY(nullptr, &tp, length));
if (!key) {
LOGE("Failed to parse key");
return false;
}
// Step 2: Verify key.
const int bits = RSA_bits(key.get());
field_size_ = RealBitSizeToFieldSize(bits);
if (field_size_ == kRsaFieldUnknown) {
LOGE("Unsupported RSA key size: bits = %d", bits);
return false;
}
allowed_schemes_ = kSign_RSASSA_PSS;
key_ = key.release();
return true;
}
bool RsaPublicKey::InitFromPrivateKey(const RsaPrivateKey& private_key) {
ScopedRsaKey key(RSA_new());
if (!key) {
LOGE("Failed to allocate key");
return false;
}
const BIGNUM* n = nullptr;
const BIGNUM* e = nullptr;
const BIGNUM* d = nullptr;
RSA_get0_key(private_key.GetRsaKey(), &n, &e, &d);
if (n == nullptr || e == nullptr) {
LOGE("Failed to get RSA private key components");
return false;
}
ScopedBigNum dub_n(BN_dup(n));
ScopedBigNum dub_e(BN_dup(e));
if (!dub_n || !dub_e) {
LOGE("Failed to duplicate RSA public key components");
return false;
}
if (!RSA_set0_key(key.get(), dub_n.get(), dub_e.get(), nullptr)) {
LOGE("Failed to RSA public key components");
return false;
}
// Ownership has transferred to the RSA key.
dub_n.release();
dub_e.release();
key_ = key.release();
allowed_schemes_ = private_key.allowed_schemes();
field_size_ = private_key.field_size();
return true;
}
OEMCryptoResult RsaPublicKey::VerifySignaturePss(
const uint8_t* message, size_t message_length, const uint8_t* signature,
size_t signature_length) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_RSASSA_PSS)) {
LOGE("RSA key cannot verify using PSS");
return OEMCrypto_ERROR_INVALID_RSA_KEY;
}
// Step 1: Create a high-level key from RSA key.
ScopedEvpPkey pkey(EVP_PKEY_new());
if (!pkey) {
LOGE("Failed to allocate PKEY");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) {
LOGE("Failed to set PKEY RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2a: Setup a EVP MD CTX for PSS Verification.
ScopedEvpMdCtx md_ctx = EVP_MD_CTX_new();
if (!md_ctx) {
LOGE("Failed to allocate MD CTX");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx|
int res = EVP_DigestVerifyInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr,
pkey.get());
if (res != 1) {
LOGE("Failed to initialize MD CTX for verification");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (pkey_ctx == nullptr) {
LOGE("PKEY CTX is unexpectedly null");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2b: Configure OEMCrypto RSASSA-PSS options.
res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING);
if (res != 1) {
LOGE("Failed to set PSS padding");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength);
if (res != 1) {
LOGE("Failed to set PSS salt length");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1());
if (res != 1) {
LOGE("Failed to set PSS MGF1 MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 3: Digest message.
if (message_length > 0) {
res = EVP_DigestVerifyUpdate(md_ctx.get(), message, message_length);
if (res != 1) {
LOGE("Failed to update MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
}
// Step 4: Verify message.
res = EVP_DigestVerifyFinal(md_ctx.get(), signature, signature_length);
if (res < 0) {
LOGE("Failed to perform RSASSA-PSS-VERIFY");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
return res ? OEMCrypto_SUCCESS : OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
OEMCryptoResult RsaPublicKey::VerifySignaturePkcs1Cast(
const uint8_t* message, size_t message_length, const uint8_t* signature,
size_t signature_length) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_PKCS1_Block1)) {
LOGE("RSA key cannot verify using PKCS1");
return OEMCrypto_ERROR_INVALID_RSA_KEY;
}
if (message_length > kRsaPkcs1CastMaxMessageSize) {
LOGE("Message is too large for CAST PKCS1 signature: size = %zu",
message_length);
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
// Step 1: Convert encrypted blob into digest.
std::vector<uint8_t> digest(RSA_size(key_));
const int res = RSA_public_decrypt(signature_length, signature, digest.data(),
key_, RSA_PKCS1_PADDING);
if (res <= 0) {
LOGE("Failed to perform public decryption");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Compare digests.
const size_t digest_size = static_cast<size_t>(res);
if (digest_size != message_length) {
LOGD("Digest size does not match");
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
digest.resize(digest_size);
for (size_t i = 0; i < digest_size; i++) {
if (message[i] != digest[i]) {
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
}
return OEMCrypto_SUCCESS;
}
OEMCryptoResult RsaPublicKey::EncryptOaep(const uint8_t* message,
size_t message_size,
uint8_t* enc_message,
size_t* enc_message_length) const {
// Step 1: Encrypt using RSAES-OAEP.
std::vector<uint8_t> encrypt_buffer(RSA_size(key_));
const int enc_res =
RSA_public_encrypt(message_size, message, encrypt_buffer.data(), key_,
RSA_PKCS1_OAEP_PADDING);
if (enc_res < 0) {
LOGE("Failed to perform RSAES-OAEP encrypt");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Copy encrypted data key.
const size_t enc_size = static_cast<size_t>(enc_res);
if (*enc_message_length < enc_size) {
*enc_message_length = enc_size;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
*enc_message_length = enc_size;
memcpy(enc_message, encrypt_buffer.data(), enc_size);
return OEMCrypto_SUCCESS;
}
// ===== ===== ===== RSA Private Key ===== ===== =====
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::New(RsaFieldSize field_size) {
std::unique_ptr<RsaPrivateKey> key(new RsaPrivateKey());
if (!key->InitFromFieldSize(field_size)) {
LOGE("Failed to initialize private key from field size");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::Load(const uint8_t* buffer,
size_t length) {
std::unique_ptr<RsaPrivateKey> key;
if (buffer == nullptr) {
LOGE("Provided private key buffer is null");
return key;
}
if (length == 0) {
LOGE("Provided private key buffer is zero length");
return key;
}
key.reset(new RsaPrivateKey());
if (!key->InitFromBuffer(buffer, length)) {
LOGE("Failed to initialize private key from buffer");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::Load(const std::string& buffer) {
if (buffer.empty()) {
LOGE("Provided private key buffer is empty");
return std::unique_ptr<RsaPrivateKey>();
}
return Load(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size());
}
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::Load(
const std::vector<uint8_t>& buffer) {
if (buffer.empty()) {
LOGE("Provided private key buffer is empty");
return std::unique_ptr<RsaPrivateKey>();
}
return Load(buffer.data(), buffer.size());
}
std::unique_ptr<RsaPublicKey> RsaPrivateKey::MakePublicKey() const {
return RsaPublicKey::New(*this);
}
bool RsaPrivateKey::IsMatchingPublicKey(const RsaPublicKey& public_key) const {
return RsaKeysAreMatchingPair(public_key.GetRsaKey(), GetRsaKey());
}
OEMCryptoResult RsaPrivateKey::Serialize(uint8_t* buffer,
size_t* buffer_size) const {
if (buffer_size == nullptr) {
LOGE("Output buffer size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (buffer == nullptr && *buffer_size > 0) {
LOGE("Output buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
// Step 1: Convert RSA key to EVP.
ScopedEvpPkey pkey(EVP_PKEY_new());
if (!pkey) {
LOGE("Failed to allocate EVP");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) {
LOGE("Failed to set EVP RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Convert RSA EVP to PKCS8 format.
ScopedRsaPrivKeyInfo priv_info(EVP_PKEY2PKCS8(pkey.get()));
if (!priv_info) {
LOGE("Failed to convert RSA key to PKCS8 info");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 3: Serialize PKCS8 to DER encoding.
ScopedBio bio(BIO_new(BIO_s_mem()));
if (!bio) {
LOGE("Failed to allocate IO buffer for RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!i2d_PKCS8_PRIV_KEY_INFO_bio(bio.get(), priv_info.get())) {
LOGE("Failed to serialize RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 4: Determine key size and copy.
char* key_ptr = nullptr;
const long key_size = BIO_get_mem_data(bio.get(), &key_ptr);
if (key_size < 0) {
LOGE("Failed to get RSA key size");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
const size_t required_size =
static_cast<size_t>(key_size) + (explicit_schemes_ ? 8 : 0);
if (*buffer_size < required_size) {
*buffer_size = required_size;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
*buffer_size = required_size;
if (explicit_schemes_) {
memcpy(buffer, "SIGN", 4);
const uint32_t allowed_schemes = htonl(allowed_schemes_);
memcpy(&buffer[4], &allowed_schemes, 4);
memcpy(&buffer[8], key_ptr, required_size - 8);
} else {
memcpy(buffer, key_ptr, required_size);
}
return OEMCrypto_SUCCESS;
}
std::vector<uint8_t> RsaPrivateKey::Serialize() const {
size_t key_size = kPrivateKeySize;
std::vector<uint8_t> key_data(key_size, 0);
OEMCryptoResult res = Serialize(key_data.data(), &key_size);
if (res == OEMCrypto_ERROR_SHORT_BUFFER) {
LOGD(
"Actual RSA private key size is larger than expected: "
"expected = %zu, actual = %zu",
kPrivateKeySize, key_size);
key_data.resize(key_size);
res = Serialize(key_data.data(), &key_size);
}
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to serialize private key: result = %d", static_cast<int>(res));
key_data.clear();
} else {
key_data.resize(key_size);
}
return key_data;
}
OEMCryptoResult RsaPrivateKey::GenerateSignature(
const uint8_t* message, size_t message_length,
RsaSignatureAlgorithm algorithm, uint8_t* signature,
size_t* signature_length) const {
if (signature_length == nullptr) {
LOGE("Output signature size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (signature == nullptr && *signature_length > 0) {
LOGE("Output signature is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (message == nullptr && message_length > 0) {
LOGE("Invalid message data");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
switch (algorithm) {
case kRsaPssDefault:
return GenerateSignaturePss(message, message_length, signature,
signature_length);
case kRsaPkcs1Cast:
return GenerateSignaturePkcs1Cast(message, message_length, signature,
signature_length);
}
LOGE("Unknown RSA signature algorithm: %d", static_cast<int>(algorithm));
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
std::vector<uint8_t> RsaPrivateKey::GenerateSignature(
const std::string& message, RsaSignatureAlgorithm algorithm) const {
size_t signature_size = SignatureSize();
std::vector<uint8_t> signature(signature_size);
const OEMCryptoResult res = GenerateSignature(
reinterpret_cast<const uint8_t*>(message.data()), message.size(),
algorithm, signature.data(), &signature_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to generate signature: result = %d", static_cast<int>(res));
signature.clear();
} else {
signature.resize(signature_size);
}
return signature;
}
std::vector<uint8_t> RsaPrivateKey::GenerateSignature(
const std::vector<uint8_t>& message,
RsaSignatureAlgorithm algorithm) const {
size_t signature_size = SignatureSize();
std::vector<uint8_t> signature(signature_size, 0);
const OEMCryptoResult res =
GenerateSignature(message.data(), message.size(), algorithm,
signature.data(), &signature_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to generate signature: result = %d", static_cast<int>(res));
signature.clear();
} else {
signature.resize(signature_size);
}
return signature;
}
size_t RsaPrivateKey::SignatureSize() const { return RSA_size(key_); }
OEMCryptoResult RsaPrivateKey::DecryptSessionKey(
const uint8_t* enc_session_key, size_t enc_session_key_size,
uint8_t* session_key, size_t* session_key_size) const {
if (enc_session_key == nullptr || enc_session_key_size == 0) {
LOGE("Encrypted session key is %s", enc_session_key ? "empty" : "null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (session_key_size == nullptr) {
LOGE("Output session key size buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (session_key == nullptr && *session_key_size > 0) {
LOGE("Output session key buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (*session_key_size < kRsaSessionKeySize) {
*session_key_size = kRsaSessionKeySize;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
const OEMCryptoResult result = DecryptOaep(
enc_session_key, enc_session_key_size, session_key, kRsaSessionKeySize);
if (result == OEMCrypto_SUCCESS) {
*session_key_size = kRsaSessionKeySize;
} else {
LOGE("Failed to decrypt session key");
}
return result;
}
std::vector<uint8_t> RsaPrivateKey::DecryptSessionKey(
const std::vector<uint8_t>& enc_session_key) const {
size_t session_key_size = kRsaSessionKeySize;
std::vector<uint8_t> session_key(session_key_size, 0);
const OEMCryptoResult res =
DecryptSessionKey(enc_session_key.data(), enc_session_key.size(),
session_key.data(), &session_key_size);
if (res != OEMCrypto_SUCCESS) {
session_key.clear();
} else {
session_key.resize(session_key_size);
}
return session_key;
}
std::vector<uint8_t> RsaPrivateKey::DecryptSessionKey(
const std::string& enc_session_key) const {
size_t session_key_size = kRsaSessionKeySize;
std::vector<uint8_t> session_key(session_key_size, 0);
const OEMCryptoResult res = DecryptSessionKey(
reinterpret_cast<const uint8_t*>(enc_session_key.data()),
enc_session_key.size(), session_key.data(), &session_key_size);
if (res != OEMCrypto_SUCCESS) {
session_key.clear();
} else {
session_key.resize(session_key_size);
}
return session_key;
}
size_t RsaPrivateKey::SessionKeyLength() const { return kRsaSessionKeySize; }
OEMCryptoResult RsaPrivateKey::DecryptEncryptionKey(
const uint8_t* enc_encryption_key, size_t enc_encryption_key_size,
uint8_t* encryption_key, size_t* encryption_key_size) const {
if (enc_encryption_key == nullptr || enc_encryption_key_size == 0) {
LOGE("Encrypted encryption key is %s",
enc_encryption_key ? "empty" : "null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (encryption_key_size == nullptr) {
LOGE("Output encryption key size buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (encryption_key == nullptr && *encryption_key_size > 0) {
LOGE("Output encryption key buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (*encryption_key_size < kEncryptionKeySize) {
*encryption_key_size = kEncryptionKeySize;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
const OEMCryptoResult result =
DecryptOaep(enc_encryption_key, enc_encryption_key_size, encryption_key,
kEncryptionKeySize);
if (result == OEMCrypto_SUCCESS) {
*encryption_key_size = kEncryptionKeySize;
} else {
LOGE("Failed to decrypt encryption key");
}
return result;
}
std::vector<uint8_t> RsaPrivateKey::DecryptEncryptionKey(
const std::vector<uint8_t>& enc_encryption_key) const {
size_t encryption_key_size = kEncryptionKeySize;
std::vector<uint8_t> encryption_key(encryption_key_size, 0);
const OEMCryptoResult res =
DecryptEncryptionKey(enc_encryption_key.data(), enc_encryption_key.size(),
encryption_key.data(), &encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
encryption_key.clear();
} else {
encryption_key.resize(encryption_key_size);
}
return encryption_key;
}
std::vector<uint8_t> RsaPrivateKey::DecryptEncryptionKey(
const std::string& enc_encryption_key) const {
size_t encryption_key_size = kEncryptionKeySize;
std::vector<uint8_t> encryption_key(encryption_key_size, 0);
const OEMCryptoResult res = DecryptEncryptionKey(
reinterpret_cast<const uint8_t*>(enc_encryption_key.data()),
enc_encryption_key.size(), encryption_key.data(), &encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
encryption_key.clear();
} else {
encryption_key.resize(encryption_key_size);
}
return encryption_key;
}
RsaPrivateKey::~RsaPrivateKey() {
if (key_ != nullptr) {
RSA_free(key_);
key_ = nullptr;
}
allowed_schemes_ = 0;
explicit_schemes_ = false;
field_size_ = kRsaFieldUnknown;
}
bool RsaPrivateKey::InitFromBuffer(const uint8_t* buffer, size_t length) {
if (length < 8) {
LOGE("Public key is too small: length = %zu", length);
return false;
}
ScopedBio bio;
// Check allowed scheme type.
if (!memcmp("SIGN", buffer, 4)) {
uint32_t allowed_schemes;
memcpy(&allowed_schemes, reinterpret_cast<const uint8_t*>(&buffer[4]), 4);
allowed_schemes_ = ntohl(allowed_schemes);
bio.reset(BIO_new_mem_buf(&buffer[8], length - 8));
explicit_schemes_ = true;
} else {
allowed_schemes_ = kSign_RSASSA_PSS;
bio.reset(BIO_new_mem_buf(buffer, length));
}
if (!bio) {
LOGE("Failed to allocate BIO buffer");
return false;
}
// Step 1: Deserializes PKCS8 RSA private key info.
ScopedRsaPrivKeyInfo priv_info(
d2i_PKCS8_PRIV_KEY_INFO_bio(bio.get(), nullptr));
if (!priv_info) {
LOGE("Failed to parse private key");
return false;
}
// Step 2: Convert to RSA key.
ScopedEvpPkey pkey(EVP_PKCS82PKEY(priv_info.get()));
if (!pkey) {
LOGE("Failed to convert PKCS8 to EVP");
return false;
}
ScopedRsaKey key(EVP_PKEY_get1_RSA(pkey.get()));
if (!key) {
LOGE("Failed to get RSA key");
return false;
}
const int check = RSA_check_key(key.get());
if (check == 0) {
LOGE("RSA key parameters are invalid");
return false;
} else if (check == -1) {
LOGE("Failed to check RSA key");
return false;
}
// Step 3: Verify field width.
const int bits = RSA_bits(key.get());
field_size_ = RealBitSizeToFieldSize(bits);
if (field_size_ == kRsaFieldUnknown) {
LOGE("Unsupported RSA key size: bits = %d", bits);
return false;
}
key_ = key.release();
return true;
}
bool RsaPrivateKey::InitFromFieldSize(RsaFieldSize field_size) {
if (field_size != kRsa2048Bit && field_size != kRsa3072Bit) {
LOGE("Unsupported RSA field size: bits = %d", static_cast<int>(field_size));
return false;
}
// Step 1: Create exponent.
ScopedBigNum exp(BN_new());
if (!exp) {
LOGE("Failed to allocate RSA exponent");
return false;
}
if (!BN_set_word(exp.get(), RSA_F4)) {
LOGE("Failed to set RSA exponent");
return false;
}
// Step 2: Generate RSA key.
ScopedRsaKey key(RSA_new());
if (!key) {
LOGE("Failed to allocate RSA key");
return false;
}
if (!RSA_generate_key_ex(key.get(), static_cast<int>(field_size), exp.get(),
nullptr)) {
LOGE("Failed to generate RSA key");
return false;
}
key_ = key.release();
allowed_schemes_ = kSign_RSASSA_PSS;
field_size_ = field_size;
return true;
}
OEMCryptoResult RsaPrivateKey::GenerateSignaturePss(
const uint8_t* message, size_t message_length, uint8_t* signature,
size_t* signature_length) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_RSASSA_PSS)) {
LOGE("RSA key cannot sign using PSS");
return OEMCrypto_ERROR_INVALID_RSA_KEY;
}
// Step 1: Create a high-level key from RSA key.
ScopedEvpPkey pkey(EVP_PKEY_new());
if (!pkey) {
LOGE("Failed to allocate PKEY");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) {
LOGE("Failed to set PKEY RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2a: Setup a EVP MD CTX for PSS Signature Generation.
ScopedEvpMdCtx md_ctx(EVP_MD_CTX_new());
if (!md_ctx) {
LOGE("Failed to allocate MD CTX");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx|
int res = EVP_DigestSignInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr,
pkey.get());
if (res != 1) {
LOGE("Failed to initialize MD CTX for signing");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (pkey_ctx == nullptr) {
LOGE("PKEY CTX is unexpectedly null");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2b: Configure OEMCrypto RSASSA-PSS options.
res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING);
if (res != 1) {
LOGE("Failed to set PSS padding");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength);
if (res != 1) {
LOGE("Failed to set PSS salt length");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1());
if (res != 1) {
LOGE("Failed to set PSS MGF1 MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 3: Digest message.
if (message_length > 0) {
res = EVP_DigestSignUpdate(md_ctx.get(), message, message_length);
if (res != 1) {
LOGE("Failed to update MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
}
// Step 4a: Determine size of signature.
size_t actual_signature_length = 0;
res = EVP_DigestSignFinal(md_ctx.get(), nullptr, &actual_signature_length);
if (res != 1) {
LOGE("Failed to determine signature length");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (*signature_length < actual_signature_length) {
*signature_length = actual_signature_length;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
// Step 4b: Generate signature.
res = EVP_DigestSignFinal(md_ctx.get(), signature, signature_length);
if (res != 1) {
LOGE("Failed to perform RSASSA-PSS-SIGN");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
return OEMCrypto_SUCCESS;
}
OEMCryptoResult RsaPrivateKey::GenerateSignaturePkcs1Cast(
const uint8_t* message, size_t message_length, uint8_t* signature,
size_t* signature_length) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_PKCS1_Block1)) {
LOGE("RSA key cannot sign PKCS1");
return OEMCrypto_ERROR_INVALID_RSA_KEY;
}
if (message_length > kRsaPkcs1CastMaxMessageSize) {
LOGE("Message is too large for CAST PKCS1 signature: size = %zu",
message_length);
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
// Step 1: Ensure signature buffer is large enough.
const size_t expected_signature_size = static_cast<size_t>(RSA_size(key_));
if (*signature_length < expected_signature_size) {
*signature_length = expected_signature_size;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
// Step 2: Encrypt with PKCS1 padding.
const int enc_res = RSA_private_encrypt(message_length, message, signature,
key_, RSA_PKCS1_PADDING);
if (enc_res < 0) {
LOGE("Failed to perform private encryption");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
*signature_length = expected_signature_size;
return OEMCrypto_SUCCESS;
}
OEMCryptoResult RsaPrivateKey::DecryptOaep(
const uint8_t* enc_message, size_t enc_message_size, uint8_t* message,
size_t expected_message_length) const {
// Step 1: Decrypt using RSAES-OAEP.
std::vector<uint8_t> decrypt_buffer(RSA_size(key_));
const int dec_res =
RSA_private_decrypt(enc_message_size, enc_message, decrypt_buffer.data(),
key_, RSA_PKCS1_OAEP_PADDING);
if (dec_res < 0) {
LOGE("Failed to perform RSAES-OAEP decrypt");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (static_cast<size_t>(dec_res) != expected_message_length) {
LOGE("Unexpected key size: expected = %zu, actual = %d",
expected_message_length, dec_res);
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Copy decrypted data key.
memcpy(message, decrypt_buffer.data(), expected_message_length);
return OEMCrypto_SUCCESS;
}
} // namespace wvoec_ref