Files
android/libwvdrmengine/oemcrypto/util/src/oemcrypto_rsa_key.cpp
Cong Lin e8add8eed8 Sync oemcrypto files from cdm udc-dev to Android
Changes included in this CL:

166806: Update OEMCrypto_GetDeviceInformation() | https://widevine-internal-review.googlesource.com/c/cdm/+/166806
166808: Update Android L3 after OEMCrypto_GetDeviceInformation() signature changes | https://widevine-internal-review.googlesource.com/c/cdm/+/166808
166809: Decode device info and write it to CSR payload | https://widevine-internal-review.googlesource.com/c/cdm/+/166809
167158: Fix Android include path and copy_files | https://widevine-internal-review.googlesource.com/c/cdm/+/167158
167159: Fix common typos and use inclusive language suggested by Android linter | https://widevine-internal-review.googlesource.com/c/cdm/+/167159

165618: Explicitly state python3 where needed. | https://widevine-internal-review.googlesource.com/c/cdm/+/165618

166757: Update Android.bp for Android | https://widevine-internal-review.googlesource.com/c/cdm/+/166757
164993: Refactor basic oemcrypto unit tests | https://widevine-internal-review.googlesource.com/c/cdm/+/164993
164978: Update OEMCrypto Unit Test Docs | https://widevine-internal-review.googlesource.com/c/cdm/+/164978
166941: Update make files for OEMCrypto | https://widevine-internal-review.googlesource.com/c/cdm/+/166941

165279: Refactor license unit tests | https://widevine-internal-review.googlesource.com/c/cdm/+/165279
165318: Refactor provisioning unit tests | https://widevine-internal-review.googlesource.com/c/cdm/+/165318
164800: Add extra check for renew on license load unit test | https://widevine-internal-review.googlesource.com/c/cdm/+/164800
165860: Remove duplicate definition of MaybeHex() | https://widevine-internal-review.googlesource.com/c/cdm/+/165860

164889: Updated CoreCommonRequestFromMessage and fix test | https://widevine-internal-review.googlesource.com/c/cdm/+/164889
164967: Add OPK pre-hook and post-hook error codes | https://widevine-internal-review.googlesource.com/c/cdm/+/164967
165140: Add hidden device_id_length to v18 provisioning message | https://widevine-internal-review.googlesource.com/c/cdm/+/165140
165204: Fix memory leak in oemcrypto test | https://widevine-internal-review.googlesource.com/c/cdm/+/165204

165958: Fix oemcrypto_generic_verify_fuzz mutator signature offset | https://widevine-internal-review.googlesource.com/c/cdm/+/165958

166037: Support SHA-256 in OEMCrypto Session Util | https://widevine-internal-review.googlesource.com/c/cdm/+/166037

Test: Run GtsMediaTests on Pixel 7
Bug: 270612144

Change-Id: Iff0820a2de7d043a820470a130af65b0dcadb759
2023-02-28 11:21:05 -08:00

1333 lines
44 KiB
C++

// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary
// source code may only be used and distributed under the Widevine License
// Agreement.
//
// Reference implementation utilities of OEMCrypto APIs
//
#include "oemcrypto_rsa_key.h"
#include <assert.h>
#include <string.h>
#include <netinet/in.h>
#include <openssl/bio.h>
#include <openssl/bn.h>
#include <openssl/crypto.h>
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include "OEMCryptoCENC.h"
#include "log.h"
#include "oemcrypto_types.h"
#include "scoped_object.h"
namespace wvoec {
namespace util {
namespace {
// Estimated size of an RSA private key.
// The private key constists of:
// n : Modulos : byteSize(n) ~= bitSize(n)/8
// e : Public exponent : 4 bytes
// d : Private exponent : ~byteSize(n)
// p : Prime 1 : ~byteSize(n)/2
// q : Prime 2 : ~byteSize(n)/2
// d (mod p-1) : Exponent 1 : ~byteSize(n)/2
// d (mod q-1) : Exponent 2 : ~byteSize(n)/2
// q^-1 (mod p) : Coefficient : ~byteSize(n)/2
// And the ASN.1 tags for each component (roughly 25 bytes).
constexpr size_t kPrivateKeySize = (3072 / 8) * 5 + 25;
// Estimated size of an RSA public key.
// The public key constists of:
// The private key constists of:
// n : Modulos : byteSize(n) ~= bitSize(n)/8
// e : Public exponent : 4 bytes
// And the ASN.1 tags + outer structure.
constexpr size_t kPublicKeySize = (3072 / 8) + 100;
// 128 bit key, intended to be used with CMAC-AES-128.
constexpr size_t kRsaSessionKeySize = 16;
// Encryption key used by OEMCrypto session for encrypting and
// decrypting data.
constexpr size_t kEncryptionKeySize = wvoec::KEY_SIZE;
// Salt length used by OEMCrypto's RSASSA-PSS implementation.
// See description of kRsaPssDefault for more information.
constexpr size_t kPssSaltLength = 20;
// Requirement of CAST receivers.
constexpr size_t kRsaPkcs1CastMaxMessageSize = 83;
using ScopedBigNum = ScopedObject<BIGNUM, BN_free>;
using ScopedBio = ScopedObject<BIO, BIO_vfree>;
using ScopedEvpMdCtx = ScopedObject<EVP_MD_CTX, EVP_MD_CTX_free>;
using ScopedEvpPkey = ScopedObject<EVP_PKEY, EVP_PKEY_free>;
using ScopedPrivateKeyInfo =
ScopedObject<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_free>;
using ScopedRsaKey = ScopedObject<RSA, RSA_free>;
// Estimates the RSA rough field size from the real bit size of the
// RSA modulos. The actual bit length could vary by a few bits.
RsaFieldSize RealBitSizeToFieldSize(int bits) {
if (bits > 1800 && bits < 2200) {
return kRsa2048Bit;
}
if (bits > 2800 && bits < 3200) {
return kRsa3072Bit;
}
return kRsaFieldUnknown;
}
bool IsValidAllowedSchemes(uint32_t allowed_schemes) {
static constexpr uint32_t kAllSchemesMask =
kSign_RSASSA_PSS | kSign_PKCS1_Block1;
return (allowed_schemes & kAllSchemesMask) != 0;
}
bool ParseRsaPrivateKeyInfo(const uint8_t* buffer, size_t length,
ScopedRsaKey* key, uint32_t* allowed_schemes,
bool* explicit_schemes, RsaFieldSize* field_size) {
if (length < 8) {
LOGE("Public key is too small: length = %zu", length);
return false;
}
ScopedBio bio;
// Check allowed scheme type.
if (!memcmp("SIGN", buffer, 4)) {
uint32_t allowed_schemes_bno;
memcpy(&allowed_schemes_bno, reinterpret_cast<const uint8_t*>(&buffer[4]),
4);
*allowed_schemes = ntohl(allowed_schemes_bno);
if (!IsValidAllowedSchemes(*allowed_schemes)) {
LOGE("Invalid allowed schemes value: allowed_schemes = %08x",
*allowed_schemes);
return false;
}
bio.reset(BIO_new_mem_buf(&buffer[8], static_cast<int>(length - 8)));
*explicit_schemes = true;
} else {
*allowed_schemes = kSign_RSASSA_PSS;
bio.reset(BIO_new_mem_buf(buffer, static_cast<int>(length)));
*explicit_schemes = false;
}
if (!bio) {
LOGE("Failed to allocate BIO buffer");
return false;
}
// Step 1: Deserializes PKCS8 PrivateKeyInfo containing an RSA key.
ScopedPrivateKeyInfo priv_info(
d2i_PKCS8_PRIV_KEY_INFO_bio(bio.get(), nullptr));
if (!priv_info) {
LOGE("Failed to parse private key");
return false;
}
// Step 2: Convert to RSA key.
ScopedEvpPkey pkey(EVP_PKCS82PKEY(priv_info.get()));
if (!pkey) {
LOGE("Failed to convert PKCS8 to EVP");
return false;
}
const int key_type = EVP_PKEY_base_id(pkey.get());
if (key_type != EVP_PKEY_RSA) {
LOGE("Decoded private key is not RSA");
return false;
}
key->reset(EVP_PKEY_get1_RSA(pkey.get()));
if (!*key) {
LOGE("Failed to get RSA key");
return false;
}
// Step 3: Verify key parameters and field width.
const int check = RSA_check_key(key->get());
if (check == 0) {
LOGE("RSA key parameters are invalid");
return false;
} else if (check == -1) {
LOGE("Failed to check RSA key");
return false;
}
const int bits = RSA_bits(key->get());
*field_size = RealBitSizeToFieldSize(bits);
if (*field_size == kRsaFieldUnknown) {
LOGE("Unsupported RSA key size: bits = %d", bits);
return false;
}
return true;
}
void OpensslFreeU8(uint8_t* ptr) { OPENSSL_free(ptr); }
} // namespace
std::string RsaFieldSizeToString(RsaFieldSize field_size) {
switch (field_size) {
case kRsa2048Bit:
return "RSA-2048";
case kRsa3072Bit:
return "RSA-3072";
case kRsaFieldUnknown:
return "Unknown";
}
return "Unknown(" + std::to_string(static_cast<int>(field_size)) + ")";
}
bool RsaKeysAreMatchingPair(const RSA* public_key, const RSA* private_key) {
if (public_key == nullptr) {
LOGE("Public key is null");
return false;
}
if (private_key == nullptr) {
LOGE("Private key is null");
return false;
}
// Step 1: Extract public key components.
const BIGNUM* public_n = nullptr;
const BIGNUM* public_e = nullptr;
const BIGNUM* d = nullptr;
RSA_get0_key(public_key, &public_n, &public_e, &d);
if (public_n == nullptr || public_e == nullptr) {
LOGE("Failed to get RSA public key components");
return false;
}
// Step 2: Extract private key components.
const BIGNUM* private_n = nullptr;
const BIGNUM* private_e = nullptr;
RSA_get0_key(private_key, &private_n, &private_e, &d);
if (private_n == nullptr || private_e == nullptr) {
LOGE("Failed to get RSA private key components");
return false;
}
// Step 3: Compare RSA components.
if (BN_cmp(public_n, private_n)) {
LOGD("RSA modulos do not match");
return false;
}
if (BN_cmp(public_e, private_e)) {
LOGD("RSA exponents do not match");
return false;
}
return true;
}
// ===== ===== ===== RSA Public Key ===== ===== =====
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::New(
const RsaPrivateKey& private_key) {
std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
if (!key->InitFromPrivateKey(private_key)) {
LOGE("Failed to initialize public key from private key");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::FromSslHandle(
const RSA* rsa_handle, uint32_t allowed_schemes) {
if (rsa_handle == nullptr) {
LOGE("Provided OpenSSL/BoringSSL RSA key is null");
return nullptr;
}
if (!IsValidAllowedSchemes(allowed_schemes)) {
LOGE("Invalid |allowed_schemes| value: allowed_schemes = %08x",
allowed_schemes);
return nullptr;
}
std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
if (!key->InitFromSslHandle(rsa_handle, allowed_schemes)) {
LOGE("Failed to initialize public key from OpenSSL/BoringSSL RSA handle");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::Load(const uint8_t* buffer,
size_t length) {
if (buffer == nullptr) {
LOGE("Provided public key buffer is null");
return nullptr;
}
if (length == 0) {
LOGE("Provided public key buffer is zero length");
return nullptr;
}
std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
if (!key->InitFromSubjectPublicKeyInfo(buffer, length)) {
LOGE("Failed to initialize public key from SubjectPublicKeyInfo");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::Load(const std::string& buffer) {
if (buffer.empty()) {
LOGE("Provided public key buffer is empty");
return nullptr;
}
return Load(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size());
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::Load(
const std::vector<uint8_t>& buffer) {
if (buffer.empty()) {
LOGE("Provided public key buffer is empty");
return nullptr;
}
return Load(buffer.data(), buffer.size());
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::LoadPrivateKeyInfo(
const uint8_t* buffer, size_t length) {
if (buffer == nullptr) {
LOGE("Provided public key buffer is null");
return nullptr;
}
if (length == 0) {
LOGE("Provided public key buffer is zero length");
return nullptr;
}
std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
if (!key->InitFromPrivateKeyInfo(buffer, length)) {
LOGE("Failed to initialize public key from PrivateKeyInfo");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::LoadPrivateKeyInfo(
const std::string& buffer) {
if (buffer.empty()) {
LOGE("Provided public key buffer is empty");
return nullptr;
}
return LoadPrivateKeyInfo(reinterpret_cast<const uint8_t*>(buffer.data()),
buffer.size());
}
// static
std::unique_ptr<RsaPublicKey> RsaPublicKey::LoadPrivateKeyInfo(
const std::vector<uint8_t>& buffer) {
if (buffer.empty()) {
LOGE("Provided public key buffer is empty");
return nullptr;
}
return LoadPrivateKeyInfo(buffer.data(), buffer.size());
}
bool RsaPublicKey::IsMatchingPrivateKey(
const RsaPrivateKey& private_key) const {
if (private_key.field_size() != field_size_) {
return false;
}
return RsaKeysAreMatchingPair(GetRsaKey(), private_key.GetRsaKey());
}
std::vector<uint8_t> RsaPrivateKey::GetPrivateExponent() const {
const BIGNUM* d = RSA_get0_d(key_);
if (d == nullptr) {
LOGE("Private exponent must not be null");
return {};
}
// Get the required length for the data.
const size_t length = BN_num_bytes(d);
if (length <= 0) {
LOGE("Private exponent length must be positive");
return {};
}
std::vector<uint8_t> serialized_private_exponent(length, 0);
if (static_cast<size_t>(BN_bn2bin(d, serialized_private_exponent.data())) !=
length) {
LOGE("Failed to convert the private exponent");
return {};
}
return serialized_private_exponent;
}
OEMCryptoResult RsaPublicKey::Serialize(uint8_t* buffer,
size_t* buffer_size) const {
if (buffer_size == nullptr) {
LOGE("Output buffer size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (buffer == nullptr && *buffer_size > 0) {
LOGE("Output buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
uint8_t* der_key_raw = nullptr;
const int der_res = i2d_RSA_PUBKEY(key_, &der_key_raw);
if (der_res < 0) {
LOGE("Public key serialization failed");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
ScopedObject<uint8_t, OpensslFreeU8> der_key(der_key_raw);
der_key_raw = nullptr;
if (!der_key) {
LOGE("Encoded key is unexpectedly null");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (der_res == 0) {
LOGE("Unexpected DER encoded size");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
const size_t required_size = static_cast<size_t>(der_res);
if (buffer == nullptr || *buffer_size < required_size) {
*buffer_size = required_size;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
memcpy(buffer, der_key.get(), required_size);
*buffer_size = required_size;
return OEMCrypto_SUCCESS;
}
std::vector<uint8_t> RsaPublicKey::Serialize() const {
size_t key_size = kPublicKeySize;
std::vector<uint8_t> key_data(key_size, 0);
const OEMCryptoResult res = Serialize(key_data.data(), &key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to serialize public key: result = %d", static_cast<int>(res));
key_data.clear();
} else {
key_data.resize(key_size);
}
return key_data;
}
OEMCryptoResult RsaPublicKey::VerifySignature(
const uint8_t* message, size_t message_length, const uint8_t* signature,
size_t signature_length, RsaSignatureAlgorithm algorithm,
OEMCrypto_SignatureHashAlgorithm hash_algorithm) const {
if (signature == nullptr || signature_length == 0) {
LOGE("Signature is missing");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (message == nullptr && message_length > 0) {
LOGE("Bad message data");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
switch (algorithm) {
case kRsaPssDefault:
return VerifySignaturePss(message, message_length, signature,
signature_length, hash_algorithm);
case kRsaPkcs1Cast:
return VerifySignaturePkcs1Cast(message, message_length, signature,
signature_length);
}
LOGE("Unknown RSA signature algorithm: %d", static_cast<int>(algorithm));
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
OEMCryptoResult RsaPublicKey::VerifySignature(
const std::string& message, const std::string& signature,
RsaSignatureAlgorithm algorithm,
OEMCrypto_SignatureHashAlgorithm hash_algorithm) const {
if (signature.empty()) {
LOGE("Signature should not be empty");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return VerifySignature(reinterpret_cast<const uint8_t*>(message.data()),
message.size(),
reinterpret_cast<const uint8_t*>(signature.data()),
signature.size(), algorithm, hash_algorithm);
}
OEMCryptoResult RsaPublicKey::VerifySignature(
const std::vector<uint8_t>& message, const std::vector<uint8_t>& signature,
RsaSignatureAlgorithm algorithm,
OEMCrypto_SignatureHashAlgorithm hash_algorithm) const {
if (signature.empty()) {
LOGE("Signature should not be empty");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return VerifySignature(message.data(), message.size(), signature.data(),
signature.size(), algorithm, hash_algorithm);
}
OEMCryptoResult RsaPublicKey::EncryptSessionKey(
const uint8_t* session_key, size_t session_key_size,
uint8_t* enc_session_key, size_t* enc_session_key_size) const {
if (session_key == nullptr) {
LOGE("Session key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (session_key_size != kRsaSessionKeySize) {
LOGE("Unexpected session key size: expected = %zu, actual = %zu",
kRsaSessionKeySize, session_key_size);
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_session_key_size == nullptr) {
LOGE("Output encrypted session key size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_session_key == nullptr && *enc_session_key_size > 0) {
LOGE("Output encrypted session key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return EncryptOaep(session_key, session_key_size, enc_session_key,
enc_session_key_size);
}
std::vector<uint8_t> RsaPublicKey::EncryptSessionKey(
const std::vector<uint8_t>& session_key) const {
if (session_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_session_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_session_key(enc_session_key_size);
const OEMCryptoResult res =
EncryptSessionKey(session_key.data(), session_key.size(),
enc_session_key.data(), &enc_session_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt session key: result = %d", static_cast<int>(res));
enc_session_key.clear();
} else {
enc_session_key.resize(enc_session_key_size);
}
return enc_session_key;
}
std::vector<uint8_t> RsaPublicKey::EncryptSessionKey(
const std::string& session_key) const {
if (session_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_session_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_session_key(enc_session_key_size);
const OEMCryptoResult res = EncryptSessionKey(
reinterpret_cast<const uint8_t*>(session_key.data()), session_key.size(),
enc_session_key.data(), &enc_session_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt session key: result = %d", static_cast<int>(res));
enc_session_key.clear();
} else {
enc_session_key.resize(enc_session_key_size);
}
return enc_session_key;
}
OEMCryptoResult RsaPublicKey::EncryptEncryptionKey(
const uint8_t* encryption_key, size_t encryption_key_size,
uint8_t* enc_encryption_key, size_t* enc_encryption_key_size) const {
if (encryption_key == nullptr) {
LOGE("Encryption key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (encryption_key_size != kEncryptionKeySize) {
LOGE("Unexpected encryption key size: expected = %zu, actual = %zu",
kEncryptionKeySize, encryption_key_size);
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_encryption_key_size == nullptr) {
LOGE("Output encrypted encryption key size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (enc_encryption_key == nullptr && *enc_encryption_key_size > 0) {
LOGE("Output encrypted encryption key is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
return EncryptOaep(encryption_key, encryption_key_size, enc_encryption_key,
enc_encryption_key_size);
}
std::vector<uint8_t> RsaPublicKey::EncryptEncryptionKey(
const std::vector<uint8_t>& encryption_key) const {
if (encryption_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_encryption_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_encryption_key(enc_encryption_key_size);
const OEMCryptoResult res =
EncryptSessionKey(encryption_key.data(), encryption_key.size(),
enc_encryption_key.data(), &enc_encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt encryption key: result = %d",
static_cast<int>(res));
enc_encryption_key.clear();
} else {
enc_encryption_key.resize(enc_encryption_key_size);
}
return enc_encryption_key;
}
std::vector<uint8_t> RsaPublicKey::EncryptEncryptionKey(
const std::string& encryption_key) const {
if (encryption_key.empty()) {
LOGE("Session key is empty");
return std::vector<uint8_t>();
}
size_t enc_encryption_key_size = static_cast<size_t>(RSA_size(key_));
std::vector<uint8_t> enc_encryption_key(enc_encryption_key_size);
const OEMCryptoResult res =
EncryptSessionKey(reinterpret_cast<const uint8_t*>(encryption_key.data()),
encryption_key.size(), enc_encryption_key.data(),
&enc_encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to encrypt encryption key: result = %d",
static_cast<int>(res));
enc_encryption_key.clear();
} else {
enc_encryption_key.resize(enc_encryption_key_size);
}
return enc_encryption_key;
}
RsaPublicKey::~RsaPublicKey() {
if (key_ != nullptr) {
RSA_free(key_);
key_ = nullptr;
}
allowed_schemes_ = 0;
field_size_ = kRsaFieldUnknown;
}
bool RsaPublicKey::InitFromSubjectPublicKeyInfo(const uint8_t* buffer,
size_t length) {
// Step 1: Deserialize SubjectPublicKeyInfo as RSA key.
const uint8_t* tp = buffer;
ScopedRsaKey key(d2i_RSA_PUBKEY(nullptr, &tp, length));
if (!key) {
LOGE("Failed to parse key");
return false;
}
// Step 2: Verify key.
const int bits = RSA_bits(key.get());
field_size_ = RealBitSizeToFieldSize(bits);
if (field_size_ == kRsaFieldUnknown) {
LOGE("Unsupported RSA key size: bits = %d", bits);
return false;
}
allowed_schemes_ = kSign_RSASSA_PSS;
key_ = key.release();
return true;
}
bool RsaPublicKey::InitFromPrivateKeyInfo(const uint8_t* buffer,
size_t length) {
ScopedRsaKey private_key;
bool explicit_schemes = false;
if (!ParseRsaPrivateKeyInfo(buffer, length, &private_key, &allowed_schemes_,
&explicit_schemes, &field_size_)) {
return false;
}
// Need to strip the private key information.
return InitFromSslHandle(private_key.get(), allowed_schemes_);
}
bool RsaPublicKey::InitFromPrivateKey(const RsaPrivateKey& private_key) {
return InitFromSslHandle(private_key.GetRsaKey(),
private_key.allowed_schemes());
}
bool RsaPublicKey::InitFromSslHandle(const RSA* rsa_handle,
uint32_t allowed_schemes) {
ScopedRsaKey key(RSA_new());
if (!key) {
LOGE("Failed to allocate key");
return false;
}
const BIGNUM* n = nullptr;
const BIGNUM* e = nullptr;
const BIGNUM* d = nullptr;
RSA_get0_key(rsa_handle, &n, &e, &d);
if (n == nullptr || e == nullptr) {
LOGE("Failed to get RSA private key components");
return false;
}
ScopedBigNum dub_n(BN_dup(n));
ScopedBigNum dub_e(BN_dup(e));
if (!dub_n || !dub_e) {
LOGE("Failed to duplicate RSA public key components");
return false;
}
if (!RSA_set0_key(key.get(), dub_n.get(), dub_e.get(), nullptr)) {
LOGE("Failed to RSA public key components");
return false;
}
// Ownership has transferred to the RSA key.
dub_n.release();
dub_e.release();
key_ = key.release();
allowed_schemes_ = allowed_schemes;
const int bits = RSA_bits(rsa_handle);
field_size_ = RealBitSizeToFieldSize(bits);
if (field_size_ == kRsaFieldUnknown) {
LOGE("Unsupported RSA key size: bits = %d", bits);
return false;
}
return true;
}
OEMCryptoResult RsaPublicKey::VerifySignaturePss(
const uint8_t* message, size_t message_length, const uint8_t* signature,
size_t signature_length,
OEMCrypto_SignatureHashAlgorithm hash_algorithm) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_RSASSA_PSS)) {
LOGE("RSA key cannot verify using PSS");
return OEMCrypto_ERROR_INVALID_KEY;
}
// Step 1: Create a high-level key from RSA key.
ScopedEvpPkey pkey(EVP_PKEY_new());
if (!pkey) {
LOGE("Failed to allocate PKEY");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) {
LOGE("Failed to set PKEY RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2a: Choose the correct digest algorithm.
const EVP_MD* digest = nullptr;
switch (hash_algorithm) {
case OEMCrypto_SHA1:
digest = EVP_sha1();
break;
case OEMCrypto_SHA2_256:
digest = EVP_sha256();
break;
case OEMCrypto_SHA2_384:
digest = EVP_sha384();
break;
case OEMCrypto_SHA2_512:
digest = EVP_sha512();
break;
}
if (digest == nullptr) {
LOGE("Unrecognized hash algorithm %d", hash_algorithm);
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
// Step 2b: Setup an EVP MD CTX for PSS Verification.
ScopedEvpMdCtx md_ctx = EVP_MD_CTX_new();
if (!md_ctx) {
LOGE("Failed to allocate MD CTX");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx|
int res = EVP_DigestVerifyInit(md_ctx.get(), &pkey_ctx, digest, nullptr,
pkey.get());
if (res != 1) {
LOGE("Failed to initialize MD CTX for verification");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (pkey_ctx == nullptr) {
LOGE("PKEY CTX is unexpectedly null");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2c: Configure OEMCrypto RSASSA-PSS options.
res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING);
if (res != 1) {
LOGE("Failed to set PSS padding");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength);
if (res != 1) {
LOGE("Failed to set PSS salt length");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1());
if (res != 1) {
LOGE("Failed to set PSS MGF1 MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 3: Digest message.
if (message_length > 0) {
res = EVP_DigestVerifyUpdate(md_ctx.get(), message, message_length);
if (res != 1) {
LOGE("Failed to update MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
}
// Step 4: Verify message.
res = EVP_DigestVerifyFinal(md_ctx.get(), signature, signature_length);
if (res < 0) {
LOGE("Failed to perform RSASSA-PSS-VERIFY");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
return res ? OEMCrypto_SUCCESS : OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
OEMCryptoResult RsaPublicKey::VerifySignaturePkcs1Cast(
const uint8_t* message, size_t message_length, const uint8_t* signature,
size_t signature_length) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_PKCS1_Block1)) {
LOGE("RSA key cannot verify using PKCS1");
return OEMCrypto_ERROR_INVALID_KEY;
}
if (message_length > kRsaPkcs1CastMaxMessageSize) {
LOGE("Message is too large for CAST PKCS1 signature: size = %zu",
message_length);
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
// Step 1: Convert encrypted blob into digest.
std::vector<uint8_t> digest(RSA_size(key_));
const int res =
RSA_public_decrypt(static_cast<int>(signature_length), signature,
digest.data(), key_, RSA_PKCS1_PADDING);
if (res <= 0) {
LOGE("Failed to perform public decryption");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Compare digests.
const size_t digest_size = static_cast<size_t>(res);
if (digest_size != message_length) {
LOGD("Digest size does not match");
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
digest.resize(digest_size);
for (size_t i = 0; i < digest_size; i++) {
if (message[i] != digest[i]) {
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
}
return OEMCrypto_SUCCESS;
}
OEMCryptoResult RsaPublicKey::EncryptOaep(const uint8_t* message,
size_t message_size,
uint8_t* enc_message,
size_t* enc_message_length) const {
// Step 1: Encrypt using RSAES-OAEP.
std::vector<uint8_t> encrypt_buffer(RSA_size(key_));
const int enc_res =
RSA_public_encrypt(static_cast<int>(message_size), message,
encrypt_buffer.data(), key_, RSA_PKCS1_OAEP_PADDING);
if (enc_res < 0) {
LOGE("Failed to perform RSAES-OAEP encrypt");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Copy encrypted data key.
const size_t enc_size = static_cast<size_t>(enc_res);
if (*enc_message_length < enc_size) {
*enc_message_length = enc_size;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
*enc_message_length = enc_size;
memcpy(enc_message, encrypt_buffer.data(), enc_size);
return OEMCrypto_SUCCESS;
}
// ===== ===== ===== RSA Private Key ===== ===== =====
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::New(RsaFieldSize field_size) {
std::unique_ptr<RsaPrivateKey> key(new RsaPrivateKey());
if (!key->InitFromFieldSize(field_size)) {
LOGE("Failed to initialize private key from field size");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::Load(const uint8_t* buffer,
size_t length) {
if (buffer == nullptr) {
LOGE("Provided private key buffer is null");
return nullptr;
}
if (length == 0) {
LOGE("Provided private key buffer is zero length");
return nullptr;
}
std::unique_ptr<RsaPrivateKey> key(new RsaPrivateKey());
if (!key->InitFromPrivateKeyInfo(buffer, length)) {
LOGE("Failed to initialize private key from PrivateKeyInfo");
key.reset();
}
return key;
}
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::Load(const std::string& buffer) {
if (buffer.empty()) {
LOGE("Provided private key buffer is empty");
return std::unique_ptr<RsaPrivateKey>();
}
return Load(reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size());
}
// static
std::unique_ptr<RsaPrivateKey> RsaPrivateKey::Load(
const std::vector<uint8_t>& buffer) {
if (buffer.empty()) {
LOGE("Provided private key buffer is empty");
return std::unique_ptr<RsaPrivateKey>();
}
return Load(buffer.data(), buffer.size());
}
std::unique_ptr<RsaPublicKey> RsaPrivateKey::MakePublicKey() const {
return RsaPublicKey::New(*this);
}
bool RsaPrivateKey::IsMatchingPublicKey(const RsaPublicKey& public_key) const {
return RsaKeysAreMatchingPair(public_key.GetRsaKey(), GetRsaKey());
}
OEMCryptoResult RsaPrivateKey::Serialize(uint8_t* buffer,
size_t* buffer_size) const {
if (buffer_size == nullptr) {
LOGE("Output buffer size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (buffer == nullptr && *buffer_size > 0) {
LOGE("Output buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
// Step 1: Convert RSA key to EVP.
ScopedEvpPkey pkey(EVP_PKEY_new());
if (!pkey) {
LOGE("Failed to allocate EVP");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) {
LOGE("Failed to set EVP RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Convert RSA EVP to PKCS8 format.
ScopedPrivateKeyInfo priv_info(EVP_PKEY2PKCS8(pkey.get()));
if (!priv_info) {
LOGE("Failed to convert RSA key to PKCS8 info");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 3: Serialize PKCS8 to DER encoding.
ScopedBio bio(BIO_new(BIO_s_mem()));
if (!bio) {
LOGE("Failed to allocate IO buffer for RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!i2d_PKCS8_PRIV_KEY_INFO_bio(bio.get(), priv_info.get())) {
LOGE("Failed to serialize RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 4: Determine key size and copy.
char* key_ptr = nullptr;
const long key_size = BIO_get_mem_data(bio.get(), &key_ptr);
if (key_size < 0) {
LOGE("Failed to get RSA key size");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (key_ptr == nullptr) {
LOGE("Encoded key is unexpectedly null");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
const size_t required_size =
static_cast<size_t>(key_size) + (explicit_schemes_ ? 8 : 0);
if (*buffer_size < required_size) {
*buffer_size = required_size;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
*buffer_size = required_size;
if (explicit_schemes_) {
memcpy(buffer, "SIGN", 4);
const uint32_t allowed_schemes = htonl(allowed_schemes_);
memcpy(&buffer[4], &allowed_schemes, 4);
memcpy(&buffer[8], key_ptr, required_size - 8);
} else {
memcpy(buffer, key_ptr, required_size);
}
return OEMCrypto_SUCCESS;
}
std::vector<uint8_t> RsaPrivateKey::Serialize() const {
size_t key_size = kPrivateKeySize;
std::vector<uint8_t> key_data(key_size, 0);
OEMCryptoResult res = Serialize(key_data.data(), &key_size);
if (res == OEMCrypto_ERROR_SHORT_BUFFER) {
LOGD(
"Actual RSA private key size is larger than expected: "
"expected = %zu, actual = %zu",
kPrivateKeySize, key_size);
key_data.resize(key_size);
res = Serialize(key_data.data(), &key_size);
}
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to serialize private key: result = %d", static_cast<int>(res));
key_data.clear();
} else {
key_data.resize(key_size);
}
return key_data;
}
OEMCryptoResult RsaPrivateKey::GenerateSignature(
const uint8_t* message, size_t message_length,
RsaSignatureAlgorithm algorithm, uint8_t* signature,
size_t* signature_length) const {
if (signature_length == nullptr) {
LOGE("Output signature size is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (signature == nullptr && *signature_length > 0) {
LOGE("Output signature is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (message == nullptr && message_length > 0) {
LOGE("Invalid message data");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
switch (algorithm) {
case kRsaPssDefault:
return GenerateSignaturePss(message, message_length, signature,
signature_length);
case kRsaPkcs1Cast:
return GenerateSignaturePkcs1Cast(message, message_length, signature,
signature_length);
}
LOGE("Unknown RSA signature algorithm: %d", static_cast<int>(algorithm));
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
std::vector<uint8_t> RsaPrivateKey::GenerateSignature(
const std::string& message, RsaSignatureAlgorithm algorithm) const {
size_t signature_size = SignatureSize();
std::vector<uint8_t> signature(signature_size);
const OEMCryptoResult res = GenerateSignature(
reinterpret_cast<const uint8_t*>(message.data()), message.size(),
algorithm, signature.data(), &signature_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to generate signature: result = %d", static_cast<int>(res));
signature.clear();
} else {
signature.resize(signature_size);
}
return signature;
}
std::vector<uint8_t> RsaPrivateKey::GenerateSignature(
const std::vector<uint8_t>& message,
RsaSignatureAlgorithm algorithm) const {
size_t signature_size = SignatureSize();
std::vector<uint8_t> signature(signature_size, 0);
const OEMCryptoResult res =
GenerateSignature(message.data(), message.size(), algorithm,
signature.data(), &signature_size);
if (res != OEMCrypto_SUCCESS) {
LOGE("Failed to generate signature: result = %d", static_cast<int>(res));
signature.clear();
} else {
signature.resize(signature_size);
}
return signature;
}
size_t RsaPrivateKey::SignatureSize() const { return RSA_size(key_); }
OEMCryptoResult RsaPrivateKey::DecryptSessionKey(
const uint8_t* enc_session_key, size_t enc_session_key_size,
uint8_t* session_key, size_t* session_key_size) const {
if (enc_session_key == nullptr || enc_session_key_size == 0) {
LOGE("Encrypted session key is %s", enc_session_key ? "empty" : "null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (session_key_size == nullptr) {
LOGE("Output session key size buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (session_key == nullptr && *session_key_size > 0) {
LOGE("Output session key buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (*session_key_size < kRsaSessionKeySize) {
*session_key_size = kRsaSessionKeySize;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
const OEMCryptoResult result = DecryptOaep(
enc_session_key, enc_session_key_size, session_key, kRsaSessionKeySize);
if (result == OEMCrypto_SUCCESS) {
*session_key_size = kRsaSessionKeySize;
} else {
LOGE("Failed to decrypt session key");
}
return result;
}
std::vector<uint8_t> RsaPrivateKey::DecryptSessionKey(
const std::vector<uint8_t>& enc_session_key) const {
size_t session_key_size = kRsaSessionKeySize;
std::vector<uint8_t> session_key(session_key_size, 0);
const OEMCryptoResult res =
DecryptSessionKey(enc_session_key.data(), enc_session_key.size(),
session_key.data(), &session_key_size);
if (res != OEMCrypto_SUCCESS) {
session_key.clear();
} else {
session_key.resize(session_key_size);
}
return session_key;
}
std::vector<uint8_t> RsaPrivateKey::DecryptSessionKey(
const std::string& enc_session_key) const {
size_t session_key_size = kRsaSessionKeySize;
std::vector<uint8_t> session_key(session_key_size, 0);
const OEMCryptoResult res = DecryptSessionKey(
reinterpret_cast<const uint8_t*>(enc_session_key.data()),
enc_session_key.size(), session_key.data(), &session_key_size);
if (res != OEMCrypto_SUCCESS) {
session_key.clear();
} else {
session_key.resize(session_key_size);
}
return session_key;
}
size_t RsaPrivateKey::SessionKeyLength() const { return kRsaSessionKeySize; }
OEMCryptoResult RsaPrivateKey::DecryptEncryptionKey(
const uint8_t* enc_encryption_key, size_t enc_encryption_key_size,
uint8_t* encryption_key, size_t* encryption_key_size) const {
if (enc_encryption_key == nullptr || enc_encryption_key_size == 0) {
LOGE("Encrypted encryption key is %s",
enc_encryption_key ? "empty" : "null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (encryption_key_size == nullptr) {
LOGE("Output encryption key size buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (encryption_key == nullptr && *encryption_key_size > 0) {
LOGE("Output encryption key buffer is null");
return OEMCrypto_ERROR_INVALID_CONTEXT;
}
if (*encryption_key_size < kEncryptionKeySize) {
*encryption_key_size = kEncryptionKeySize;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
const OEMCryptoResult result =
DecryptOaep(enc_encryption_key, enc_encryption_key_size, encryption_key,
kEncryptionKeySize);
if (result == OEMCrypto_SUCCESS) {
*encryption_key_size = kEncryptionKeySize;
} else {
LOGE("Failed to decrypt encryption key");
}
return result;
}
std::vector<uint8_t> RsaPrivateKey::DecryptEncryptionKey(
const std::vector<uint8_t>& enc_encryption_key) const {
size_t encryption_key_size = kEncryptionKeySize;
std::vector<uint8_t> encryption_key(encryption_key_size, 0);
const OEMCryptoResult res =
DecryptEncryptionKey(enc_encryption_key.data(), enc_encryption_key.size(),
encryption_key.data(), &encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
encryption_key.clear();
} else {
encryption_key.resize(encryption_key_size);
}
return encryption_key;
}
std::vector<uint8_t> RsaPrivateKey::DecryptEncryptionKey(
const std::string& enc_encryption_key) const {
size_t encryption_key_size = kEncryptionKeySize;
std::vector<uint8_t> encryption_key(encryption_key_size, 0);
const OEMCryptoResult res = DecryptEncryptionKey(
reinterpret_cast<const uint8_t*>(enc_encryption_key.data()),
enc_encryption_key.size(), encryption_key.data(), &encryption_key_size);
if (res != OEMCrypto_SUCCESS) {
encryption_key.clear();
} else {
encryption_key.resize(encryption_key_size);
}
return encryption_key;
}
RsaPrivateKey::~RsaPrivateKey() {
if (key_ != nullptr) {
RSA_free(key_);
key_ = nullptr;
}
allowed_schemes_ = 0;
explicit_schemes_ = false;
field_size_ = kRsaFieldUnknown;
}
bool RsaPrivateKey::InitFromPrivateKeyInfo(const uint8_t* buffer,
size_t length) {
ScopedRsaKey key;
if (!ParseRsaPrivateKeyInfo(buffer, length, &key, &allowed_schemes_,
&explicit_schemes_, &field_size_)) {
return false;
}
key_ = key.release();
return true;
}
bool RsaPrivateKey::InitFromFieldSize(RsaFieldSize field_size) {
if (field_size != kRsa2048Bit && field_size != kRsa3072Bit) {
LOGE("Unsupported RSA field size: bits = %d", static_cast<int>(field_size));
return false;
}
// Step 1: Create exponent.
ScopedBigNum exp(BN_new());
if (!exp) {
LOGE("Failed to allocate RSA exponent");
return false;
}
if (!BN_set_word(exp.get(), RSA_F4)) {
LOGE("Failed to set RSA exponent");
return false;
}
// Step 2: Generate RSA key.
ScopedRsaKey key(RSA_new());
if (!key) {
LOGE("Failed to allocate RSA key");
return false;
}
if (!RSA_generate_key_ex(key.get(), static_cast<int>(field_size), exp.get(),
nullptr)) {
LOGE("Failed to generate RSA key");
return false;
}
key_ = key.release();
allowed_schemes_ = kSign_RSASSA_PSS;
field_size_ = field_size;
return true;
}
OEMCryptoResult RsaPrivateKey::GenerateSignaturePss(
const uint8_t* message, size_t message_length, uint8_t* signature,
size_t* signature_length) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_RSASSA_PSS)) {
LOGE("RSA key cannot sign using PSS");
return OEMCrypto_ERROR_INVALID_KEY;
}
// Step 1: Create a high-level key from RSA key.
ScopedEvpPkey pkey(EVP_PKEY_new());
if (!pkey) {
LOGE("Failed to allocate PKEY");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) {
LOGE("Failed to set PKEY RSA key");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2a: Setup a EVP MD CTX for PSS Signature Generation.
ScopedEvpMdCtx md_ctx(EVP_MD_CTX_new());
if (!md_ctx) {
LOGE("Failed to allocate MD CTX");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx|
int res = EVP_DigestSignInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr,
pkey.get());
if (res != 1) {
LOGE("Failed to initialize MD CTX for signing");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (pkey_ctx == nullptr) {
LOGE("PKEY CTX is unexpectedly null");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2b: Configure OEMCrypto RSASSA-PSS options.
res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING);
if (res != 1) {
LOGE("Failed to set PSS padding");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength);
if (res != 1) {
LOGE("Failed to set PSS salt length");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1());
if (res != 1) {
LOGE("Failed to set PSS MGF1 MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 3: Digest message.
if (message_length > 0) {
res = EVP_DigestSignUpdate(md_ctx.get(), message, message_length);
if (res != 1) {
LOGE("Failed to update MD");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
}
// Step 4a: Determine size of signature.
size_t actual_signature_length = 0;
res = EVP_DigestSignFinal(md_ctx.get(), nullptr, &actual_signature_length);
if (res != 1) {
LOGE("Failed to determine signature length");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (*signature_length < actual_signature_length) {
*signature_length = actual_signature_length;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
// Step 4b: Generate signature.
res = EVP_DigestSignFinal(md_ctx.get(), signature, signature_length);
if (res != 1) {
LOGE("Failed to perform RSASSA-PSS-SIGN");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
return OEMCrypto_SUCCESS;
}
OEMCryptoResult RsaPrivateKey::GenerateSignaturePkcs1Cast(
const uint8_t* message, size_t message_length, uint8_t* signature,
size_t* signature_length) const {
// Step 0: Ensure the signature algorithm is supported by key.
if (!(allowed_schemes_ & kSign_PKCS1_Block1)) {
LOGE("RSA key cannot sign PKCS1");
return OEMCrypto_ERROR_INVALID_KEY;
}
if (message_length > kRsaPkcs1CastMaxMessageSize) {
LOGE("Message is too large for CAST PKCS1 signature: size = %zu",
message_length);
return OEMCrypto_ERROR_SIGNATURE_FAILURE;
}
// Step 1: Ensure signature buffer is large enough.
const size_t expected_signature_size = static_cast<size_t>(RSA_size(key_));
if (*signature_length < expected_signature_size) {
*signature_length = expected_signature_size;
return OEMCrypto_ERROR_SHORT_BUFFER;
}
// Step 2: Encrypt with PKCS1 padding.
const int enc_res =
RSA_private_encrypt(static_cast<int>(message_length), message, signature,
key_, RSA_PKCS1_PADDING);
if (enc_res < 0) {
LOGE("Failed to perform private encryption");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
*signature_length = expected_signature_size;
return OEMCrypto_SUCCESS;
}
OEMCryptoResult RsaPrivateKey::DecryptOaep(
const uint8_t* enc_message, size_t enc_message_size, uint8_t* message,
size_t expected_message_length) const {
// Step 1: Decrypt using RSAES-OAEP.
std::vector<uint8_t> decrypt_buffer(RSA_size(key_));
const int dec_res =
RSA_private_decrypt(static_cast<int>(enc_message_size), enc_message,
decrypt_buffer.data(), key_, RSA_PKCS1_OAEP_PADDING);
if (dec_res < 0) {
LOGE("Failed to perform RSAES-OAEP decrypt");
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
if (static_cast<size_t>(dec_res) != expected_message_length) {
LOGE("Unexpected key size: expected = %zu, actual = %d",
expected_message_length, dec_res);
return OEMCrypto_ERROR_UNKNOWN_FAILURE;
}
// Step 2: Copy decrypted data key.
memcpy(message, decrypt_buffer.data(), expected_message_length);
return OEMCrypto_SUCCESS;
}
} // namespace util
} // namespace wvoec