Files
media_cas_packager_sdk_source/common/rsa_key.cc
Lu Chen 79e39b482d Add support for Widevine ECM v3
Widevine ECM v3 is redesigned mainly based on protobuf, and supports new features including carrying fingerprinting and service blocking information. Existing clients must upgrade the Widevine CAS plugin to use the new ECM v3.
2020-12-14 09:49:52 -08:00

390 lines
13 KiB
C++

////////////////////////////////////////////////////////////////////////////////
// Copyright 2016 Google LLC.
//
// This software is licensed under the terms defined in the Widevine Master
// License Agreement. For a copy of this agreement, please contact
// widevine-licensing@google.com.
////////////////////////////////////////////////////////////////////////////////
//
// Description:
// Definition of classes representing RSA private and public keys used
// for message signing, signature verification, encryption and decryption.
//
// RSA signature details:
// Algorithm: RSASSA-PSS
// Hash algorithm: |hash_algorithm|
// Mask generation function: mgf1SHA1
// Salt length: 20 bytes
// Trailer field: 0xbc
//
// RSA encryption details:
// Algorithm: RSA-OAEP
// Mask generation function: mgf1SHA1
// Label (encoding parameter): empty std::string
#include "common/rsa_key.h"
#include "glog/logging.h"
#include "openssl/bn.h"
#include "openssl/digest.h"
#include "openssl/err.h"
#include "openssl/evp.h"
#include "openssl/rsa.h"
#include "openssl/sha.h"
#include "common/hash_algorithm.h"
#include "common/rsa_util.h"
#include "common/sha_util.h"
static const int kPssSaltLength = 20;
namespace {
// Check if two RSA keys match. If matches, they are either a public-private key
// pair or the same public key or the same private key.
bool RsaKeyMatch(const RSA* key1, const RSA* key2) {
if (!key1 || !key2) return false;
return BN_cmp(key1->n, key2->n) == 0;
}
std::string OpenSSLErrorString(uint32_t error) {
char buf[ERR_ERROR_STRING_BUF_LEN];
ERR_error_string_n(error, buf, sizeof(buf));
return buf;
}
std::string GetMessageDigest(const std::string& message,
widevine::HashAlgorithm hash_algorithm) {
switch (hash_algorithm) {
// The default hash algorithm of RSA signature is SHA1.
case widevine::HashAlgorithm::kUnspecified:
case widevine::HashAlgorithm::kSha1:
return widevine::Sha1_Hash(message);
case widevine::HashAlgorithm::kSha256:
return widevine::Sha256_Hash(message);
}
LOG(FATAL) << "Unexpected hash algorithm: "
<< static_cast<int>(hash_algorithm);
return "";
}
const EVP_MD* GetHashMd(widevine::HashAlgorithm hash_algorithm) {
switch (hash_algorithm) {
case widevine::HashAlgorithm::kUnspecified:
case widevine::HashAlgorithm::kSha1:
return EVP_sha1();
case widevine::HashAlgorithm::kSha256:
return EVP_sha256();
}
LOG(FATAL) << "Unexpected hash algorithm: "
<< static_cast<int>(hash_algorithm);
return nullptr;
}
bool IsMessageTooSmall(const std::string& message) {
DCHECK(!message.empty());
// The most significant byte is encoded first in the message. See
// https://tools.ietf.org/html/rfc8017. To make sure the big number is greater
// than 2^64, we need to make sure there is at least a non-zero number in the
// first "LENGTH_IN_BITS - 64" bits of the message,
// i.e. "LENGTH_IN_BYTES - 8" bytes of the message.
const int kMinimumSizeInBytes = 8;
for (int i = 0; i < message.length() - kMinimumSizeInBytes; i++) {
if (message[i] != 0) return false;
}
return true;
}
} // namespace
namespace widevine {
RsaPrivateKey::RsaPrivateKey(RSA* key) : key_(key) { CHECK(key_ != nullptr); }
RsaPrivateKey::RsaPrivateKey(const RsaPrivateKey& rsa_key)
: key_(RSAPrivateKey_dup(rsa_key.key_)) {
CHECK(key_ != nullptr);
}
RsaPrivateKey::~RsaPrivateKey() { RSA_free(key_); }
RsaPrivateKey* RsaPrivateKey::Create(const std::string& serialized_key) {
RSA* key;
if (!rsa_util::DeserializeRsaPrivateKey(serialized_key, &key)) return nullptr;
if (RSA_check_key(key) != 1) {
LOG(ERROR) << "Invalid private RSA key: "
<< OpenSSLErrorString(ERR_get_error());
RSA_free(key);
}
return new RsaPrivateKey(key);
}
bool RsaPrivateKey::Decrypt(const std::string& encrypted_message,
std::string* decrypted_message) const {
DCHECK(decrypted_message);
size_t rsa_size = RSA_size(key_);
if (encrypted_message.size() != rsa_size) {
LOG(ERROR) << "Encrypted RSA message has the wrong size (expected "
<< rsa_size << ", actual " << encrypted_message.size() << ")";
return false;
}
decrypted_message->assign(rsa_size, 0);
int decrypted_size = RSA_private_decrypt(
rsa_size,
const_cast<unsigned char*>(
reinterpret_cast<const unsigned char*>(encrypted_message.data())),
reinterpret_cast<unsigned char*>(&(*decrypted_message)[0]), key_,
RSA_PKCS1_OAEP_PADDING);
if (decrypted_size == -1) {
LOG(ERROR) << "RSA private decrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
decrypted_message->resize(decrypted_size);
return true;
}
bool RsaPrivateKey::GenerateSignature(const std::string& message,
HashAlgorithm hash_algorithm,
std::string* signature) const {
DCHECK(signature);
if (message.empty()) {
LOG(ERROR) << "Message to be signed is empty";
return false;
}
// Hash the message using corresponding hash algorithm.
std::string message_digest = GetMessageDigest(message, hash_algorithm);
if (message_digest.empty()) {
LOG(ERROR) << "Empty message digest";
return false;
}
const EVP_MD* hash = GetHashMd(hash_algorithm);
if (hash == nullptr) {
LOG(ERROR) << "No hash md";
return false;
}
// Add PSS padding.
size_t rsa_size = RSA_size(key_);
std::string padded_digest(rsa_size, 0);
if (!RSA_padding_add_PKCS1_PSS_mgf1(
key_, reinterpret_cast<unsigned char*>(&padded_digest[0]),
reinterpret_cast<unsigned char*>(&message_digest[0]), hash,
EVP_sha1(), kPssSaltLength)) {
LOG(ERROR) << "RSA padding failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
// Encrypt PSS padded digest.
signature->assign(rsa_size, 0);
if (RSA_private_encrypt(padded_digest.size(),
reinterpret_cast<unsigned char*>(&padded_digest[0]),
reinterpret_cast<unsigned char*>(&(*signature)[0]),
key_, RSA_NO_PADDING) !=
static_cast<int>(signature->size())) {
LOG(ERROR) << "RSA private encrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
return true;
}
bool RsaPrivateKey::GenerateSignatureSha256Pkcs7(const std::string& message,
std::string* signature) const {
DCHECK(signature);
if (message.empty()) {
LOG(ERROR) << "Empty signature verification message";
return false;
}
unsigned char digest[SHA256_DIGEST_LENGTH];
SHA256(reinterpret_cast<const unsigned char*>(message.data()), message.size(),
digest);
unsigned int sig_len = RSA_size(key_);
signature->resize(sig_len);
return RSA_sign(NID_sha256, digest, sizeof(digest),
reinterpret_cast<unsigned char*>(&(*signature)[0]), &sig_len,
key_) == 1;
}
bool RsaPrivateKey::MatchesPrivateKey(const RsaPrivateKey& private_key) const {
return RsaKeyMatch(key(), private_key.key());
}
bool RsaPrivateKey::MatchesPublicKey(const RsaPublicKey& public_key) const {
return RsaKeyMatch(key(), public_key.key());
}
uint32_t RsaPrivateKey::KeySize() const { return RSA_size(key_); }
RsaPublicKey::RsaPublicKey(RSA* key) : key_(key) { CHECK(key_ != nullptr); }
RsaPublicKey::RsaPublicKey(const RsaPublicKey& rsa_key)
: key_(RSAPublicKey_dup(rsa_key.key_)) {
CHECK(key_ != nullptr);
}
RsaPublicKey::RsaPublicKey(const RsaPrivateKey& rsa_key)
: key_(RSAPublicKey_dup(rsa_key.key_)) {
CHECK(key_ != nullptr);
}
RsaPublicKey::~RsaPublicKey() { RSA_free(key_); }
RsaPublicKey* RsaPublicKey::Create(const std::string& serialized_key) {
RSA* key;
if (!rsa_util::DeserializeRsaPublicKey(serialized_key, &key)) return nullptr;
if (RSA_size(key) == 0) {
LOG(ERROR) << "Invalid public RSA key: "
<< OpenSSLErrorString(ERR_get_error());
RSA_free(key);
}
return new RsaPublicKey(key);
}
bool RsaPublicKey::Encrypt(const std::string& clear_message,
std::string* encrypted_message) const {
DCHECK(encrypted_message);
if (clear_message.empty()) {
LOG(ERROR) << "Message to be encrypted is empty";
return false;
}
size_t rsa_size = RSA_size(key_);
encrypted_message->assign(rsa_size, 0);
const int kRetryAttempt = 1;
for (int i = 0; i < 1 + kRetryAttempt; i++) {
if (RSA_public_encrypt(
clear_message.size(),
const_cast<unsigned char*>(
reinterpret_cast<const unsigned char*>(clear_message.data())),
reinterpret_cast<unsigned char*>(&(*encrypted_message)[0]), key_,
RSA_PKCS1_OAEP_PADDING) != static_cast<int>(rsa_size)) {
LOG(ERROR) << "RSA public encrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
if (!IsMessageTooSmall(*encrypted_message)) return true;
}
LOG(ERROR) << "RSA public encryption randomness error";
return false;
}
bool RsaPublicKey::VerifySignature(const std::string& message,
HashAlgorithm hash_algorithm,
const std::string& signature) const {
if (message.empty()) {
LOG(ERROR) << "Signed message is empty";
return false;
}
size_t rsa_size = RSA_size(key_);
if (signature.size() != rsa_size) {
LOG(ERROR) << "Message signature is of the wrong size (expected "
<< rsa_size << ", actual " << signature.size() << ")";
return false;
}
// Decrypt the signature.
std::string padded_digest(signature.size(), 0);
if (RSA_public_decrypt(
signature.size(),
const_cast<unsigned char*>(
reinterpret_cast<const unsigned char*>(signature.data())),
reinterpret_cast<unsigned char*>(&padded_digest[0]), key_,
RSA_NO_PADDING) != static_cast<int>(rsa_size)) {
LOG(ERROR) << "RSA public decrypt failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
// Hash the message using the corresponding hash algorithm.
std::string message_digest = GetMessageDigest(message, hash_algorithm);
if (message_digest.empty()) {
LOG(ERROR) << "Empty message digest";
return false;
}
const EVP_MD* hash = GetHashMd(hash_algorithm);
if (hash == nullptr) {
LOG(ERROR) << "No hash md";
return false;
}
// Verify PSS padding.
if (RSA_verify_PKCS1_PSS_mgf1(
key_, reinterpret_cast<unsigned char*>(&message_digest[0]), hash,
EVP_sha1(), reinterpret_cast<unsigned char*>(&padded_digest[0]),
kPssSaltLength) == 0) {
LOG(ERROR) << "RSA Verify PSS padding failure: "
<< OpenSSLErrorString(ERR_get_error());
return false;
}
return true;
}
bool RsaPublicKey::VerifySignatureSha256Pkcs7(
const std::string& message, const std::string& signature) const {
if (message.empty()) {
LOG(ERROR) << "Empty signature verification message";
return false;
}
if (signature.empty()) {
LOG(ERROR) << "Empty signature";
return false;
}
if (signature.size() != RSA_size(key_)) {
LOG(ERROR) << "RSA signature has the wrong size";
return false;
}
unsigned char digest[SHA256_DIGEST_LENGTH];
SHA256(reinterpret_cast<const unsigned char*>(message.data()), message.size(),
digest);
return RSA_verify(NID_sha256, digest, sizeof(digest),
reinterpret_cast<const unsigned char*>(signature.data()),
signature.size(), key_) == 1;
}
bool RsaPublicKey::MatchesPrivateKey(const RsaPrivateKey& private_key) const {
return RsaKeyMatch(key(), private_key.key());
}
bool RsaPublicKey::MatchesPublicKey(const RsaPublicKey& public_key) const {
return RsaKeyMatch(key(), public_key.key());
}
uint32_t RsaPublicKey::KeySize() const { return RSA_size(key_); }
RsaKeyFactory::RsaKeyFactory() {}
RsaKeyFactory::~RsaKeyFactory() {}
std::unique_ptr<RsaPrivateKey> RsaKeyFactory::CreateFromPkcs1PrivateKey(
const std::string& private_key) const {
return std::unique_ptr<RsaPrivateKey>(RsaPrivateKey::Create(private_key));
}
std::unique_ptr<RsaPrivateKey> RsaKeyFactory::CreateFromPkcs8PrivateKey(
const std::string& private_key,
const std::string& private_key_passphrase) const {
std::string pkcs1_key;
const bool result =
private_key_passphrase.empty()
? rsa_util::PrivateKeyInfoToRsaPrivateKey(private_key, &pkcs1_key)
: rsa_util::EncryptedPrivateKeyInfoToRsaPrivateKey(
private_key, private_key_passphrase, &pkcs1_key);
if (!result) {
LOG(WARNING) << "Failed to get pkcs1_key.";
return std::unique_ptr<RsaPrivateKey>();
}
return std::unique_ptr<RsaPrivateKey>(RsaPrivateKey::Create(pkcs1_key));
}
std::unique_ptr<RsaPublicKey> RsaKeyFactory::CreateFromPkcs1PublicKey(
const std::string& public_key) const {
return std::unique_ptr<RsaPublicKey>(RsaPublicKey::Create(public_key));
}
} // namespace widevine