diff --git a/libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.h b/libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.h index b5ff7b9e..442adc12 100644 --- a/libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.h +++ b/libwvdrmengine/oemcrypto/ref/src/oemcrypto_ecc_key.h @@ -70,10 +70,10 @@ class EccPublicKey { // Serializes the public key into an ASN.1 DER encoded SubjectPublicKey // representation. - // On success, |*buffer_size| is populated with the number of bytes + // On success, |buffer_size| is populated with the number of bytes // written to |buffer|, and OEMCrypto_SUCCESS is returned. - // If the provided |*buffer_size| is too small, ERROR_SHORT_BUFFER - // is returned and |*buffer_size| is set to the required buffer size. + // If the provided |buffer_size| is too small, ERROR_SHORT_BUFFER + // is returned and |buffer_size| is set to the required buffer size. OEMCryptoResult Serialize(uint8_t* buffer, size_t* buffer_size) const; // Same as above, except directly returns the serialized key. // Returns an empty vector on error. @@ -167,10 +167,10 @@ class EccPrivateKey { // Serializes the private key into an ASN.1 DER encoded ECPrivateKey // representation. - // On success, |*buffer_size| is populated with the number of bytes + // On success, |buffer_size| is populated with the number of bytes // written to |buffer|, and SUCCESS is returned. - // If the provided |*buffer_size| is too small, - // OEMCrypto_ERROR_SHORT_BUFFER is returned and |*buffer_size| is + // If the provided |buffer_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |buffer_size| is // set to the required buffer size. OEMCryptoResult Serialize(uint8_t* buffer, size_t* buffer_size) const; // Same as above, except directly returns the serialized key. @@ -184,10 +184,10 @@ class EccPrivateKey { // - SHA-256 / secp256r1 // - SHA-384 / secp384r1 (optional support) // - SHA-512 / secp521r1 (optional support) - // On success, |*signature_length| is populated with the number of + // On success, |signature_length| is populated with the number of // bytes written to |signature|, and SUCCESS is returned. - // If the provided |*signature_length| is too small, - // OEMCrypto_ERROR_SHORT_BUFFER is returned and |*signature_length| + // If the provided |signature_length| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |signature_length| // is set to the required signature size. OEMCryptoResult GenerateSignature(const uint8_t* message, size_t message_length, uint8_t* signature, @@ -203,10 +203,10 @@ class EccPrivateKey { // Derives the OEMCrypto session key used for deriving other keys. // The provided public key must be of the same curve. - // On success, |*session_key_size| is populated with the number of + // On success, |session_key_size| is populated with the number of // bytes written to |session_key|, and OEMCrypto_SUCCESS is returned. - // If the provided |*session_key_size| is too small, - // OEMCrypto_ERROR_SHORT_BUFFER is returned and |*session_key_size| + // If the provided |session_key_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |session_key_size| // is set to the required buffer size. OEMCryptoResult DeriveSessionKey(const EccPublicKey& public_key, uint8_t* session_key, diff --git a/libwvdrmengine/oemcrypto/ref/src/oemcrypto_rsa_key.cpp b/libwvdrmengine/oemcrypto/ref/src/oemcrypto_rsa_key.cpp new file mode 100644 index 00000000..71aea9b4 --- /dev/null +++ b/libwvdrmengine/oemcrypto/ref/src/oemcrypto_rsa_key.cpp @@ -0,0 +1,1169 @@ +// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation of OEMCrypto APIs +// +#include "oemcrypto_rsa_key.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "OEMCryptoCENC.h" +#include "log.h" +#include "oemcrypto_types.h" +#include "scoped_object.h" + +namespace wvoec_ref { +namespace { +// Estimated size of an RSA private key. +// The private key constists of: +// n : Modulos : byteSize(n) ~= bitSize(n)/8 +// e : Public exponent : 4 bytes +// d : Private exponent : ~byteSize(n) +// p : Prime 1 : ~byteSize(n)/2 +// q : Prime 2 : ~byteSize(n)/2 +// d (mod p-1) : Exponent 1 : ~byteSize(n)/2 +// d (mod q-1) : Exponent 2 : ~byteSize(n)/2 +// q^-1 (mod p) : Coefficient : ~byteSize(n)/2 +// And the ASN.1 tags for each component (roughly 25 bytes). +constexpr size_t kPrivateKeySize = (3072 / 8) * 5 + 25; +// Estimated size of an RSA public key. +// The public key constists of: +// The private key constists of: +// n : Modulos : byteSize(n) ~= bitSize(n)/8 +// e : Public exponent : 4 bytes +// And the ASN.1 tags + outer structure. +constexpr size_t kPublicKeySize = (3072 / 8) + 100; + +// 128 bit key, intended to be used with CMAC-AES-128. +constexpr size_t kRsaSessionKeySize = 16; +// Encryption key used by OEMCrypto session for encrypting and +// decrypting data. +constexpr size_t kEncryptionKeySize = wvoec::KEY_SIZE; + +// Salt length used by OEMCrypto's RSASSA-PSS implementation. +// See description of kRsaPssDefault for more information. +constexpr size_t kPssSaltLength = 20; +// Requirement of CAST receivers. +constexpr size_t kRsaPkcs1CastMaxMessageSize = 83; + +using ScopedBio = ScopedObject; +using ScopedBigNum = ScopedObject; +using ScopedEvpMdCtx = ScopedObject; +using ScopedEvpPkey = ScopedObject; +using ScopedRsaKey = ScopedObject; +using ScopedRsaPrivKeyInfo = + ScopedObject; + +// Estimates the RSA rough field size from the real bit size of the +// RSA modulos. The actual bit length could vary by a few bits. +RsaFieldSize RealBitSizeToFieldSize(int bits) { + if (bits > 1800 && bits < 2200) { + return kRsa2048Bit; + } + if (bits > 2800 && bits < 3200) { + return kRsa3072Bit; + } + return kRsaFieldUnknown; +} +} // namespace + +std::string RsaFieldSizeToString(RsaFieldSize field_size) { + switch (field_size) { + case kRsa2048Bit: + return "RSA-2048"; + case kRsa3072Bit: + return "RSA-3072"; + case kRsaFieldUnknown: + return "Unknown"; + } + return "Unknown(" + std::to_string(static_cast(field_size)) + ")"; +} + +bool RsaKeysAreMatchingPair(const RSA* public_key, const RSA* private_key) { + if (public_key == nullptr) { + LOGE("Public key is null"); + return false; + } + if (private_key == nullptr) { + LOGE("Private key is null"); + return false; + } + // Step 1: Extract public key components. + const BIGNUM* public_n = nullptr; + const BIGNUM* public_e = nullptr; + const BIGNUM* d = nullptr; + RSA_get0_key(public_key, &public_n, &public_e, &d); + if (public_n == nullptr || public_e == nullptr) { + LOGE("Failed to get RSA public key components"); + return false; + } + // Step 2: Extract private key components. + const BIGNUM* private_n = nullptr; + const BIGNUM* private_e = nullptr; + RSA_get0_key(private_key, &private_n, &private_e, &d); + if (private_n == nullptr || private_e == nullptr) { + LOGE("Failed to get RSA private key components"); + return false; + } + // Step 3: Compare RSA components. + if (BN_cmp(public_n, private_n)) { + LOGD("RSA modulos do not match"); + return false; + } + if (BN_cmp(public_e, private_e)) { + LOGD("RSA exponents do not match"); + return false; + } + return true; +} + +// ===== ===== ===== RSA Public Key ===== ===== ===== + +// static +std::unique_ptr RsaPublicKey::New( + const RsaPrivateKey& private_key) { + std::unique_ptr key(new RsaPublicKey()); + if (!key->InitFromPrivateKey(private_key)) { + LOGE("Failed to initialize public key from private key"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr RsaPublicKey::Load(const uint8_t* buffer, + size_t length) { + std::unique_ptr key; + if (buffer == nullptr) { + LOGE("Provided public key buffer is null"); + return key; + } + if (length == 0) { + LOGE("Provided public key buffer is zero length"); + return key; + } + key.reset(new RsaPublicKey()); + if (!key->InitFromBuffer(buffer, length)) { + LOGE("Failed to initialize public key from buffer"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr RsaPublicKey::Load(const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return std::unique_ptr(); + } + return Load(reinterpret_cast(buffer.data()), buffer.size()); +} + +// static +std::unique_ptr RsaPublicKey::Load( + const std::vector& buffer) { + std::unique_ptr key; + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return std::unique_ptr(); + } + return Load(buffer.data(), buffer.size()); +} + +bool RsaPublicKey::IsMatchingPrivateKey( + const RsaPrivateKey& private_key) const { + if (private_key.field_size() != field_size_) { + return false; + } + return RsaKeysAreMatchingPair(GetRsaKey(), private_key.GetRsaKey()); +} + +OEMCryptoResult RsaPublicKey::Serialize(uint8_t* buffer, + size_t* buffer_size) const { + if (buffer_size == nullptr) { + LOGE("Output buffer size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (buffer == nullptr && *buffer_size > 0) { + LOGE("Output buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + + uint8_t* der_key = nullptr; + const int der_res = i2d_RSA_PUBKEY(key_, &der_key); + if (der_res < 0) { + LOGE("Public key serialization failed"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (der_key == nullptr) { + LOGE("Encoded key is unexpectedly null"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (der_res == 0) { + LOGE("Unexpected DER encoded size"); + OPENSSL_free(der_key); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + const size_t required_size = static_cast(der_res); + if (buffer == nullptr || *buffer_size < required_size) { + *buffer_size = required_size; + OPENSSL_free(der_key); + return OEMCrypto_ERROR_SHORT_BUFFER; + } + memcpy(buffer, der_key, required_size); + *buffer_size = required_size; + OPENSSL_free(der_key); + return OEMCrypto_SUCCESS; +} + +std::vector RsaPublicKey::Serialize() const { + size_t key_size = kPublicKeySize; + std::vector key_data(key_size, 0); + const OEMCryptoResult res = Serialize(key_data.data(), &key_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to serialize public key: result = %d", static_cast(res)); + key_data.clear(); + } else { + key_data.resize(key_size); + } + return key_data; +} + +OEMCryptoResult RsaPublicKey::VerifySignature( + const uint8_t* message, size_t message_length, const uint8_t* signature, + size_t signature_length, RsaSignatureAlgorithm algorithm) const { + if (signature == nullptr || signature_length == 0) { + LOGE("Signature is missing"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (message == nullptr && message_length > 0) { + LOGE("Bad message data"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + switch (algorithm) { + case kRsaPssDefault: + return VerifySignaturePss(message, message_length, signature, + signature_length); + case kRsaPkcs1Cast: + return VerifySignaturePkcs1Cast(message, message_length, signature, + signature_length); + } + LOGE("Unknown RSA signature algorithm: %d", static_cast(algorithm)); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; +} + +OEMCryptoResult RsaPublicKey::VerifySignature( + const std::string& message, const std::string& signature, + RsaSignatureAlgorithm algorithm) const { + if (signature.empty()) { + LOGE("Signature should not be empty"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + return VerifySignature(reinterpret_cast(message.data()), + message.size(), + reinterpret_cast(signature.data()), + signature.size(), algorithm); +} + +OEMCryptoResult RsaPublicKey::VerifySignature( + const std::vector& message, const std::vector& signature, + RsaSignatureAlgorithm algorithm) const { + if (signature.empty()) { + LOGE("Signature should not be empty"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + return VerifySignature(message.data(), message.size(), signature.data(), + signature.size(), algorithm); +} + +OEMCryptoResult RsaPublicKey::EncryptSessionKey( + const uint8_t* session_key, size_t session_key_size, + uint8_t* enc_session_key, size_t* enc_session_key_size) const { + if (session_key == nullptr) { + LOGE("Session key is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (session_key_size != kRsaSessionKeySize) { + LOGE("Unexpected session key size: expected = %zu, actual = %zu", + kRsaSessionKeySize, session_key_size); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (enc_session_key_size == nullptr) { + LOGE("Output encrypted session key size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (enc_session_key == nullptr && *enc_session_key_size > 0) { + LOGE("Output encrypted session key is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + return EncryptOaep(session_key, session_key_size, enc_session_key, + enc_session_key_size); +} + +std::vector RsaPublicKey::EncryptSessionKey( + const std::vector& session_key) const { + if (session_key.empty()) { + LOGE("Session key is empty"); + return std::vector(); + } + size_t enc_session_key_size = static_cast(RSA_size(key_)); + std::vector enc_session_key(enc_session_key_size); + const OEMCryptoResult res = + EncryptSessionKey(session_key.data(), session_key.size(), + enc_session_key.data(), &enc_session_key_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to encrypt session key: result = %d", static_cast(res)); + enc_session_key.clear(); + } else { + enc_session_key.resize(enc_session_key_size); + } + return enc_session_key; +} + +std::vector RsaPublicKey::EncryptSessionKey( + const std::string& session_key) const { + if (session_key.empty()) { + LOGE("Session key is empty"); + return std::vector(); + } + size_t enc_session_key_size = static_cast(RSA_size(key_)); + std::vector enc_session_key(enc_session_key_size); + const OEMCryptoResult res = EncryptSessionKey( + reinterpret_cast(session_key.data()), session_key.size(), + enc_session_key.data(), &enc_session_key_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to encrypt session key: result = %d", static_cast(res)); + enc_session_key.clear(); + } else { + enc_session_key.resize(enc_session_key_size); + } + return enc_session_key; +} + +OEMCryptoResult RsaPublicKey::EncryptEncryptionKey( + const uint8_t* encryption_key, size_t encryption_key_size, + uint8_t* enc_encryption_key, size_t* enc_encryption_key_size) const { + if (encryption_key == nullptr) { + LOGE("Encryption key is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (encryption_key_size != kEncryptionKeySize) { + LOGE("Unexpected encryption key size: expected = %zu, actual = %zu", + kEncryptionKeySize, encryption_key_size); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (enc_encryption_key_size == nullptr) { + LOGE("Output encrypted encryption key size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (enc_encryption_key == nullptr && *enc_encryption_key_size > 0) { + LOGE("Output encrypted encryption key is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + return EncryptOaep(encryption_key, encryption_key_size, enc_encryption_key, + enc_encryption_key_size); +} + +std::vector RsaPublicKey::EncryptEncryptionKey( + const std::vector& encryption_key) const { + if (encryption_key.empty()) { + LOGE("Session key is empty"); + return std::vector(); + } + size_t enc_encryption_key_size = static_cast(RSA_size(key_)); + std::vector enc_encryption_key(enc_encryption_key_size); + const OEMCryptoResult res = + EncryptSessionKey(encryption_key.data(), encryption_key.size(), + enc_encryption_key.data(), &enc_encryption_key_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to encrypt encryption key: result = %d", + static_cast(res)); + enc_encryption_key.clear(); + } else { + enc_encryption_key.resize(enc_encryption_key_size); + } + return enc_encryption_key; +} + +std::vector RsaPublicKey::EncryptEncryptionKey( + const std::string& encryption_key) const { + if (encryption_key.empty()) { + LOGE("Session key is empty"); + return std::vector(); + } + size_t enc_encryption_key_size = static_cast(RSA_size(key_)); + std::vector enc_encryption_key(enc_encryption_key_size); + const OEMCryptoResult res = + EncryptSessionKey(reinterpret_cast(encryption_key.data()), + encryption_key.size(), enc_encryption_key.data(), + &enc_encryption_key_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to encrypt encryption key: result = %d", + static_cast(res)); + enc_encryption_key.clear(); + } else { + enc_encryption_key.resize(enc_encryption_key_size); + } + return enc_encryption_key; +} + +RsaPublicKey::~RsaPublicKey() { + if (key_ != nullptr) { + RSA_free(key_); + key_ = nullptr; + } + allowed_schemes_ = 0; + field_size_ = kRsaFieldUnknown; +} + +bool RsaPublicKey::InitFromBuffer(const uint8_t* buffer, size_t length) { + // Step 1: Deserialize SubjectPublicKeyInfo as RSA key. + const uint8_t* tp = buffer; + ScopedRsaKey key(d2i_RSA_PUBKEY(nullptr, &tp, length)); + if (!key) { + LOGE("Failed to parse key"); + return false; + } + // Step 2: Verify key. + const int bits = RSA_bits(key.get()); + field_size_ = RealBitSizeToFieldSize(bits); + if (field_size_ == kRsaFieldUnknown) { + LOGE("Unsupported RSA key size: bits = %d", bits); + return false; + } + allowed_schemes_ = kSign_RSASSA_PSS; + key_ = key.release(); + return true; +} + +bool RsaPublicKey::InitFromPrivateKey(const RsaPrivateKey& private_key) { + ScopedRsaKey key(RSA_new()); + if (!key) { + LOGE("Failed to allocate key"); + return false; + } + const BIGNUM* n = nullptr; + const BIGNUM* e = nullptr; + const BIGNUM* d = nullptr; + RSA_get0_key(private_key.GetRsaKey(), &n, &e, &d); + if (n == nullptr || e == nullptr) { + LOGE("Failed to get RSA private key components"); + return false; + } + ScopedBigNum dub_n(BN_dup(n)); + ScopedBigNum dub_e(BN_dup(e)); + if (!dub_n || !dub_e) { + LOGE("Failed to duplicate RSA public key components"); + return false; + } + if (!RSA_set0_key(key.get(), dub_n.get(), dub_e.get(), nullptr)) { + LOGE("Failed to RSA public key components"); + return false; + } + // Ownership has transferred to the RSA key. + dub_n.release(); + dub_e.release(); + + key_ = key.release(); + allowed_schemes_ = private_key.allowed_schemes(); + field_size_ = private_key.field_size(); + return true; +} + +OEMCryptoResult RsaPublicKey::VerifySignaturePss( + const uint8_t* message, size_t message_length, const uint8_t* signature, + size_t signature_length) const { + // Step 0: Ensure the signature algorithm is supported by key. + if (!(allowed_schemes_ & kSign_RSASSA_PSS)) { + LOGE("RSA key cannot verify using PSS"); + return OEMCrypto_ERROR_INVALID_RSA_KEY; + } + // Step 1: Create a high-level key from RSA key. + ScopedEvpPkey pkey(EVP_PKEY_new()); + if (!pkey) { + LOGE("Failed to allocate PKEY"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) { + LOGE("Failed to set PKEY RSA key"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2a: Setup a EVP MD CTX for PSS Verification. + ScopedEvpMdCtx md_ctx = EVP_MD_CTX_new(); + if (!md_ctx) { + LOGE("Failed to allocate MD CTX"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx| + int res = EVP_DigestVerifyInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr, + pkey.get()); + if (res != 1) { + LOGE("Failed to initialize MD CTX for verification"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (pkey_ctx == nullptr) { + LOGE("PKEY CTX is unexpectedly null"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2b: Configure OEMCrypto RSASSA-PSS options. + res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING); + if (res != 1) { + LOGE("Failed to set PSS padding"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength); + if (res != 1) { + LOGE("Failed to set PSS salt length"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1()); + if (res != 1) { + LOGE("Failed to set PSS MGF1 MD"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 3: Digest message. + if (message_length > 0) { + res = EVP_DigestVerifyUpdate(md_ctx.get(), message, message_length); + if (res != 1) { + LOGE("Failed to update MD"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + } + // Step 4: Verify message. + res = EVP_DigestVerifyFinal(md_ctx.get(), signature, signature_length); + if (res < 0) { + LOGE("Failed to perform RSASSA-PSS-VERIFY"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + return res ? OEMCrypto_SUCCESS : OEMCrypto_ERROR_SIGNATURE_FAILURE; +} + +OEMCryptoResult RsaPublicKey::VerifySignaturePkcs1Cast( + const uint8_t* message, size_t message_length, const uint8_t* signature, + size_t signature_length) const { + // Step 0: Ensure the signature algorithm is supported by key. + if (!(allowed_schemes_ & kSign_PKCS1_Block1)) { + LOGE("RSA key cannot verify using PKCS1"); + return OEMCrypto_ERROR_INVALID_RSA_KEY; + } + if (message_length > kRsaPkcs1CastMaxMessageSize) { + LOGE("Message is too large for CAST PKCS1 signature: size = %zu", + message_length); + return OEMCrypto_ERROR_SIGNATURE_FAILURE; + } + // Step 1: Convert encrypted blob into digest. + std::vector digest(RSA_size(key_)); + const int res = RSA_public_decrypt(signature_length, signature, digest.data(), + key_, RSA_PKCS1_PADDING); + if (res <= 0) { + LOGE("Failed to perform public decryption"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2: Compare digests. + const size_t digest_size = static_cast(res); + if (digest_size != message_length) { + LOGD("Digest size does not match"); + return OEMCrypto_ERROR_SIGNATURE_FAILURE; + } + digest.resize(digest_size); + for (size_t i = 0; i < digest_size; i++) { + if (message[i] != digest[i]) { + return OEMCrypto_ERROR_SIGNATURE_FAILURE; + } + } + return OEMCrypto_SUCCESS; +} + +OEMCryptoResult RsaPublicKey::EncryptOaep(const uint8_t* message, + size_t message_size, + uint8_t* enc_message, + size_t* enc_message_length) const { + // Step 1: Encrypt using RSAES-OAEP. + std::vector encrypt_buffer(RSA_size(key_)); + const int enc_res = + RSA_public_encrypt(message_size, message, encrypt_buffer.data(), key_, + RSA_PKCS1_OAEP_PADDING); + if (enc_res < 0) { + LOGE("Failed to perform RSAES-OAEP encrypt"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2: Copy encrypted data key. + const size_t enc_size = static_cast(enc_res); + if (*enc_message_length < enc_size) { + *enc_message_length = enc_size; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + *enc_message_length = enc_size; + memcpy(enc_message, encrypt_buffer.data(), enc_size); + return OEMCrypto_SUCCESS; +} + +// ===== ===== ===== RSA Private Key ===== ===== ===== + +// static +std::unique_ptr RsaPrivateKey::New(RsaFieldSize field_size) { + std::unique_ptr key(new RsaPrivateKey()); + if (!key->InitFromFieldSize(field_size)) { + LOGE("Failed to initialize private key from field size"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr RsaPrivateKey::Load(const uint8_t* buffer, + size_t length) { + std::unique_ptr key; + if (buffer == nullptr) { + LOGE("Provided private key buffer is null"); + return key; + } + if (length == 0) { + LOGE("Provided private key buffer is zero length"); + return key; + } + key.reset(new RsaPrivateKey()); + if (!key->InitFromBuffer(buffer, length)) { + LOGE("Failed to initialize private key from buffer"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr RsaPrivateKey::Load(const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return std::unique_ptr(); + } + return Load(reinterpret_cast(buffer.data()), buffer.size()); +} + +// static +std::unique_ptr RsaPrivateKey::Load( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return std::unique_ptr(); + } + return Load(buffer.data(), buffer.size()); +} + +std::unique_ptr RsaPrivateKey::MakePublicKey() const { + return RsaPublicKey::New(*this); +} + +bool RsaPrivateKey::IsMatchingPublicKey(const RsaPublicKey& public_key) const { + return RsaKeysAreMatchingPair(public_key.GetRsaKey(), GetRsaKey()); +} + +OEMCryptoResult RsaPrivateKey::Serialize(uint8_t* buffer, + size_t* buffer_size) const { + if (buffer_size == nullptr) { + LOGE("Output buffer size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (buffer == nullptr && *buffer_size > 0) { + LOGE("Output buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + // Step 1: Convert RSA key to EVP. + ScopedEvpPkey pkey(EVP_PKEY_new()); + if (!pkey) { + LOGE("Failed to allocate EVP"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) { + LOGE("Failed to set EVP RSA key"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2: Convert RSA EVP to PKCS8 format. + ScopedRsaPrivKeyInfo priv_info(EVP_PKEY2PKCS8(pkey.get())); + if (!priv_info) { + LOGE("Failed to convert RSA key to PKCS8 info"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 3: Serialize PKCS8 to DER encoding. + ScopedBio bio(BIO_new(BIO_s_mem())); + if (!bio) { + LOGE("Failed to allocate IO buffer for RSA key"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (!i2d_PKCS8_PRIV_KEY_INFO_bio(bio.get(), priv_info.get())) { + LOGE("Failed to serialize RSA key"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 4: Determine key size and copy. + char* key_ptr = nullptr; + const long key_size = BIO_get_mem_data(bio.get(), &key_ptr); + if (key_size < 0) { + LOGE("Failed to get RSA key size"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + const size_t required_size = + static_cast(key_size) + (explicit_schemes_ ? 8 : 0); + if (*buffer_size < required_size) { + *buffer_size = required_size; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + *buffer_size = required_size; + if (explicit_schemes_) { + memcpy(buffer, "SIGN", 4); + const uint32_t allowed_schemes = htonl(allowed_schemes_); + memcpy(&buffer[4], &allowed_schemes, 4); + memcpy(&buffer[8], key_ptr, required_size - 8); + } else { + memcpy(buffer, key_ptr, required_size); + } + return OEMCrypto_SUCCESS; +} + +std::vector RsaPrivateKey::Serialize() const { + size_t key_size = kPrivateKeySize; + std::vector key_data(key_size, 0); + OEMCryptoResult res = Serialize(key_data.data(), &key_size); + if (res == OEMCrypto_ERROR_SHORT_BUFFER) { + LOGD( + "Actual RSA private key size is larger than expected: " + "expected = %zu, actual = %zu", + kPrivateKeySize, key_size); + key_data.resize(key_size); + res = Serialize(key_data.data(), &key_size); + } + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to serialize private key: result = %d", static_cast(res)); + key_data.clear(); + } else { + key_data.resize(key_size); + } + return key_data; +} + +OEMCryptoResult RsaPrivateKey::GenerateSignature( + const uint8_t* message, size_t message_length, + RsaSignatureAlgorithm algorithm, uint8_t* signature, + size_t* signature_length) const { + if (signature_length == nullptr) { + LOGE("Output signature size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (signature == nullptr && *signature_length > 0) { + LOGE("Output signature is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (message == nullptr && message_length > 0) { + LOGE("Invalid message data"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + switch (algorithm) { + case kRsaPssDefault: + return GenerateSignaturePss(message, message_length, signature, + signature_length); + case kRsaPkcs1Cast: + return GenerateSignaturePkcs1Cast(message, message_length, signature, + signature_length); + } + LOGE("Unknown RSA signature algorithm: %d", static_cast(algorithm)); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; +} + +std::vector RsaPrivateKey::GenerateSignature( + const std::string& message, RsaSignatureAlgorithm algorithm) const { + size_t signature_size = SignatureSize(); + std::vector signature(signature_size); + const OEMCryptoResult res = GenerateSignature( + reinterpret_cast(message.data()), message.size(), + algorithm, signature.data(), &signature_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to generate signature: result = %d", static_cast(res)); + signature.clear(); + } else { + signature.resize(signature_size); + } + return signature; +} + +std::vector RsaPrivateKey::GenerateSignature( + const std::vector& message, + RsaSignatureAlgorithm algorithm) const { + size_t signature_size = SignatureSize(); + std::vector signature(signature_size, 0); + const OEMCryptoResult res = + GenerateSignature(message.data(), message.size(), algorithm, + signature.data(), &signature_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to generate signature: result = %d", static_cast(res)); + signature.clear(); + } else { + signature.resize(signature_size); + } + return signature; +} + +size_t RsaPrivateKey::SignatureSize() const { return RSA_size(key_); } + +OEMCryptoResult RsaPrivateKey::DecryptSessionKey( + const uint8_t* enc_session_key, size_t enc_session_key_size, + uint8_t* session_key, size_t* session_key_size) const { + if (enc_session_key == nullptr || enc_session_key_size == 0) { + LOGE("Encrypted session key is %s", enc_session_key ? "empty" : "null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (session_key_size == nullptr) { + LOGE("Output session key size buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (session_key == nullptr && *session_key_size > 0) { + LOGE("Output session key buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (*session_key_size < kRsaSessionKeySize) { + *session_key_size = kRsaSessionKeySize; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + const OEMCryptoResult result = DecryptOaep( + enc_session_key, enc_session_key_size, session_key, kRsaSessionKeySize); + if (result == OEMCrypto_SUCCESS) { + *session_key_size = kRsaSessionKeySize; + } else { + LOGE("Failed to decrypt session key"); + } + return result; +} + +std::vector RsaPrivateKey::DecryptSessionKey( + const std::vector& enc_session_key) const { + size_t session_key_size = kRsaSessionKeySize; + std::vector session_key(session_key_size, 0); + const OEMCryptoResult res = + DecryptSessionKey(enc_session_key.data(), enc_session_key.size(), + session_key.data(), &session_key_size); + if (res != OEMCrypto_SUCCESS) { + session_key.clear(); + } else { + session_key.resize(session_key_size); + } + return session_key; +} + +std::vector RsaPrivateKey::DecryptSessionKey( + const std::string& enc_session_key) const { + size_t session_key_size = kRsaSessionKeySize; + std::vector session_key(session_key_size, 0); + const OEMCryptoResult res = DecryptSessionKey( + reinterpret_cast(enc_session_key.data()), + enc_session_key.size(), session_key.data(), &session_key_size); + if (res != OEMCrypto_SUCCESS) { + session_key.clear(); + } else { + session_key.resize(session_key_size); + } + return session_key; +} + +size_t RsaPrivateKey::SessionKeyLength() const { return kRsaSessionKeySize; } + +OEMCryptoResult RsaPrivateKey::DecryptEncryptionKey( + const uint8_t* enc_encryption_key, size_t enc_encryption_key_size, + uint8_t* encryption_key, size_t* encryption_key_size) const { + if (enc_encryption_key == nullptr || enc_encryption_key_size == 0) { + LOGE("Encrypted encryption key is %s", + enc_encryption_key ? "empty" : "null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (encryption_key_size == nullptr) { + LOGE("Output encryption key size buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (encryption_key == nullptr && *encryption_key_size > 0) { + LOGE("Output encryption key buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (*encryption_key_size < kEncryptionKeySize) { + *encryption_key_size = kEncryptionKeySize; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + const OEMCryptoResult result = + DecryptOaep(enc_encryption_key, enc_encryption_key_size, encryption_key, + kEncryptionKeySize); + if (result == OEMCrypto_SUCCESS) { + *encryption_key_size = kEncryptionKeySize; + } else { + LOGE("Failed to decrypt encryption key"); + } + return result; +} + +std::vector RsaPrivateKey::DecryptEncryptionKey( + const std::vector& enc_encryption_key) const { + size_t encryption_key_size = kEncryptionKeySize; + std::vector encryption_key(encryption_key_size, 0); + const OEMCryptoResult res = + DecryptEncryptionKey(enc_encryption_key.data(), enc_encryption_key.size(), + encryption_key.data(), &encryption_key_size); + if (res != OEMCrypto_SUCCESS) { + encryption_key.clear(); + } else { + encryption_key.resize(encryption_key_size); + } + return encryption_key; +} + +std::vector RsaPrivateKey::DecryptEncryptionKey( + const std::string& enc_encryption_key) const { + size_t encryption_key_size = kEncryptionKeySize; + std::vector encryption_key(encryption_key_size, 0); + const OEMCryptoResult res = DecryptEncryptionKey( + reinterpret_cast(enc_encryption_key.data()), + enc_encryption_key.size(), encryption_key.data(), &encryption_key_size); + if (res != OEMCrypto_SUCCESS) { + encryption_key.clear(); + } else { + encryption_key.resize(encryption_key_size); + } + return encryption_key; +} + +RsaPrivateKey::~RsaPrivateKey() { + if (key_ != nullptr) { + RSA_free(key_); + key_ = nullptr; + } + allowed_schemes_ = 0; + explicit_schemes_ = false; + field_size_ = kRsaFieldUnknown; +} + +bool RsaPrivateKey::InitFromBuffer(const uint8_t* buffer, size_t length) { + if (length < 8) { + LOGE("Public key is too small: length = %zu", length); + return false; + } + ScopedBio bio; + // Check allowed scheme type. + if (!memcmp("SIGN", buffer, 4)) { + uint32_t allowed_schemes; + memcpy(&allowed_schemes, reinterpret_cast(&buffer[4]), 4); + allowed_schemes_ = ntohl(allowed_schemes); + bio.reset(BIO_new_mem_buf(&buffer[8], length - 8)); + explicit_schemes_ = true; + } else { + allowed_schemes_ = kSign_RSASSA_PSS; + bio.reset(BIO_new_mem_buf(buffer, length)); + } + if (!bio) { + LOGE("Failed to allocate BIO buffer"); + return false; + } + // Step 1: Deserializes PKCS8 RSA private key info. + ScopedRsaPrivKeyInfo priv_info( + d2i_PKCS8_PRIV_KEY_INFO_bio(bio.get(), nullptr)); + if (!priv_info) { + LOGE("Failed to parse private key"); + return false; + } + // Step 2: Convert to RSA key. + ScopedEvpPkey pkey(EVP_PKCS82PKEY(priv_info.get())); + if (!pkey) { + LOGE("Failed to convert PKCS8 to EVP"); + return false; + } + ScopedRsaKey key(EVP_PKEY_get1_RSA(pkey.get())); + if (!key) { + LOGE("Failed to get RSA key"); + return false; + } + const int check = RSA_check_key(key.get()); + if (check == 0) { + LOGE("RSA key parameters are invalid"); + return false; + } else if (check == -1) { + LOGE("Failed to check RSA key"); + return false; + } + // Step 3: Verify field width. + const int bits = RSA_bits(key.get()); + LOGD("Loaded RSA private key size: bits = %d", bits); + field_size_ = RealBitSizeToFieldSize(bits); + if (field_size_ == kRsaFieldUnknown) { + LOGE("Unsupported RSA key size: bits = %d", bits); + return false; + } + key_ = key.release(); + return true; +} + +bool RsaPrivateKey::InitFromFieldSize(RsaFieldSize field_size) { + if (field_size != kRsa2048Bit && field_size != kRsa3072Bit) { + LOGE("Unsupported RSA field size: bits = %d", static_cast(field_size)); + return false; + } + // Step 1: Create exponent. + ScopedBigNum exp(BN_new()); + if (!exp) { + LOGE("Failed to allocate RSA exponent"); + return false; + } + if (!BN_set_word(exp.get(), RSA_F4)) { + LOGE("Failed to set RSA exponent"); + return false; + } + // Step 2: Generate RSA key. + ScopedRsaKey key(RSA_new()); + if (!key) { + LOGE("Failed to allocate RSA key"); + return false; + } + if (!RSA_generate_key_ex(key.get(), static_cast(field_size), exp.get(), + nullptr)) { + LOGE("Failed to generate RSA key"); + return false; + } + key_ = key.release(); + allowed_schemes_ = kSign_RSASSA_PSS; + field_size_ = field_size; + return true; +} + +OEMCryptoResult RsaPrivateKey::GenerateSignaturePss( + const uint8_t* message, size_t message_length, uint8_t* signature, + size_t* signature_length) const { + // Step 0: Ensure the signature algorithm is supported by key. + if (!(allowed_schemes_ & kSign_RSASSA_PSS)) { + LOGE("RSA key cannot sign using PSS"); + return OEMCrypto_ERROR_INVALID_RSA_KEY; + } + // Step 1: Create a high-level key from RSA key. + ScopedEvpPkey pkey(EVP_PKEY_new()); + if (!pkey) { + LOGE("Failed to allocate PKEY"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (!EVP_PKEY_set1_RSA(pkey.get(), key_)) { + LOGE("Failed to set PKEY RSA key"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2a: Setup a EVP MD CTX for PSS Signature Generation. + ScopedEvpMdCtx md_ctx(EVP_MD_CTX_new()); + if (!md_ctx) { + LOGE("Failed to allocate MD CTX"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx| + int res = EVP_DigestSignInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr, + pkey.get()); + if (res != 1) { + LOGE("Failed to initialize MD CTX for signing"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (pkey_ctx == nullptr) { + LOGE("PKEY CTX is unexpectedly null"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2b: Configure OEMCrypto RSASSA-PSS options. + res = EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING); + if (res != 1) { + LOGE("Failed to set PSS padding"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + res = EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, kPssSaltLength); + if (res != 1) { + LOGE("Failed to set PSS salt length"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + res = EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, EVP_sha1()); + if (res != 1) { + LOGE("Failed to set PSS MGF1 MD"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 3: Digest message. + if (message_length > 0) { + res = EVP_DigestSignUpdate(md_ctx.get(), message, message_length); + if (res != 1) { + LOGE("Failed to update MD"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + } + // Step 4a: Determine size of signature. + size_t actual_signature_length = 0; + res = EVP_DigestSignFinal(md_ctx.get(), nullptr, &actual_signature_length); + if (res != 1) { + LOGE("Failed to determine signature length"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (*signature_length < actual_signature_length) { + *signature_length = actual_signature_length; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + // Step 4b: Generate signature. + res = EVP_DigestSignFinal(md_ctx.get(), signature, signature_length); + if (res != 1) { + LOGE("Failed to perform RSASSA-PSS-SIGN"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + return OEMCrypto_SUCCESS; +} + +OEMCryptoResult RsaPrivateKey::GenerateSignaturePkcs1Cast( + const uint8_t* message, size_t message_length, uint8_t* signature, + size_t* signature_length) const { + // Step 0: Ensure the signature algorithm is supported by key. + if (!(allowed_schemes_ & kSign_PKCS1_Block1)) { + LOGE("RSA key cannot sign PKCS1"); + return OEMCrypto_ERROR_INVALID_RSA_KEY; + } + if (message_length > kRsaPkcs1CastMaxMessageSize) { + LOGE("Message is too large for CAST PKCS1 signature: size = %zu", + message_length); + return OEMCrypto_ERROR_SIGNATURE_FAILURE; + } + // Step 1: Ensure signature buffer is large enough. + const size_t expected_signature_size = static_cast(RSA_size(key_)); + if (*signature_length < expected_signature_size) { + *signature_length = expected_signature_size; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + // Step 2: Encrypt with PKCS1 padding. + const int enc_res = RSA_private_encrypt(message_length, message, signature, + key_, RSA_PKCS1_PADDING); + if (enc_res < 0) { + LOGE("Failed to perform private encryption"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + *signature_length = expected_signature_size; + return OEMCrypto_SUCCESS; +} + +OEMCryptoResult RsaPrivateKey::DecryptOaep( + const uint8_t* enc_message, size_t enc_message_size, uint8_t* message, + size_t expected_message_length) const { + // Step 1: Decrypt using RSAES-OAEP. + std::vector decrypt_buffer(RSA_size(key_)); + const int dec_res = + RSA_private_decrypt(enc_message_size, enc_message, decrypt_buffer.data(), + key_, RSA_PKCS1_OAEP_PADDING); + if (dec_res < 0) { + LOGE("Failed to perform RSAES-OAEP decrypt"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (static_cast(dec_res) != expected_message_length) { + LOGE("Unexpected key size: expected = %zu, actual = %d", + expected_message_length, dec_res); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2: Copy decrypted data key. + memcpy(message, decrypt_buffer.data(), expected_message_length); + return OEMCrypto_SUCCESS; +} + +} // namespace wvoec_ref diff --git a/libwvdrmengine/oemcrypto/ref/src/oemcrypto_rsa_key.h b/libwvdrmengine/oemcrypto/ref/src/oemcrypto_rsa_key.h new file mode 100644 index 00000000..478b4bd7 --- /dev/null +++ b/libwvdrmengine/oemcrypto/ref/src/oemcrypto_rsa_key.h @@ -0,0 +1,359 @@ +// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation of OEMCrypto APIs +// +#ifndef OEMCRYPTO_RSA_KEY_H_ +#define OEMCRYPTO_RSA_KEY_H_ + +#include +#include + +#include +#include +#include + +#include + +#include "OEMCryptoCENCCommon.h" + +namespace wvoec_ref { + +enum RsaFieldSize { + kRsaFieldUnknown = 0, + kRsa2048Bit = 2048, + kRsa3072Bit = 3084 +}; + +// Identifies the RSA signature algorithm to be used when signing +// messages or verifying message signatures. +// The two standard signing algorithms specified by PKCS1 RSA V2.1 +// are RSASSA-PKCS1 and RSASSA-PSS. Each require agreement on a +// set of options. For OEMCrypto, only one set of options are agreed +// upon for each RSA signature scheme. CAST receivers specify a +// special implementation of PKCS1 where the message is already +// digested and encoded when provided. +enum RsaSignatureAlgorithm { + // RSASSA-PSS with default options: + // Hash algorithm: SHA-1 + // MGF: MGF1 with SHA-1 + // Salt length: 20 bytes + // Trailer field: 0xbc + kRsaPssDefault = 0, + // RSASSA-PKCS1 for CAST receivers. + // Assumes message is already digested & encoded. Max message length + // is 83 bytes. + kRsaPkcs1Cast = 1 +}; + +// Returns the string representation of the provided RSA field size. +// Intended for logging purposes. +std::string RsaFieldSizeToString(RsaFieldSize field_size); + +// Compares two OpenSSL/BoringSSL RSA keys to see if their public RSA +// components are matching. +// This function assumes both keys are valid. +// Returns true if they are matching, false otherwise. +bool RsaKeysAreMatchingPair(const RSA* public_key, const RSA* private_key); + +class RsaPrivateKey; + +class RsaPublicKey { + public: + // Creates a new public key equivalent of the provided private key. + static std::unique_ptr New(const RsaPrivateKey& private_key); + + // Loads a serialized RSA public key. + // The provided |buffer| must contain a valid ASN.1 DER encoded + // SubjectPublicKey. This API will reject any RSA key that is not + // approximately to 2048bits or 3072bits. + // + // buffer: SubjectPublicKeyInfo = { + // algorithm: AlgorithmIdentifier = { + // algorithm: OID = rsaEncryption, + // parameters: NULL = null + // }, + // subjectPublicKey: BIT STRING = ... -- ASN.1 DER encoded RSAPublicKey + // } + // + // Failure will occur if the provided |buffer| does not contain a + // valid SubjectPublicKey, or if the specified curve is not + // supported. + static std::unique_ptr Load(const uint8_t* buffer, + size_t length); + static std::unique_ptr Load(const std::string& buffer); + static std::unique_ptr Load(const std::vector& buffer); + + RsaFieldSize field_size() const { return field_size_; } + uint32_t allowed_schemes() const { return allowed_schemes_; } + const RSA* GetRsaKey() const { return key_; } + + // Checks if the provided |private_key| is the RSA private key of this + // public key. + bool IsMatchingPrivateKey(const RsaPrivateKey& private_key) const; + + // Serializes the public key into an ASN.1 DER encoded SubjectPublicKey + // representation. + // On success, |buffer_size| is populated with the number of bytes + // written to |buffer|, and OEMCrypto_SUCCESS is returned. + // If the provided |buffer_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |buffer_size| is set + // to the required buffer size. + OEMCryptoResult Serialize(uint8_t* buffer, size_t* buffer_size) const; + // Same as above, except directly returns the serialized key. + // Returns an empty vector on error. + std::vector Serialize() const; + + // Verifies the |signature| matches the provided |message| by the + // private equivalent of this public key. + // The signature algorithm can be specified via the |algorithm| field. + // See RsaSignatureAlgorithm for details on each algorithm. + // + // Returns: + // OEMCrypto_SUCCESS if signature is valid + // OEMCrypto_ERROR_SIGNATURE_FAILURE if the signature is invalid + // OEMCrypto_ERROR_UNKNOWN_FAILURE if any error occurs + OEMCryptoResult VerifySignature( + const uint8_t* message, size_t message_length, const uint8_t* signature, + size_t signature_length, + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; + OEMCryptoResult VerifySignature( + const std::string& message, const std::string& signature, + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; + OEMCryptoResult VerifySignature( + const std::vector& message, + const std::vector& signature, + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; + + // Encrypts the OEMCrypto session key used for deriving other keys. + // On success, |enc_session_key_size| is populated with the number + // of bytes written to |enc_session_key|, and OEMCrypto_SUCCESS is + // returned. If the provided |enc_session_key_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and + // |enc_session_key_size| is set to the required buffer size. + OEMCryptoResult EncryptSessionKey(const uint8_t* session_key, + size_t session_key_size, + uint8_t* enc_session_key, + size_t* enc_session_key_size) const; + // Same as above, except directly returns the encrypted key. + std::vector EncryptSessionKey( + const std::vector& session_key) const; + std::vector EncryptSessionKey(const std::string& session_key) const; + + // Encrypts the OEMCrypto encryption key used for encrypting the + // DRM private key. + // On success, |enc_encryption_key_size| is populated with the + // number of bytes written to |enc_encryption_key|, and + // OEMCrypto_SUCCESS is returned. + // If the provided |enc_encryption_key_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and + // |enc_encryption_key_size| is set to the required buffer size. + OEMCryptoResult EncryptEncryptionKey(const uint8_t* encryption_key, + size_t encryption_key_size, + uint8_t* enc_encryption_key, + size_t* enc_encryption_key_size) const; + // Same as above, except directly returns the encrypted key. + std::vector EncryptEncryptionKey( + const std::vector& encryption_key) const; + std::vector EncryptEncryptionKey( + const std::string& encryption_key) const; + + ~RsaPublicKey(); + + RsaPublicKey(const RsaPublicKey&) = delete; + RsaPublicKey(RsaPublicKey&&) = delete; + const RsaPublicKey& operator=(const RsaPublicKey&) = delete; + RsaPublicKey& operator=(RsaPublicKey&&) = delete; + + private: + RsaPublicKey() {} + + // Initializes the public key object using the provided |buffer|. + // In case of any failure, false is return and the key should be + // discarded. + bool InitFromBuffer(const uint8_t* buffer, size_t length); + // Initializes the public key object from a private. + bool InitFromPrivateKey(const RsaPrivateKey& private_key); + + // Signature specialization functions. + OEMCryptoResult VerifySignaturePss(const uint8_t* message, + size_t message_length, + const uint8_t* signature, + size_t signature_length) const; + OEMCryptoResult VerifySignaturePkcs1Cast(const uint8_t* message, + size_t message_length, + const uint8_t* signature, + size_t signature_length) const; + + // RSAES-OAEP encrypt. + OEMCryptoResult EncryptOaep(const uint8_t* message, size_t message_size, + uint8_t* enc_message, + size_t* enc_message_length) const; + + // OpenSSL/BoringSSL implementation of an RSA key. + // Will only include components of an RSA public key. + RSA* key_ = nullptr; + uint32_t allowed_schemes_ = 0; + RsaFieldSize field_size_ = kRsaFieldUnknown; +}; + +class RsaPrivateKey { + public: + // Creates a new, pseudorandom RSA private key. + static std::unique_ptr New(RsaFieldSize field_size); + + // Loads a serialized RSA private key. + // The provided |buffer| must contain a valid ASN.1 DER encoded + // PrivateKeyInfo (RFC 5208). + // + // buffer: PrivateKeyInfo = { + // version: INTEGER = v1(0), + // privateKeyAlgorithm: OID = rsaEncryption, + // privateKey: OCTET STRING = ..., + // -- BER encoding of RSAPrivateKey (RFC 3447) + // attributes: Attributes = ... -- Optional, not used by OEMCrypto + // } + // Note: If the public key is not included, then it is computed from + // the private. + // + // Failure will occur if the provided |buffer| does not contain a + // valid ECPrivateKey, or if the specified curve is not supported. + static std::unique_ptr Load(const uint8_t* buffer, + size_t length); + static std::unique_ptr Load(const std::string& buffer); + static std::unique_ptr Load( + const std::vector& buffer); + + // Creates a new RSA public key of this private key. + // Equivalent to calling RsaPublicKey::New with this private + // key. + std::unique_ptr MakePublicKey() const; + + RsaFieldSize field_size() const { return field_size_; } + uint32_t allowed_schemes() const { return allowed_schemes_; } + const RSA* GetRsaKey() const { return key_; } + + // Checks if the provided |public_key| is the RSA public key of this + // private key. + bool IsMatchingPublicKey(const RsaPublicKey& public_key) const; + + // Serializes the private key into an ASN.1 DER encoded X + // representation. + // On success, |buffer_size| is populated with the number of bytes + // written to |buffer|, and OEMCrypto_SUCCESS is returned. + // If the provided |buffer_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |buffer_size| is + // set to the required buffer size. + OEMCryptoResult Serialize(uint8_t* buffer, size_t* buffer_size) const; + // Same as above, except directly returns the serialized key. + // Returns an empty vector on error. + std::vector Serialize() const; + + // Signs the provided |message| using the RSA signing algorithm + // specified by |algorithm|. See RsaSignatureAlgorithm for + // details on each algorithm. + // + // On success, |signature_length| is populated with the number of + // bytes written to |signature|, and OEMCrypto_SUCCESS is returned. + // If the provided |signature_length| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |signature_length| + // is set to the required signature size. + OEMCryptoResult GenerateSignature(const uint8_t* message, + size_t message_length, + RsaSignatureAlgorithm algorithm, + uint8_t* signature, + size_t* signature_length) const; + // Same as above, except directly returns the serialized signature. + // Returns an empty vector on error. + std::vector GenerateSignature( + const std::vector& message, + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; + std::vector GenerateSignature( + const std::string& message, + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; + // Returns an upper bound for the signature size. May be larger than + // the actual signature generated by GenerateSignature(). + size_t SignatureSize() const; + + // Decrypts the OEMCrypto session key used for deriving other keys. + // On success, |session_key_size| is populated with the number of + // bytes written to |session_key|, and OEMCrypto_SUCCESS is returned. + // If the provided |session_key_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |session_key_size| + // is set to the required buffer size. + OEMCryptoResult DecryptSessionKey(const uint8_t* enc_session_key, + size_t enc_session_key_size, + uint8_t* session_key, + size_t* session_key_size) const; + // Same as above, except directly returns the decrypted key. + std::vector DecryptSessionKey( + const std::vector& enc_session_key) const; + std::vector DecryptSessionKey( + const std::string& enc_session_key) const; + // Returns the byte length of the symmetric key that would be derived + // by DecryptSessionKey(). + size_t SessionKeyLength() const; + + // Decrypts the OEMCrypto encryption key used for decrypting DRM + // private key. + // On success, |encryption_key_size| is populated with the number of + // bytes written to |encryption_key|, and OEMCrypto_SUCCESS is + // returned. + // If the provided |encryption_key_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |encryption_key_size| + // is set to the required buffer size. + OEMCryptoResult DecryptEncryptionKey(const uint8_t* enc_encryption_key, + size_t enc_encryption_key_size, + uint8_t* encryption_key, + size_t* encryption_key_size) const; + // Same as above, except directly returns the decrypted key. + std::vector DecryptEncryptionKey( + const std::vector& enc_encryption_key) const; + std::vector DecryptEncryptionKey( + const std::string& enc_encryption_key) const; + + ~RsaPrivateKey(); + + RsaPrivateKey(const RsaPrivateKey&) = delete; + RsaPrivateKey(RsaPrivateKey&&) = delete; + const RsaPrivateKey& operator=(const RsaPrivateKey&) = delete; + RsaPrivateKey& operator=(RsaPrivateKey&&) = delete; + + private: + RsaPrivateKey() {} + + // Initializes the public key object using the provided |buffer|. + // In case of any failure, false is return and the key should be + // discarded. + bool InitFromBuffer(const uint8_t* buffer, size_t length); + // Generates a new key based on the provided field size. + bool InitFromFieldSize(RsaFieldSize field_size); + + // Signature specialization functions. + OEMCryptoResult GenerateSignaturePss(const uint8_t* message, + size_t message_length, + uint8_t* signature, + size_t* signature_length) const; + OEMCryptoResult GenerateSignaturePkcs1Cast(const uint8_t* message, + size_t message_length, + uint8_t* signature, + size_t* signature_length) const; + + // RSAES-OAEP decrypt. + OEMCryptoResult DecryptOaep(const uint8_t* enc_message, + size_t enc_message_size, uint8_t* message, + size_t expected_message_length) const; + + // OpenSSL/BoringSSL implementation of an RSA key. + // Will include all components of an RSA private key. + RSA* key_ = nullptr; + uint32_t allowed_schemes_ = 0; + // Set true if the deserialized key contained an allowed schemes. + bool explicit_schemes_ = false; + RsaFieldSize field_size_ = kRsaFieldUnknown; +}; + +} // namespace wvoec_ref + +#endif // OEMCRYPTO_RSA_KEY_H_