Files

409 lines
13 KiB
C++

////////////////////////////////////////////////////////////////////////////////
// Copyright 2016 Google LLC.
//
// This software is licensed under the terms defined in the Widevine Master
// License Agreement. For a copy of this agreement, please contact
// widevine-licensing@google.com.
////////////////////////////////////////////////////////////////////////////////
//
// Description:
// Definition of classes representing RSA private and public keys used
// for message signing, signature verification, encryption and decryption.
//
// RSA signature details:
// Algorithm: RSASSA-PSS
// Hash algorithm: |hash_algorithm|
// Mask generation function: mgf1SHA1
// Salt length: 20 bytes
// Trailer field: 0xbc
//
// RSA encryption details:
// Algorithm: RSA-OAEP
// Mask generation function: mgf1SHA1
// Label (encoding parameter): empty std::string
#include "common/rsa_key.h"
#include <cstdint>
#include "glog/logging.h"
#include "openssl/asn1.h"
#include "openssl/bn.h"
#include "openssl/digest.h"
#include "openssl/err.h"
#include "openssl/evp.h"
#include "openssl/rsa.h"
#include "openssl/sha.h"
#include "common/hash_algorithm.h"
#include "common/rsa_util.h"
#include "common/sha_util.h"
static const int kPssSaltLength = 20;
namespace {
// Check if two RSA keys match. If matches, they are either a public-private key
// pair or the same public key or the same private key.
bool RsaKeyMatch(const RSA* key1, const RSA* key2) {
if (!key1 || !key2) return false;
return BN_cmp(key1->n, key2->n) == 0;
}
std::string OpenSSLErrorString(uint32_t error) {
char buf[ERR_ERROR_STRING_BUF_LEN];
ERR_error_string_n(error, buf, sizeof(buf));
return buf;
}
std::string GetMessageDigest(const std::string& message,
widevine::HashAlgorithm hash_algorithm) {
switch (hash_algorithm) {
// The default hash algorithm of RSA signature is SHA1.
case widevine::HashAlgorithm::kUnspecified:
case widevine::HashAlgorithm::kSha1:
return widevine::Sha1_Hash(message);
case widevine::HashAlgorithm::kSha384:
return widevine::Sha384_Hash(message);
case widevine::HashAlgorithm::kSha256:
return widevine::Sha256_Hash(message);
}
LOG(FATAL) << "Unexpected hash algorithm: "
<< static_cast<int>(hash_algorithm);
return "";
}
const EVP_MD* GetHashMd(widevine::HashAlgorithm hash_algorithm) {
switch (hash_algorithm) {
case widevine::HashAlgorithm::kUnspecified:
case widevine::HashAlgorithm::kSha1:
return EVP_sha1();
case widevine::HashAlgorithm::kSha384:
return EVP_sha384();
case widevine::HashAlgorithm::kSha256:
return EVP_sha256();
}
LOG(FATAL) << "Unexpected hash algorithm: "
<< static_cast<int>(hash_algorithm);
return nullptr;
}
bool IsMessageTooSmall(const std::string& message) {
DCHECK(!message.empty());
// The most significant byte is encoded first in the message. See
// https://tools.ietf.org/html/rfc8017. To make sure the big number is greater
// than 2^64, we need to make sure there is at least a non-zero number in the
// first "LENGTH_IN_BITS - 64" bits of the message,
// i.e. "LENGTH_IN_BYTES - 8" bytes of the message.
const int kMinimumSizeInBytes = 8;
for (int i = 0; i < message.length() - kMinimumSizeInBytes; i++) {
if (message[i] != 0) return false;
}
return true;
}
} // namespace
namespace widevine {
RsaPrivateKey::RsaPrivateKey(RSA* key) : key_(key) { CHECK(key_ != nullptr); }
RsaPrivateKey::RsaPrivateKey(const RsaPrivateKey& rsa_key)
: key_(RSAPrivateKey_dup(rsa_key.key_)) {
CHECK(key_ != nullptr);
}
RsaPrivateKey::~RsaPrivateKey() { RSA_free(key_); }
RsaPrivateKey* RsaPrivateKey::Create(const std::string& serialized_key) {
RSA* key;
if (!rsa_util::DeserializeRsaPrivateKey(serialized_key, &key)) return nullptr;
if (RSA_check_key(key) != 1) {
LOG(ERROR) << "Invalid private RSA key: "
<< OpenSSLErrorString(ERR_get_error());
RSA_free(key);
}
return new RsaPrivateKey(key);
}
bool RsaPrivateKey::Decrypt(const std::string& encrypted_message,
std::string* decrypted_message) const {
DCHECK(decrypted_message);
size_t rsa_size = RSA_size(key_);
if (encrypted_message.size() != rsa_size) {
LOG(ERROR) << "Encrypted RSA message has the wrong size (expected "
<< rsa_size << ", actual " << encrypted_message.size() << ")";
return false;
}
decrypted_message->assign(rsa_size, 0);
int decrypted_size = RSA_private_decrypt(
rsa_size,
const_cast<unsigned char*>(
reinterpret_cast<const unsigned char*>(encrypted_message.data())),
reinterpret_cast<unsigned char*>(&(*decrypted_message)[0]), key_,
RSA_PKCS1_OAEP_PADDING);
if (decrypted_size == -1) {
LOG(ERROR) << "RSA private decrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
decrypted_message->resize(decrypted_size);
return true;
}
bool RsaPrivateKey::GenerateSignature(const std::string& message,
HashAlgorithm hash_algorithm,
std::string* signature) const {
DCHECK(signature);
if (message.empty()) {
LOG(ERROR) << "Message to be signed is empty";
return false;
}
// Hash the message using corresponding hash algorithm.
std::string message_digest = GetMessageDigest(message, hash_algorithm);
if (message_digest.empty()) {
LOG(ERROR) << "Empty message digest";
return false;
}
const EVP_MD* hash = GetHashMd(hash_algorithm);
if (hash == nullptr) {
LOG(ERROR) << "No hash md";
return false;
}
// Add PSS padding.
size_t rsa_size = RSA_size(key_);
std::string padded_digest(rsa_size, 0);
if (!RSA_padding_add_PKCS1_PSS_mgf1(
key_, reinterpret_cast<unsigned char*>(&padded_digest[0]),
reinterpret_cast<unsigned char*>(&message_digest[0]), hash,
EVP_sha1(), kPssSaltLength)) {
LOG(ERROR) << "RSA padding failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
// Encrypt PSS padded digest.
signature->assign(rsa_size, 0);
if (RSA_private_encrypt(padded_digest.size(),
reinterpret_cast<unsigned char*>(&padded_digest[0]),
reinterpret_cast<unsigned char*>(&(*signature)[0]),
key_, RSA_NO_PADDING) !=
static_cast<int>(signature->size())) {
LOG(ERROR) << "RSA private encrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
return true;
}
bool RsaPrivateKey::GenerateSignatureSha256Pkcs7(const std::string& message,
std::string* signature) const {
DCHECK(signature);
if (message.empty()) {
LOG(ERROR) << "Empty signature verification message";
return false;
}
unsigned char digest[SHA256_DIGEST_LENGTH];
SHA256(reinterpret_cast<const unsigned char*>(message.data()), message.size(),
digest);
unsigned int sig_len = RSA_size(key_);
signature->resize(sig_len);
return RSA_sign(NID_sha256, digest, sizeof(digest),
reinterpret_cast<unsigned char*>(&(*signature)[0]), &sig_len,
key_) == 1;
}
bool RsaPrivateKey::MatchesPrivateKey(const RsaPrivateKey& private_key) const {
return RsaKeyMatch(key(), private_key.key());
}
bool RsaPrivateKey::MatchesPublicKey(const RsaPublicKey& public_key) const {
return RsaKeyMatch(key(), public_key.key());
}
uint32_t RsaPrivateKey::KeySize() const { return RSA_size(key_); }
RsaPublicKey::RsaPublicKey(RSA* key) : key_(key) { CHECK(key_ != nullptr); }
RsaPublicKey::RsaPublicKey(const RsaPublicKey& rsa_key)
: key_(RSAPublicKey_dup(rsa_key.key_)) {
CHECK(key_ != nullptr);
}
RsaPublicKey::RsaPublicKey(const RsaPrivateKey& rsa_key)
: key_(RSAPublicKey_dup(rsa_key.key_)) {
CHECK(key_ != nullptr);
}
RsaPublicKey::~RsaPublicKey() { RSA_free(key_); }
RsaPublicKey* RsaPublicKey::Create(const std::string& serialized_key) {
RSA* key;
if (!rsa_util::DeserializeRsaPublicKey(serialized_key, &key)) return nullptr;
if (RSA_size(key) == 0) {
LOG(ERROR) << "Invalid public RSA key: "
<< OpenSSLErrorString(ERR_get_error());
RSA_free(key);
}
return new RsaPublicKey(key);
}
bool RsaPublicKey::Encrypt(const std::string& clear_message,
std::string* encrypted_message) const {
DCHECK(encrypted_message);
if (clear_message.empty()) {
LOG(ERROR) << "Message to be encrypted is empty";
return false;
}
size_t rsa_size = RSA_size(key_);
encrypted_message->assign(rsa_size, 0);
const int kRetryAttempt = 1;
for (int i = 0; i < 1 + kRetryAttempt; i++) {
if (RSA_public_encrypt(
clear_message.size(),
const_cast<unsigned char*>(
reinterpret_cast<const unsigned char*>(clear_message.data())),
reinterpret_cast<unsigned char*>(&(*encrypted_message)[0]), key_,
RSA_PKCS1_OAEP_PADDING) != static_cast<int>(rsa_size)) {
LOG(ERROR) << "RSA public encrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
if (!IsMessageTooSmall(*encrypted_message)) return true;
}
LOG(ERROR) << "RSA public encryption randomness error";
return false;
}
bool RsaPublicKey::VerifySignature(const std::string& message,
HashAlgorithm hash_algorithm,
const std::string& signature) const {
if (message.empty()) {
LOG(ERROR) << "Signed message is empty";
return false;
}
size_t rsa_size = RSA_size(key_);
if (signature.size() != rsa_size) {
LOG(ERROR) << "Message signature is of the wrong size (expected "
<< rsa_size << ", actual " << signature.size() << ")";
return false;
}
// Decrypt the signature.
std::string padded_digest(signature.size(), 0);
if (RSA_public_decrypt(
signature.size(),
const_cast<unsigned char*>(
reinterpret_cast<const unsigned char*>(signature.data())),
reinterpret_cast<unsigned char*>(&padded_digest[0]), key_,
RSA_NO_PADDING) != static_cast<int>(rsa_size)) {
LOG(ERROR) << "RSA public decrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
// Hash the message using the corresponding hash algorithm.
std::string message_digest = GetMessageDigest(message, hash_algorithm);
if (message_digest.empty()) {
LOG(ERROR) << "Empty message digest";
return false;
}
const EVP_MD* hash = GetHashMd(hash_algorithm);
if (hash == nullptr) {
LOG(ERROR) << "No hash md";
return false;
}
// Verify PSS padding.
if (RSA_verify_PKCS1_PSS_mgf1(
key_, reinterpret_cast<unsigned char*>(&message_digest[0]), hash,
EVP_sha1(), reinterpret_cast<unsigned char*>(&padded_digest[0]),
kPssSaltLength) == 0) {
LOG(ERROR) << "RSA Verify PSS padding failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
return true;
}
bool RsaPublicKey::VerifySignatureSha256Pkcs7(
const std::string& message, const std::string& signature) const {
if (message.empty()) {
LOG(ERROR) << "Empty signature verification message";
return false;
}
if (signature.empty()) {
LOG(ERROR) << "Empty signature";
return false;
}
if (signature.size() != RSA_size(key_)) {
LOG(ERROR) << "RSA signature has the wrong size";
return false;
}
unsigned char digest[SHA256_DIGEST_LENGTH];
SHA256(reinterpret_cast<const unsigned char*>(message.data()), message.size(),
digest);
return RSA_verify(NID_sha256, digest, sizeof(digest),
reinterpret_cast<const unsigned char*>(signature.data()),
signature.size(), key_) == 1;
}
bool RsaPublicKey::MatchesPrivateKey(const RsaPrivateKey& private_key) const {
return RsaKeyMatch(key(), private_key.key());
}
bool RsaPublicKey::MatchesPublicKey(const RsaPublicKey& public_key) const {
return RsaKeyMatch(key(), public_key.key());
}
uint32_t RsaPublicKey::KeySize() const { return RSA_size(key_); }
bool RsaPublicKey::SerializedKey(std::string* serialized_key) const {
if (serialized_key == nullptr) {
return false;
}
std::string tmp_serialized_key;
if (!rsa_util::SerializeRsaPublicKey(key(), &tmp_serialized_key)) {
return false;
}
*serialized_key = tmp_serialized_key;
return true;
}
RsaKeyFactory::RsaKeyFactory() {}
RsaKeyFactory::~RsaKeyFactory() {}
std::unique_ptr<RsaPrivateKey> RsaKeyFactory::CreateFromPkcs1PrivateKey(
const std::string& private_key) const {
return std::unique_ptr<RsaPrivateKey>(RsaPrivateKey::Create(private_key));
}
std::unique_ptr<RsaPrivateKey> RsaKeyFactory::CreateFromPkcs8PrivateKey(
const std::string& private_key,
const std::string& private_key_passphrase) const {
std::string pkcs1_key;
const bool result =
private_key_passphrase.empty()
? rsa_util::PrivateKeyInfoToRsaPrivateKey(private_key, &pkcs1_key)
: rsa_util::EncryptedPrivateKeyInfoToRsaPrivateKey(
private_key, private_key_passphrase, &pkcs1_key);
if (!result) {
LOG(WARNING) << "Failed to get pkcs1_key.";
return std::unique_ptr<RsaPrivateKey>();
}
return std::unique_ptr<RsaPrivateKey>(RsaPrivateKey::Create(pkcs1_key));
}
std::unique_ptr<RsaPublicKey> RsaKeyFactory::CreateFromPkcs1PublicKey(
const std::string& public_key) const {
return std::unique_ptr<RsaPublicKey>(RsaPublicKey::Create(public_key));
}
} // namespace widevine