// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary // source code may only be used and distributed under the Widevine License // Agreement. // // Reference implementation of OEMCrypto APIs // #include "oemcrypto_rsa_key.h" #include #include #include #include #include #include #include #include #include #include "OEMCryptoCENC.h" #include "log.h" #include "oemcrypto_types.h" #include "scoped_object.h" namespace wvoec_ref { namespace { // Estimated size of an RSA private key. // The private key constists of: // n : Modulos : byteSize(n) ~= bitSize(n)/8 // e : Public exponent : 4 bytes // d : Private exponent : ~byteSize(n) // p : Prime 1 : ~byteSize(n)/2 // q : Prime 2 : ~byteSize(n)/2 // d (mod p-1) : Exponent 1 : ~byteSize(n)/2 // d (mod q-1) : Exponent 2 : ~byteSize(n)/2 // q^-1 (mod p) : Coefficient : ~byteSize(n)/2 // And the ASN.1 tags for each component (roughly 25 bytes). constexpr size_t kPrivateKeySize = (3072 / 8) * 5 + 25; // Estimated size of an RSA public key. // The public key constists of: // The private key constists of: // n : Modulos : byteSize(n) ~= bitSize(n)/8 // e : Public exponent : 4 bytes // And the ASN.1 tags + outer structure. constexpr size_t kPublicKeySize = (3072 / 8) + 100; // 128 bit key, intended to be used with CMAC-AES-128. constexpr size_t kRsaSessionKeySize = 16; // Encryption key used by OEMCrypto session for encrypting and // decrypting data. constexpr size_t kEncryptionKeySize = wvoec::KEY_SIZE; // Salt length used by OEMCrypto's RSASSA-PSS implementation. // See description of kRsaPssDefault for more information. constexpr size_t kPssSaltLength = 20; // Requirement of CAST receivers. constexpr size_t kRsaPkcs1CastMaxMessageSize = 83; using ScopedBio = ScopedObject; using ScopedBigNum = ScopedObject; using ScopedEvpMdCtx = ScopedObject; using ScopedEvpPkey = ScopedObject; using ScopedRsaKey = ScopedObject; using ScopedRsaPrivKeyInfo = ScopedObject; // Estimates the RSA rough field size from the real bit size of the // RSA modulos. The actual bit length could vary by a few bits. RsaFieldSize RealBitSizeToFieldSize(int bits) { if (bits > 1800 && bits < 2200) { return kRsa2048Bit; } if (bits > 2800 && bits < 3200) { return kRsa3072Bit; } return kRsaFieldUnknown; } } // namespace std::string RsaFieldSizeToString(RsaFieldSize field_size) { switch (field_size) { case kRsa2048Bit: return "RSA-2048"; case kRsa3072Bit: return "RSA-3072"; case kRsaFieldUnknown: return "Unknown"; } return "Unknown(" + std::to_string(static_cast(field_size)) + ")"; } bool RsaKeysAreMatchingPair(const RSA* public_key, const RSA* private_key) { if (public_key == nullptr) { LOGE("Public key is null"); return false; } if (private_key == nullptr) { LOGE("Private key is null"); return false; } // Step 1: Extract public key components. const BIGNUM* public_n = nullptr; const BIGNUM* public_e = nullptr; const BIGNUM* d = nullptr; RSA_get0_key(public_key, &public_n, &public_e, &d); if (public_n == nullptr || public_e == nullptr) { LOGE("Failed to get RSA public key components"); return false; } // Step 2: Extract private key components. const BIGNUM* private_n = nullptr; const BIGNUM* private_e = nullptr; RSA_get0_key(private_key, &private_n, &private_e, &d); if (private_n == nullptr || private_e == nullptr) { LOGE("Failed to get RSA private key components"); return false; } // Step 3: Compare RSA components. if (BN_cmp(public_n, private_n)) { LOGD("RSA modulos do not match"); return false; } if (BN_cmp(public_e, private_e)) { LOGD("RSA exponents do not match"); return false; } return true; } // ===== ===== ===== RSA Public Key ===== ===== ===== // static std::unique_ptr RsaPublicKey::New( const RsaPrivateKey& private_key) { std::unique_ptr key(new RsaPublicKey()); if (!key->InitFromPrivateKey(private_key)) { LOGE("Failed to initialize public key from private key"); key.reset(); } return key; } // static std::unique_ptr RsaPublicKey::Load(const uint8_t* buffer, size_t length) { std::unique_ptr key; if (buffer == nullptr) { LOGE("Provided public key buffer is null"); return key; } if (length == 0) { LOGE("Provided public key buffer is zero length"); return key; } key.reset(new RsaPublicKey()); if (!key->InitFromBuffer(buffer, length)) { LOGE("Failed to initialize public key from buffer"); key.reset(); } return key; } // static std::unique_ptr RsaPublicKey::Load(const std::string& buffer) { if (buffer.empty()) { LOGE("Provided public key buffer is empty"); return std::unique_ptr(); } return Load(reinterpret_cast(buffer.data()), buffer.size()); } // static std::unique_ptr RsaPublicKey::Load( const std::vector& buffer) { std::unique_ptr key; if (buffer.empty()) { LOGE("Provided public key buffer is empty"); return std::unique_ptr(); } return Load(buffer.data(), buffer.size()); } bool RsaPublicKey::IsMatchingPrivateKey( const RsaPrivateKey& private_key) const { if (private_key.field_size() != field_size_) { return false; } return RsaKeysAreMatchingPair(GetRsaKey(), private_key.GetRsaKey()); } OEMCryptoResult RsaPublicKey::Serialize(uint8_t* buffer, size_t* buffer_size) const { if (buffer_size == nullptr) { LOGE("Output buffer size is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (buffer == nullptr && *buffer_size > 0) { LOGE("Output buffer is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } uint8_t* der_key = nullptr; const int der_res = i2d_RSA_PUBKEY(key_, &der_key); if (der_res < 0) { LOGE("Public key serialization failed"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (der_key == nullptr) { LOGE("Encoded key is unexpectedly null"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (der_res == 0) { LOGE("Unexpected DER encoded size"); OPENSSL_free(der_key); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } const size_t required_size = static_cast(der_res); if (buffer == nullptr || *buffer_size < required_size) { *buffer_size = required_size; OPENSSL_free(der_key); return OEMCrypto_ERROR_SHORT_BUFFER; } memcpy(buffer, der_key, required_size); *buffer_size = required_size; OPENSSL_free(der_key); return OEMCrypto_SUCCESS; } std::vector RsaPublicKey::Serialize() const { size_t key_size = kPublicKeySize; std::vector key_data(key_size, 0); const OEMCryptoResult res = Serialize(key_data.data(), &key_size); if (res != OEMCrypto_SUCCESS) { LOGE("Failed to serialize public key: result = %d", static_cast(res)); key_data.clear(); } else { key_data.resize(key_size); } return key_data; } OEMCryptoResult RsaPublicKey::VerifySignature( const uint8_t* message, size_t message_length, const uint8_t* signature, size_t signature_length, RsaSignatureAlgorithm algorithm) const { if (signature == nullptr || signature_length == 0) { LOGE("Signature is missing"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (message == nullptr && message_length > 0) { LOGE("Bad message data"); return OEMCrypto_ERROR_INVALID_CONTEXT; } switch (algorithm) { case kRsaPssDefault: return VerifySignaturePss(message, message_length, signature, signature_length); case kRsaPkcs1Cast: return VerifySignaturePkcs1Cast(message, message_length, signature, signature_length); } LOGE("Unknown RSA signature algorithm: %d", static_cast(algorithm)); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } OEMCryptoResult RsaPublicKey::VerifySignature( const std::string& message, const std::string& signature, RsaSignatureAlgorithm algorithm) const { if (signature.empty()) { LOGE("Signature should not be empty"); return OEMCrypto_ERROR_INVALID_CONTEXT; } return VerifySignature(reinterpret_cast(message.data()), message.size(), reinterpret_cast(signature.data()), signature.size(), algorithm); } OEMCryptoResult RsaPublicKey::VerifySignature( const std::vector& message, const std::vector& signature, RsaSignatureAlgorithm algorithm) const { if (signature.empty()) { LOGE("Signature should not be empty"); return OEMCrypto_ERROR_INVALID_CONTEXT; } return VerifySignature(message.data(), message.size(), signature.data(), signature.size(), algorithm); } OEMCryptoResult RsaPublicKey::EncryptSessionKey( const uint8_t* session_key, size_t session_key_size, uint8_t* enc_session_key, size_t* enc_session_key_size) const { if (session_key == nullptr) { LOGE("Session key is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (session_key_size != kRsaSessionKeySize) { LOGE("Unexpected session key size: expected = %zu, actual = %zu", kRsaSessionKeySize, session_key_size); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (enc_session_key_size == nullptr) { LOGE("Output encrypted session key size is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (enc_session_key == nullptr && *enc_session_key_size > 0) { LOGE("Output encrypted session key is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } return EncryptOaep(session_key, session_key_size, enc_session_key, enc_session_key_size); } std::vector RsaPublicKey::EncryptSessionKey( const std::vector& session_key) const { if (session_key.empty()) { LOGE("Session key is empty"); return std::vector(); } size_t enc_session_key_size = static_cast(RSA_size(key_)); std::vector enc_session_key(enc_session_key_size); const OEMCryptoResult res = EncryptSessionKey(session_key.data(), session_key.size(), enc_session_key.data(), &enc_session_key_size); if (res != OEMCrypto_SUCCESS) { LOGE("Failed to encrypt session key: result = %d", static_cast(res)); enc_session_key.clear(); } else { enc_session_key.resize(enc_session_key_size); } return enc_session_key; } std::vector RsaPublicKey::EncryptSessionKey( const std::string& session_key) const { if (session_key.empty()) { LOGE("Session key is empty"); return std::vector(); } size_t enc_session_key_size = static_cast(RSA_size(key_)); std::vector enc_session_key(enc_session_key_size); const OEMCryptoResult res = EncryptSessionKey( reinterpret_cast(session_key.data()), session_key.size(), enc_session_key.data(), &enc_session_key_size); if (res != OEMCrypto_SUCCESS) { LOGE("Failed to encrypt session key: result = %d", static_cast(res)); enc_session_key.clear(); } else { enc_session_key.resize(enc_session_key_size); } return enc_session_key; } OEMCryptoResult RsaPublicKey::EncryptEncryptionKey( const uint8_t* encryption_key, size_t encryption_key_size, uint8_t* enc_encryption_key, size_t* enc_encryption_key_size) const { if (encryption_key == nullptr) { LOGE("Encryption key is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (encryption_key_size != kEncryptionKeySize) { LOGE("Unexpected encryption key size: expected = %zu, actual = %zu", kEncryptionKeySize, encryption_key_size); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (enc_encryption_key_size == nullptr) { LOGE("Output encrypted encryption key size is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (enc_encryption_key == nullptr && *enc_encryption_key_size > 0) { LOGE("Output encrypted encryption key is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } return EncryptOaep(encryption_key, encryption_key_size, enc_encryption_key, enc_encryption_key_size); } std::vector RsaPublicKey::EncryptEncryptionKey( const std::vector& encryption_key) const { if (encryption_key.empty()) { LOGE("Session key is empty"); return std::vector(); } size_t enc_encryption_key_size = static_cast(RSA_size(key_)); std::vector enc_encryption_key(enc_encryption_key_size); const OEMCryptoResult res = EncryptSessionKey(encryption_key.data(), encryption_key.size(), enc_encryption_key.data(), &enc_encryption_key_size); if (res != OEMCrypto_SUCCESS) { LOGE("Failed to encrypt encryption key: result = %d", static_cast(res)); enc_encryption_key.clear(); } else { enc_encryption_key.resize(enc_encryption_key_size); } return enc_encryption_key; } std::vector RsaPublicKey::EncryptEncryptionKey( const std::string& encryption_key) const { if (encryption_key.empty()) { LOGE("Session key is empty"); return std::vector(); } size_t enc_encryption_key_size = static_cast(RSA_size(key_)); std::vector enc_encryption_key(enc_encryption_key_size); const OEMCryptoResult res = EncryptSessionKey(reinterpret_cast(encryption_key.data()), encryption_key.size(), enc_encryption_key.data(), &enc_encryption_key_size); if (res != OEMCrypto_SUCCESS) { LOGE("Failed to encrypt encryption key: result = %d", static_cast(res)); enc_encryption_key.clear(); } else { enc_encryption_key.resize(enc_encryption_key_size); } return enc_encryption_key; } RsaPublicKey::~RsaPublicKey() { if (key_ != nullptr) { RSA_free(key_); key_ = nullptr; } allowed_schemes_ = 0; field_size_ = kRsaFieldUnknown; } bool RsaPublicKey::InitFromBuffer(const uint8_t* buffer, size_t length) { // Step 1: Deserialize SubjectPublicKeyInfo as RSA key. const uint8_t* tp = buffer; ScopedRsaKey key(d2i_RSA_PUBKEY(nullptr, &tp, length)); if (!key) { LOGE("Failed to parse key"); return false; } // Step 2: Verify key. const int bits = RSA_bits(key.get()); field_size_ = RealBitSizeToFieldSize(bits); if (field_size_ == kRsaFieldUnknown) { LOGE("Unsupported RSA key size: bits = %d", bits); return false; } allowed_schemes_ = kSign_RSASSA_PSS; key_ = key.release(); return true; } bool RsaPublicKey::InitFromPrivateKey(const RsaPrivateKey& private_key) { ScopedRsaKey key(RSA_new()); if (!key) { LOGE("Failed to allocate key"); return false; } const BIGNUM* n = nullptr; const BIGNUM* e = nullptr; const BIGNUM* d = nullptr; RSA_get0_key(private_key.GetRsaKey(), &n, &e, &d); if (n == nullptr || e == nullptr) { LOGE("Failed to get RSA private key components"); return false; } ScopedBigNum dub_n(BN_dup(n)); ScopedBigNum dub_e(BN_dup(e)); if (!dub_n || !dub_e) { LOGE("Failed to duplicate RSA public key components"); return false; } if (!RSA_set0_key(key.get(), dub_n.get(), dub_e.get(), nullptr)) { LOGE("Failed to RSA public key components"); return false; } // Ownership has transferred to the RSA key. dub_n.release(); dub_e.release(); key_ = key.release(); allowed_schemes_ = private_key.allowed_schemes(); field_size_ = private_key.field_size(); return true; } OEMCryptoResult RsaPublicKey::VerifySignaturePss( const uint8_t* message, size_t message_length, const uint8_t* signature, size_t signature_length) const { // Step 0: Ensure the signature algorithm is supported by key. if (!(allowed_schemes_ & kSign_RSASSA_PSS)) { LOGE("RSA key cannot verify using PSS"); return OEMCrypto_ERROR_INVALID_RSA_KEY; } // Step 1: Create a high-level key from RSA key. ScopedEvpPkey pkey(EVP_PKEY_new()); if (!pkey) { LOGE("Failed to allocate PKEY"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) { LOGE("Failed to set PKEY RSA key"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2a: Setup a EVP MD CTX for PSS Verification. ScopedEvpMdCtx md_ctx = EVP_MD_CTX_new(); if (!md_ctx) { LOGE("Failed to allocate MD CTX"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx| int res = EVP_DigestVerifyInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr, pkey.get()); if (res != 1) { LOGE("Failed to initialize MD CTX for verification"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (pkey_ctx == nullptr) { LOGE("PKEY CTX is unexpectedly null"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2b: Configure OEMCrypto RSASSA-PSS options. res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING); if (res != 1) { LOGE("Failed to set PSS padding"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength); if (res != 1) { LOGE("Failed to set PSS salt length"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1()); if (res != 1) { LOGE("Failed to set PSS MGF1 MD"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 3: Digest message. if (message_length > 0) { res = EVP_DigestVerifyUpdate(md_ctx.get(), message, message_length); if (res != 1) { LOGE("Failed to update MD"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } } // Step 4: Verify message. res = EVP_DigestVerifyFinal(md_ctx.get(), signature, signature_length); if (res < 0) { LOGE("Failed to perform RSASSA-PSS-VERIFY"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } return res ? OEMCrypto_SUCCESS : OEMCrypto_ERROR_SIGNATURE_FAILURE; } OEMCryptoResult RsaPublicKey::VerifySignaturePkcs1Cast( const uint8_t* message, size_t message_length, const uint8_t* signature, size_t signature_length) const { // Step 0: Ensure the signature algorithm is supported by key. if (!(allowed_schemes_ & kSign_PKCS1_Block1)) { LOGE("RSA key cannot verify using PKCS1"); return OEMCrypto_ERROR_INVALID_RSA_KEY; } if (message_length > kRsaPkcs1CastMaxMessageSize) { LOGE("Message is too large for CAST PKCS1 signature: size = %zu", message_length); return OEMCrypto_ERROR_SIGNATURE_FAILURE; } // Step 1: Convert encrypted blob into digest. std::vector digest(RSA_size(key_)); const int res = RSA_public_decrypt(signature_length, signature, digest.data(), key_, RSA_PKCS1_PADDING); if (res <= 0) { LOGE("Failed to perform public decryption"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2: Compare digests. const size_t digest_size = static_cast(res); if (digest_size != message_length) { LOGD("Digest size does not match"); return OEMCrypto_ERROR_SIGNATURE_FAILURE; } digest.resize(digest_size); for (size_t i = 0; i < digest_size; i++) { if (message[i] != digest[i]) { return OEMCrypto_ERROR_SIGNATURE_FAILURE; } } return OEMCrypto_SUCCESS; } OEMCryptoResult RsaPublicKey::EncryptOaep(const uint8_t* message, size_t message_size, uint8_t* enc_message, size_t* enc_message_length) const { // Step 1: Encrypt using RSAES-OAEP. std::vector encrypt_buffer(RSA_size(key_)); const int enc_res = RSA_public_encrypt(message_size, message, encrypt_buffer.data(), key_, RSA_PKCS1_OAEP_PADDING); if (enc_res < 0) { LOGE("Failed to perform RSAES-OAEP encrypt"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2: Copy encrypted data key. const size_t enc_size = static_cast(enc_res); if (*enc_message_length < enc_size) { *enc_message_length = enc_size; return OEMCrypto_ERROR_SHORT_BUFFER; } *enc_message_length = enc_size; memcpy(enc_message, encrypt_buffer.data(), enc_size); return OEMCrypto_SUCCESS; } // ===== ===== ===== RSA Private Key ===== ===== ===== // static std::unique_ptr RsaPrivateKey::New(RsaFieldSize field_size) { std::unique_ptr key(new RsaPrivateKey()); if (!key->InitFromFieldSize(field_size)) { LOGE("Failed to initialize private key from field size"); key.reset(); } return key; } // static std::unique_ptr RsaPrivateKey::Load(const uint8_t* buffer, size_t length) { std::unique_ptr key; if (buffer == nullptr) { LOGE("Provided private key buffer is null"); return key; } if (length == 0) { LOGE("Provided private key buffer is zero length"); return key; } key.reset(new RsaPrivateKey()); if (!key->InitFromBuffer(buffer, length)) { LOGE("Failed to initialize private key from buffer"); key.reset(); } return key; } // static std::unique_ptr RsaPrivateKey::Load(const std::string& buffer) { if (buffer.empty()) { LOGE("Provided private key buffer is empty"); return std::unique_ptr(); } return Load(reinterpret_cast(buffer.data()), buffer.size()); } // static std::unique_ptr RsaPrivateKey::Load( const std::vector& buffer) { if (buffer.empty()) { LOGE("Provided private key buffer is empty"); return std::unique_ptr(); } return Load(buffer.data(), buffer.size()); } std::unique_ptr RsaPrivateKey::MakePublicKey() const { return RsaPublicKey::New(*this); } bool RsaPrivateKey::IsMatchingPublicKey(const RsaPublicKey& public_key) const { return RsaKeysAreMatchingPair(public_key.GetRsaKey(), GetRsaKey()); } OEMCryptoResult RsaPrivateKey::Serialize(uint8_t* buffer, size_t* buffer_size) const { if (buffer_size == nullptr) { LOGE("Output buffer size is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (buffer == nullptr && *buffer_size > 0) { LOGE("Output buffer is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } // Step 1: Convert RSA key to EVP. ScopedEvpPkey pkey(EVP_PKEY_new()); if (!pkey) { LOGE("Failed to allocate EVP"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) { LOGE("Failed to set EVP RSA key"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2: Convert RSA EVP to PKCS8 format. ScopedRsaPrivKeyInfo priv_info(EVP_PKEY2PKCS8(pkey.get())); if (!priv_info) { LOGE("Failed to convert RSA key to PKCS8 info"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 3: Serialize PKCS8 to DER encoding. ScopedBio bio(BIO_new(BIO_s_mem())); if (!bio) { LOGE("Failed to allocate IO buffer for RSA key"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (!i2d_PKCS8_PRIV_KEY_INFO_bio(bio.get(), priv_info.get())) { LOGE("Failed to serialize RSA key"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 4: Determine key size and copy. char* key_ptr = nullptr; const long key_size = BIO_get_mem_data(bio.get(), &key_ptr); if (key_size < 0) { LOGE("Failed to get RSA key size"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } const size_t required_size = static_cast(key_size) + (explicit_schemes_ ? 8 : 0); if (*buffer_size < required_size) { *buffer_size = required_size; return OEMCrypto_ERROR_SHORT_BUFFER; } *buffer_size = required_size; if (explicit_schemes_) { memcpy(buffer, "SIGN", 4); const uint32_t allowed_schemes = htonl(allowed_schemes_); memcpy(&buffer[4], &allowed_schemes, 4); memcpy(&buffer[8], key_ptr, required_size - 8); } else { memcpy(buffer, key_ptr, required_size); } return OEMCrypto_SUCCESS; } std::vector RsaPrivateKey::Serialize() const { size_t key_size = kPrivateKeySize; std::vector key_data(key_size, 0); OEMCryptoResult res = Serialize(key_data.data(), &key_size); if (res == OEMCrypto_ERROR_SHORT_BUFFER) { LOGD( "Actual RSA private key size is larger than expected: " "expected = %zu, actual = %zu", kPrivateKeySize, key_size); key_data.resize(key_size); res = Serialize(key_data.data(), &key_size); } if (res != OEMCrypto_SUCCESS) { LOGE("Failed to serialize private key: result = %d", static_cast(res)); key_data.clear(); } else { key_data.resize(key_size); } return key_data; } OEMCryptoResult RsaPrivateKey::GenerateSignature( const uint8_t* message, size_t message_length, RsaSignatureAlgorithm algorithm, uint8_t* signature, size_t* signature_length) const { if (signature_length == nullptr) { LOGE("Output signature size is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (signature == nullptr && *signature_length > 0) { LOGE("Output signature is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (message == nullptr && message_length > 0) { LOGE("Invalid message data"); return OEMCrypto_ERROR_INVALID_CONTEXT; } switch (algorithm) { case kRsaPssDefault: return GenerateSignaturePss(message, message_length, signature, signature_length); case kRsaPkcs1Cast: return GenerateSignaturePkcs1Cast(message, message_length, signature, signature_length); } LOGE("Unknown RSA signature algorithm: %d", static_cast(algorithm)); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } std::vector RsaPrivateKey::GenerateSignature( const std::string& message, RsaSignatureAlgorithm algorithm) const { size_t signature_size = SignatureSize(); std::vector signature(signature_size); const OEMCryptoResult res = GenerateSignature( reinterpret_cast(message.data()), message.size(), algorithm, signature.data(), &signature_size); if (res != OEMCrypto_SUCCESS) { LOGE("Failed to generate signature: result = %d", static_cast(res)); signature.clear(); } else { signature.resize(signature_size); } return signature; } std::vector RsaPrivateKey::GenerateSignature( const std::vector& message, RsaSignatureAlgorithm algorithm) const { size_t signature_size = SignatureSize(); std::vector signature(signature_size, 0); const OEMCryptoResult res = GenerateSignature(message.data(), message.size(), algorithm, signature.data(), &signature_size); if (res != OEMCrypto_SUCCESS) { LOGE("Failed to generate signature: result = %d", static_cast(res)); signature.clear(); } else { signature.resize(signature_size); } return signature; } size_t RsaPrivateKey::SignatureSize() const { return RSA_size(key_); } OEMCryptoResult RsaPrivateKey::DecryptSessionKey( const uint8_t* enc_session_key, size_t enc_session_key_size, uint8_t* session_key, size_t* session_key_size) const { if (enc_session_key == nullptr || enc_session_key_size == 0) { LOGE("Encrypted session key is %s", enc_session_key ? "empty" : "null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (session_key_size == nullptr) { LOGE("Output session key size buffer is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (session_key == nullptr && *session_key_size > 0) { LOGE("Output session key buffer is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (*session_key_size < kRsaSessionKeySize) { *session_key_size = kRsaSessionKeySize; return OEMCrypto_ERROR_SHORT_BUFFER; } const OEMCryptoResult result = DecryptOaep( enc_session_key, enc_session_key_size, session_key, kRsaSessionKeySize); if (result == OEMCrypto_SUCCESS) { *session_key_size = kRsaSessionKeySize; } else { LOGE("Failed to decrypt session key"); } return result; } std::vector RsaPrivateKey::DecryptSessionKey( const std::vector& enc_session_key) const { size_t session_key_size = kRsaSessionKeySize; std::vector session_key(session_key_size, 0); const OEMCryptoResult res = DecryptSessionKey(enc_session_key.data(), enc_session_key.size(), session_key.data(), &session_key_size); if (res != OEMCrypto_SUCCESS) { session_key.clear(); } else { session_key.resize(session_key_size); } return session_key; } std::vector RsaPrivateKey::DecryptSessionKey( const std::string& enc_session_key) const { size_t session_key_size = kRsaSessionKeySize; std::vector session_key(session_key_size, 0); const OEMCryptoResult res = DecryptSessionKey( reinterpret_cast(enc_session_key.data()), enc_session_key.size(), session_key.data(), &session_key_size); if (res != OEMCrypto_SUCCESS) { session_key.clear(); } else { session_key.resize(session_key_size); } return session_key; } size_t RsaPrivateKey::SessionKeyLength() const { return kRsaSessionKeySize; } OEMCryptoResult RsaPrivateKey::DecryptEncryptionKey( const uint8_t* enc_encryption_key, size_t enc_encryption_key_size, uint8_t* encryption_key, size_t* encryption_key_size) const { if (enc_encryption_key == nullptr || enc_encryption_key_size == 0) { LOGE("Encrypted encryption key is %s", enc_encryption_key ? "empty" : "null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (encryption_key_size == nullptr) { LOGE("Output encryption key size buffer is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (encryption_key == nullptr && *encryption_key_size > 0) { LOGE("Output encryption key buffer is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } if (*encryption_key_size < kEncryptionKeySize) { *encryption_key_size = kEncryptionKeySize; return OEMCrypto_ERROR_SHORT_BUFFER; } const OEMCryptoResult result = DecryptOaep(enc_encryption_key, enc_encryption_key_size, encryption_key, kEncryptionKeySize); if (result == OEMCrypto_SUCCESS) { *encryption_key_size = kEncryptionKeySize; } else { LOGE("Failed to decrypt encryption key"); } return result; } std::vector RsaPrivateKey::DecryptEncryptionKey( const std::vector& enc_encryption_key) const { size_t encryption_key_size = kEncryptionKeySize; std::vector encryption_key(encryption_key_size, 0); const OEMCryptoResult res = DecryptEncryptionKey(enc_encryption_key.data(), enc_encryption_key.size(), encryption_key.data(), &encryption_key_size); if (res != OEMCrypto_SUCCESS) { encryption_key.clear(); } else { encryption_key.resize(encryption_key_size); } return encryption_key; } std::vector RsaPrivateKey::DecryptEncryptionKey( const std::string& enc_encryption_key) const { size_t encryption_key_size = kEncryptionKeySize; std::vector encryption_key(encryption_key_size, 0); const OEMCryptoResult res = DecryptEncryptionKey( reinterpret_cast(enc_encryption_key.data()), enc_encryption_key.size(), encryption_key.data(), &encryption_key_size); if (res != OEMCrypto_SUCCESS) { encryption_key.clear(); } else { encryption_key.resize(encryption_key_size); } return encryption_key; } RsaPrivateKey::~RsaPrivateKey() { if (key_ != nullptr) { RSA_free(key_); key_ = nullptr; } allowed_schemes_ = 0; explicit_schemes_ = false; field_size_ = kRsaFieldUnknown; } bool RsaPrivateKey::InitFromBuffer(const uint8_t* buffer, size_t length) { if (length < 8) { LOGE("Public key is too small: length = %zu", length); return false; } ScopedBio bio; // Check allowed scheme type. if (!memcmp("SIGN", buffer, 4)) { uint32_t allowed_schemes; memcpy(&allowed_schemes, reinterpret_cast(&buffer[4]), 4); allowed_schemes_ = ntohl(allowed_schemes); bio.reset(BIO_new_mem_buf(&buffer[8], length - 8)); explicit_schemes_ = true; } else { allowed_schemes_ = kSign_RSASSA_PSS; bio.reset(BIO_new_mem_buf(buffer, length)); } if (!bio) { LOGE("Failed to allocate BIO buffer"); return false; } // Step 1: Deserializes PKCS8 RSA private key info. ScopedRsaPrivKeyInfo priv_info( d2i_PKCS8_PRIV_KEY_INFO_bio(bio.get(), nullptr)); if (!priv_info) { LOGE("Failed to parse private key"); return false; } // Step 2: Convert to RSA key. ScopedEvpPkey pkey(EVP_PKCS82PKEY(priv_info.get())); if (!pkey) { LOGE("Failed to convert PKCS8 to EVP"); return false; } ScopedRsaKey key(EVP_PKEY_get1_RSA(pkey.get())); if (!key) { LOGE("Failed to get RSA key"); return false; } const int check = RSA_check_key(key.get()); if (check == 0) { LOGE("RSA key parameters are invalid"); return false; } else if (check == -1) { LOGE("Failed to check RSA key"); return false; } // Step 3: Verify field width. const int bits = RSA_bits(key.get()); field_size_ = RealBitSizeToFieldSize(bits); if (field_size_ == kRsaFieldUnknown) { LOGE("Unsupported RSA key size: bits = %d", bits); return false; } key_ = key.release(); return true; } bool RsaPrivateKey::InitFromFieldSize(RsaFieldSize field_size) { if (field_size != kRsa2048Bit && field_size != kRsa3072Bit) { LOGE("Unsupported RSA field size: bits = %d", static_cast(field_size)); return false; } // Step 1: Create exponent. ScopedBigNum exp(BN_new()); if (!exp) { LOGE("Failed to allocate RSA exponent"); return false; } if (!BN_set_word(exp.get(), RSA_F4)) { LOGE("Failed to set RSA exponent"); return false; } // Step 2: Generate RSA key. ScopedRsaKey key(RSA_new()); if (!key) { LOGE("Failed to allocate RSA key"); return false; } if (!RSA_generate_key_ex(key.get(), static_cast(field_size), exp.get(), nullptr)) { LOGE("Failed to generate RSA key"); return false; } key_ = key.release(); allowed_schemes_ = kSign_RSASSA_PSS; field_size_ = field_size; return true; } OEMCryptoResult RsaPrivateKey::GenerateSignaturePss( const uint8_t* message, size_t message_length, uint8_t* signature, size_t* signature_length) const { // Step 0: Ensure the signature algorithm is supported by key. if (!(allowed_schemes_ & kSign_RSASSA_PSS)) { LOGE("RSA key cannot sign using PSS"); return OEMCrypto_ERROR_INVALID_RSA_KEY; } // Step 1: Create a high-level key from RSA key. ScopedEvpPkey pkey(EVP_PKEY_new()); if (!pkey) { LOGE("Failed to allocate PKEY"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) { LOGE("Failed to set PKEY RSA key"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2a: Setup a EVP MD CTX for PSS Signature Generation. ScopedEvpMdCtx md_ctx(EVP_MD_CTX_new()); if (!md_ctx) { LOGE("Failed to allocate MD CTX"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx| int res = EVP_DigestSignInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr, pkey.get()); if (res != 1) { LOGE("Failed to initialize MD CTX for signing"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (pkey_ctx == nullptr) { LOGE("PKEY CTX is unexpectedly null"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2b: Configure OEMCrypto RSASSA-PSS options. res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING); if (res != 1) { LOGE("Failed to set PSS padding"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength); if (res != 1) { LOGE("Failed to set PSS salt length"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1()); if (res != 1) { LOGE("Failed to set PSS MGF1 MD"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 3: Digest message. if (message_length > 0) { res = EVP_DigestSignUpdate(md_ctx.get(), message, message_length); if (res != 1) { LOGE("Failed to update MD"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } } // Step 4a: Determine size of signature. size_t actual_signature_length = 0; res = EVP_DigestSignFinal(md_ctx.get(), nullptr, &actual_signature_length); if (res != 1) { LOGE("Failed to determine signature length"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (*signature_length < actual_signature_length) { *signature_length = actual_signature_length; return OEMCrypto_ERROR_SHORT_BUFFER; } // Step 4b: Generate signature. res = EVP_DigestSignFinal(md_ctx.get(), signature, signature_length); if (res != 1) { LOGE("Failed to perform RSASSA-PSS-SIGN"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } return OEMCrypto_SUCCESS; } OEMCryptoResult RsaPrivateKey::GenerateSignaturePkcs1Cast( const uint8_t* message, size_t message_length, uint8_t* signature, size_t* signature_length) const { // Step 0: Ensure the signature algorithm is supported by key. if (!(allowed_schemes_ & kSign_PKCS1_Block1)) { LOGE("RSA key cannot sign PKCS1"); return OEMCrypto_ERROR_INVALID_RSA_KEY; } if (message_length > kRsaPkcs1CastMaxMessageSize) { LOGE("Message is too large for CAST PKCS1 signature: size = %zu", message_length); return OEMCrypto_ERROR_SIGNATURE_FAILURE; } // Step 1: Ensure signature buffer is large enough. const size_t expected_signature_size = static_cast(RSA_size(key_)); if (*signature_length < expected_signature_size) { *signature_length = expected_signature_size; return OEMCrypto_ERROR_SHORT_BUFFER; } // Step 2: Encrypt with PKCS1 padding. const int enc_res = RSA_private_encrypt(message_length, message, signature, key_, RSA_PKCS1_PADDING); if (enc_res < 0) { LOGE("Failed to perform private encryption"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } *signature_length = expected_signature_size; return OEMCrypto_SUCCESS; } OEMCryptoResult RsaPrivateKey::DecryptOaep( const uint8_t* enc_message, size_t enc_message_size, uint8_t* message, size_t expected_message_length) const { // Step 1: Decrypt using RSAES-OAEP. std::vector decrypt_buffer(RSA_size(key_)); const int dec_res = RSA_private_decrypt(enc_message_size, enc_message, decrypt_buffer.data(), key_, RSA_PKCS1_OAEP_PADDING); if (dec_res < 0) { LOGE("Failed to perform RSAES-OAEP decrypt"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } if (static_cast(dec_res) != expected_message_length) { LOGE("Unexpected key size: expected = %zu, actual = %d", expected_message_length, dec_res); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2: Copy decrypted data key. memcpy(message, decrypt_buffer.data(), expected_message_length); return OEMCrypto_SUCCESS; } } // namespace wvoec_ref