diff --git a/common/BUILD b/common/BUILD index 7cd14f7..a25bb7d 100644 --- a/common/BUILD +++ b/common/BUILD @@ -16,10 +16,30 @@ filegroup( name = "binary_release_files", srcs = [ "certificate_type.h", + "default_device_security_profile_list.h", + "security_profile_list.h", "status.h", ], ) +cc_library( + name = "playready_interface", + hdrs = ["playready_interface.h"], + deps = [ + "//util:error_space", + "//protos/public:license_protocol_cc_proto", + ], +) + +cc_library( + name = "playready_sdk_impl", + hdrs = ["playready_sdk_impl.h"], + deps = [ + ":playready_interface", + "//protos/public:license_protocol_cc_proto", + ], +) + cc_library( name = "content_id_util", srcs = ["content_id_util.cc"], @@ -27,6 +47,7 @@ cc_library( deps = [ ":error_space", ":status", + "//base", "//license_server_sdk/internal:sdk", "//protos/public:errors_cc_proto", "//protos/public:external_license_cc_proto", @@ -67,9 +88,14 @@ cc_library( hdrs = ["security_profile_list.h"], deps = [ ":client_id_util", + ":device_status_list", + "//base", + "//external:protobuf", "@abseil_repo//absl/synchronization", "//protos/public:client_identification_cc_proto", + "//protos/public:device_certificate_status_cc_proto", "//protos/public:device_common_cc_proto", + "//protos/public:device_security_profile_data_cc_proto", "//protos/public:provisioned_device_info_cc_proto", "//protos/public:security_profile_cc_proto", ], @@ -80,6 +106,7 @@ cc_test( timeout = "short", srcs = ["security_profile_list_test.cc"], deps = [ + ":client_id_util", ":security_profile_list", "//base", "//external:protobuf", @@ -90,6 +117,40 @@ cc_test( ], ) +cc_library( + name = "default_device_security_profile_list", + srcs = ["default_device_security_profile_list.cc"], + hdrs = ["default_device_security_profile_list.h"], + deps = [ + ":client_id_util", + ":device_status_list", + ":security_profile_list", + "//base", + "//external:protobuf", + "//protos/public:client_identification_cc_proto", + "//protos/public:device_certificate_status_cc_proto", + "//protos/public:device_common_cc_proto", + "//protos/public:provisioned_device_info_cc_proto", + "//protos/public:security_profile_cc_proto", + ], +) + +cc_test( + name = "default_device_security_profile_list_test", + timeout = "short", + srcs = ["default_device_security_profile_list_test.cc"], + deps = [ + ":client_id_util", + ":default_device_security_profile_list", + "//base", + "//external:protobuf", + "//testing:gunit_main", + "@abseil_repo//absl/memory", + "//protos/public:device_common_cc_proto", + "//protos/public:security_profile_cc_proto", + ], +) + cc_library( name = "status", srcs = ["status.cc"], @@ -120,6 +181,8 @@ cc_library( "certificate_client_cert.cc", "certificate_client_cert.h", "client_cert.cc", + "dual_certificate_client_cert.cc", + "dual_certificate_client_cert.h", "keybox_client_cert.cc", ], hdrs = [ @@ -132,6 +195,7 @@ cc_library( ":ec_key", ":ec_util", ":error_space", + ":hash_algorithm", ":openssl_util", ":random_util", ":rsa_key", @@ -155,8 +219,11 @@ cc_test( srcs = ["client_cert_test.cc"], deps = [ ":client_cert", + ":ec_key", ":ec_test_keys", ":error_space", + ":hash_algorithm", + ":hash_algorithm_util", ":rsa_key", ":rsa_test_keys", ":sha_util", @@ -179,6 +246,8 @@ cc_library( ":client_cert", ":drm_service_certificate", ":error_space", + ":hash_algorithm", + ":hash_algorithm_util", ":rsa_key", ":status", "//base", @@ -211,6 +280,8 @@ cc_test( deps = [ ":client_cert", ":device_status_list", + ":hash_algorithm", + ":hash_algorithm_util", ":rsa_key", ":rsa_test_keys", ":status", @@ -235,6 +306,8 @@ cc_library( ":certificate_type", ":ec_key", ":error_space", + ":hash_algorithm", + ":hash_algorithm_util", ":rsa_key", ":sha_util", ":signer_public_key", @@ -258,6 +331,8 @@ cc_test( ":ec_key", ":ec_test_keys", ":error_space", + ":hash_algorithm", + ":hash_algorithm_util", ":rsa_key", ":rsa_test_keys", ":test_drm_certificates", @@ -338,9 +413,12 @@ cc_library( srcs = ["rsa_key.cc"], hdrs = ["rsa_key.h"], deps = [ + ":hash_algorithm", ":rsa_util", ":sha_util", "//base", + "@abseil_repo//absl/base:core_headers", + "@abseil_repo//absl/strings", "//external:openssl", ], ) @@ -371,6 +449,7 @@ cc_library( testonly = 1, hdrs = ["mock_rsa_key.h"], deps = [ + ":hash_algorithm", ":rsa_key", "//testing:gunit", ], @@ -384,9 +463,11 @@ cc_library( "ec_util.h", ], deps = [ + ":hash_algorithm", ":openssl_util", ":private_key_util", "//base", + "@abseil_repo//absl/base:core_headers", "@abseil_repo//absl/memory", "//external:openssl", ], @@ -415,9 +496,11 @@ cc_library( deps = [ ":aes_cbc_util", ":ec_util", + ":hash_algorithm", ":openssl_util", ":sha_util", "//base", + "@abseil_repo//absl/base:core_headers", "@abseil_repo//absl/memory", "//external:openssl", ], @@ -671,6 +754,7 @@ cc_library( hdrs = ["signature_util.h"], deps = [ ":aes_cbc_util", + ":hash_algorithm", ":rsa_key", ":sha_util", ":status", @@ -774,6 +858,7 @@ cc_library( ":status", ":x509_cert", "//base", + "@abseil_repo//absl/base:core_headers", "@abseil_repo//absl/strings", "@abseil_repo//absl/synchronization", "//protos/public:client_identification_cc_proto", @@ -795,6 +880,7 @@ cc_library( ":rsa_util", ":status", "//base", + "@abseil_repo//absl/base:core_headers", "@abseil_repo//absl/strings", "@abseil_repo//absl/synchronization", "//util/gtl:map_util", @@ -813,6 +899,7 @@ cc_test( ":aes_cbc_util", ":drm_root_certificate", ":drm_service_certificate", + ":hash_algorithm_util", ":rsa_key", ":rsa_test_keys", ":rsa_util", @@ -849,6 +936,7 @@ cc_library( ":rsa_key", ":status", "//base", + "@abseil_repo//absl/base:core_headers", "@abseil_repo//absl/strings", "@abseil_repo//absl/synchronization", "//external:openssl", @@ -887,6 +975,7 @@ cc_library( deps = [ ":certificate_type", ":error_space", + ":hash_algorithm_util", ":rsa_key", ":status", ":x509_cert", @@ -901,6 +990,7 @@ cc_test( timeout = "short", srcs = ["vmp_checker_test.cc"], deps = [ + ":hash_algorithm_util", ":rsa_key", ":vmp_checker", "//base", @@ -1012,6 +1102,7 @@ cc_library( hdrs = ["signer_public_key.h"], deps = [ ":ec_key", + ":hash_algorithm", ":rsa_key", "@abseil_repo//absl/memory", "//protos/public:drm_certificate_cc_proto", @@ -1024,6 +1115,7 @@ cc_test( deps = [ ":ec_key", ":ec_test_keys", + ":hash_algorithm", ":rsa_key", ":rsa_test_keys", ":signer_public_key", @@ -1041,3 +1133,29 @@ cc_library( "//common/oemcrypto_core_message/odk:kdo", ], ) + +cc_library( + name = "hash_algorithm", + hdrs = ["hash_algorithm.h"], +) + +cc_library( + name = "hash_algorithm_util", + srcs = ["hash_algorithm_util.cc"], + hdrs = ["hash_algorithm_util.h"], + deps = [ + ":hash_algorithm", + "//base", + "//protos/public:hash_algorithm_cc_proto", + ], +) + +cc_test( + name = "hash_algorithm_util_test", + srcs = ["hash_algorithm_util_test.cc"], + deps = [ + ":hash_algorithm", + ":hash_algorithm_util", + "//testing:gunit_main", + ], +) diff --git a/common/aes_cbc_util_test.cc b/common/aes_cbc_util_test.cc index c7f67e6..1713ea9 100644 --- a/common/aes_cbc_util_test.cc +++ b/common/aes_cbc_util_test.cc @@ -7,6 +7,7 @@ //////////////////////////////////////////////////////////////////////////////// #include "common/aes_cbc_util.h" + #include "testing/gmock.h" #include "testing/gunit.h" diff --git a/common/certificate_client_cert.cc b/common/certificate_client_cert.cc index 0c9326f..be44107 100644 --- a/common/certificate_client_cert.cc +++ b/common/certificate_client_cert.cc @@ -66,10 +66,11 @@ class ClientCertAlgorithmRSA : public ClientCertAlgorithm { } Status VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, const std::string& signature) const override { CHECK(rsa_public_key_); - if (!rsa_public_key_->VerifySignature(message, signature)) { + if (!rsa_public_key_->VerifySignature(message, hash_algorithm, signature)) { return Status(error_space, INVALID_SIGNATURE, ""); } return OkStatus(); @@ -143,10 +144,12 @@ class ClientCertAlgorithmECC : public ClientCertAlgorithm { } Status VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, const std::string& signature) const override { CHECK(client_ecc_public_key_); - if (!client_ecc_public_key_->VerifySignature(message, signature)) { + if (!client_ecc_public_key_->VerifySignature(message, hash_algorithm, + signature)) { return Status(error_space, INVALID_SIGNATURE, ""); } return OkStatus(); @@ -165,7 +168,7 @@ class ClientCertAlgorithmECC : public ClientCertAlgorithm { } SignedMessage::SessionKeyType session_key_type() const override { - return SignedMessage::EPHERMERAL_ECC_PUBLIC_KEY; + return SignedMessage::EPHEMERAL_ECC_PUBLIC_KEY; } private: @@ -259,11 +262,11 @@ Status CertificateClientCert::Initialize( } Status CertificateClientCert::VerifySignature( - const std::string& message, const std::string& signature, - ProtocolVersion protocol_version) const { + const std::string& message, HashAlgorithm hash_algorithm, + const std::string& signature, ProtocolVersion protocol_version) const { return algorithm_->VerifySignature( protocol_version < VERSION_2_2 ? message : Sha512_Hash(message), - signature); + hash_algorithm, signature); } void CertificateClientCert::GenerateSigningKey( diff --git a/common/certificate_client_cert.h b/common/certificate_client_cert.h index 2013035..c412fec 100644 --- a/common/certificate_client_cert.h +++ b/common/certificate_client_cert.h @@ -10,6 +10,7 @@ #define COMMON_CERTIFICATE_CLIENT_CERT_H_ #include "common/client_cert.h" +#include "common/hash_algorithm.h" #include "protos/public/drm_certificate.pb.h" namespace widevine { @@ -30,6 +31,7 @@ class ClientCertAlgorithm { // Verify the |signature| of an incoming request |message| using the public // key from the drm certificate. virtual Status VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, const std::string& signature) const = 0; // Returns the key to be used in key derivation of the license @@ -57,6 +59,7 @@ class CertificateClientCert : public ClientCert { const std::string& serialized_certificate); Status VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, const std::string& signature, ProtocolVersion protocol_version) const override; @@ -70,6 +73,7 @@ class CertificateClientCert : public ClientCert { SignedMessage::SessionKeyType key_type() const override { return algorithm_->session_key_type(); } + bool using_dual_certificate() const override { return false; } const std::string& serial_number() const override { return device_cert_.serial_number(); } diff --git a/common/client_cert.cc b/common/client_cert.cc index e4d3181..ec6d30a 100644 --- a/common/client_cert.cc +++ b/common/client_cert.cc @@ -17,6 +17,7 @@ #include "absl/strings/escaping.h" #include "common/certificate_client_cert.h" #include "common/crypto_util.h" +#include "common/dual_certificate_client_cert.h" #include "common/error_space.h" #include "common/keybox_client_cert.h" #include "common/random_util.h" @@ -52,23 +53,34 @@ uint32_t KeyboxClientCert::GetSystemId(const std::string& keybox_bytes) { return WvmTokenHandler::GetSystemId(keybox_bytes); } -Status ClientCert::Create( - const DrmRootCertificate* root_certificate, - widevine::ClientIdentification::TokenType token_type, - const std::string& token, std::unique_ptr* client_cert) { +Status ClientCert::Create(const DrmRootCertificate* root_certificate, + const widevine::ClientIdentification& client_id, + std::unique_ptr* client_cert) { CHECK(client_cert); - Status status; - switch (token_type) { + + switch (client_id.type()) { case ClientIdentification::KEYBOX: - return CreateWithKeybox(token, client_cert); - + return CreateWithKeybox(client_id.token(), client_cert); case ClientIdentification::DRM_DEVICE_CERTIFICATE: - return CreateWithDrmCertificate(root_certificate, token, client_cert); + if (!client_id.has_device_credentials()) { + return CreateWithDrmCertificate(root_certificate, client_id.token(), + client_cert); + } + // Assumes |client_id.token| is the signing cert and + // |client_id.device_credentials().token| is the encryption cert. + if (client_id.device_credentials().type() != + ClientIdentification::DRM_DEVICE_CERTIFICATE) + return Status(error_space, INVALID_DRM_CERTIFICATE, + "unsupported-encryption-certificate"); + return CreateWithDualDrmCertificates( + root_certificate, client_id.token(), + client_id.device_credentials().token(), client_cert); default: return Status(error_space, error::UNIMPLEMENTED, "client-type-not-implemented"); } + return OkStatus(); } @@ -78,6 +90,7 @@ Status ClientCert::CreateWithDrmCertificate( const DrmRootCertificate* root_certificate, const std::string& drm_certificate, std::unique_ptr* client_cert) { + CHECK(root_certificate); CHECK(client_cert); auto device_cert = absl::make_unique(); Status status = device_cert->Initialize(root_certificate, drm_certificate); @@ -87,6 +100,22 @@ Status ClientCert::CreateWithDrmCertificate( return status; } +Status ClientCert::CreateWithDualDrmCertificates( + const DrmRootCertificate* root_certificate, + const std::string& signing_drm_certificate, + const std::string& encryption_drm_certificate, + std::unique_ptr* client_cert) { + CHECK(root_certificate); + CHECK(client_cert); + auto device_cert = absl::make_unique(); + Status status = device_cert->Initialize( + root_certificate, signing_drm_certificate, encryption_drm_certificate); + if (status.ok()) { + *client_cert = std::move(device_cert); + } + return status; +} + Status ClientCert::CreateWithKeybox(const std::string& keybox_token, std::unique_ptr* client_cert) { CHECK(client_cert); diff --git a/common/client_cert.h b/common/client_cert.h index 675c53e..8b458a8 100644 --- a/common/client_cert.h +++ b/common/client_cert.h @@ -12,8 +12,11 @@ #include #include "common/drm_root_certificate.h" +#include "common/error_space.h" +#include "common/hash_algorithm.h" #include "common/status.h" #include "protos/public/client_identification.pb.h" +#include "protos/public/errors.pb.h" #include "protos/public/license_protocol.pb.h" namespace widevine { @@ -24,17 +27,10 @@ class ClientCert { ClientCert() = default; public: - // Creates a ClientCert from the |token|. The type of ClientCert created is - // determined by the |token_type|. - static Status Create( - const DrmRootCertificate* root_certificate, - widevine::ClientIdentification::TokenType token_type, - const std::string& token, std::unique_ptr* client_cert); - - // Creates a Keybox based ClientCert. The |client_cert| is a caller supplied - // unique_ptr to receive the new ClientCert. - static Status CreateWithKeybox(const std::string& keybox_token, - std::unique_ptr* client_cert); + // Creates a Device Certificate from the supplied |client_id|. + static Status Create(const DrmRootCertificate* root_certificate, + const widevine::ClientIdentification& client_id, + std::unique_ptr* client_cert); // Creates a Device Certificate based ClientCert. static Status CreateWithDrmCertificate( @@ -42,6 +38,21 @@ class ClientCert { const std::string& drm_certificate, std::unique_ptr* client_cert); + // Creates a Device Certificate using the supplied certificates. + // The|signing_drm_certificate| will be used to verify an incoming request. + // The |encryption_drm_certificate| will be used to define the session key + // used to protect a response message. + static Status CreateWithDualDrmCertificates( + const DrmRootCertificate* root_certificate, + const std::string& signing_drm_certificate, + const std::string& encryption_drm_certificate, + std::unique_ptr* client_cert); + + // Creates a Keybox based ClientCert. The |client_cert| is a caller supplied + // unique_ptr to receive the new ClientCert. + static Status CreateWithKeybox(const std::string& keybox_token, + std::unique_ptr* client_cert); + virtual ~ClientCert() = default; ClientCert(const ClientCert&) = delete; ClientCert& operator=(const ClientCert&) = delete; @@ -50,6 +61,7 @@ class ClientCert { // classes information and the passed in message. Returns OK if signature // is valid. virtual Status VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, const std::string& signature, ProtocolVersion protocol_version) const = 0; @@ -61,6 +73,7 @@ class ClientCert { virtual const std::string& encrypted_key() const = 0; virtual const std::string& key() const = 0; virtual SignedMessage::SessionKeyType key_type() const = 0; + virtual bool using_dual_certificate() const = 0; virtual const std::string& serial_number() const = 0; virtual const std::string& service_id() const = 0; virtual const std::string& signing_key() const = 0; @@ -71,6 +84,14 @@ class ClientCert { virtual widevine::ClientIdentification::TokenType type() const = 0; virtual const std::string& encrypted_unique_id() const = 0; virtual const std::string& unique_id_hash() const = 0; + virtual Status SystemIdUnknownError() const { + return Status(error_space, DRM_DEVICE_CERTIFICATE_UNKNOWN, + "device-certificate-status-unknown"); + } + virtual Status SystemIdRevokedError() const { + return Status(error_space, DRM_DEVICE_CERTIFICATE_REVOKED, + "device-certificate-revoked"); + } }; } // namespace widevine diff --git a/common/client_cert_test.cc b/common/client_cert_test.cc index ba1f759..e145f60 100644 --- a/common/client_cert_test.cc +++ b/common/client_cert_test.cc @@ -9,12 +9,16 @@ #include "common/client_cert.h" #include +#include #include "testing/gmock.h" #include "testing/gunit.h" #include "absl/strings/escaping.h" +#include "common/ec_key.h" #include "common/ec_test_keys.h" #include "common/error_space.h" +#include "common/hash_algorithm.h" +#include "common/hash_algorithm_util.h" #include "common/keybox_client_cert.h" #include "common/rsa_key.h" #include "common/rsa_test_keys.h" @@ -35,13 +39,16 @@ const DrmCertificate::Type kNoSigner = DrmCertificate::ROOT; const DrmCertificate::Type kDeviceModelSigner = DrmCertificate::DEVICE_MODEL; const DrmCertificate::Type kProvisionerSigner = DrmCertificate::PROVISIONER; +const HashAlgorithm kSha256 = HashAlgorithm::kSha256; + // TODO(user): Change these tests to use on-the-fly generated intermediate // and device certificates based on RsaTestKeys. // TODO(user): Add testcase(s) CreateSignature, // and GenerateSigningKey. class ClientCertTest - : public ::testing::TestWithParam { + : public ::testing::TestWithParam< + std::tuple> { public: ~ClientCertTest() override = default; void SetUp() override { @@ -74,16 +81,30 @@ class ClientCertTest class TestCertificateAndData { public: const std::string certificate_; + const std::string encryption_certificate_; const std::string expected_serial_number_; uint32_t expected_system_id_; Status expected_status_; + SignedMessage::SessionKeyType expected_key_type_; TestCertificateAndData(const std::string& certificate, const std::string& expected_serial_number, uint32_t expected_system_id, Status expected_status) : certificate_(certificate), expected_serial_number_(expected_serial_number), expected_system_id_(expected_system_id), - expected_status_(expected_status) {} + expected_status_(expected_status), + expected_key_type_(SignedMessage::WRAPPED_AES_KEY) {} + TestCertificateAndData(const std::string& certificate, + const std::string& encryption_certificate, + const std::string& expected_serial_number, + uint32_t expected_system_id, Status expected_status, + SignedMessage::SessionKeyType expected_key_type) + : certificate_(certificate), + encryption_certificate_(encryption_certificate), + expected_serial_number_(expected_serial_number), + expected_system_id_(expected_system_id), + expected_status_(expected_status), + expected_key_type_(expected_key_type) {} }; void TestBasicValidation(const TestTokenAndKeys& expectation, @@ -94,31 +115,32 @@ class ClientCertTest void GenerateSignature(const std::string& message, const std::string& private_key, - std::string* signature); - SignedDrmCertificate* SignCertificate(const DrmCertificate& certificate, - SignedDrmCertificate* signer, - const std::string& private_key); - DrmCertificate* GenerateProvisionerCertificate( + HashAlgorithm hash_algorithm, std::string* signature); + std::unique_ptr SignCertificate( + const DrmCertificate& certificate, const SignedDrmCertificate* signer, + const std::string& private_key); + std::unique_ptr GenerateProvisionerCertificate( uint32_t system_id, const std::string& serial_number, const std::string& provider_id); - SignedDrmCertificate* GenerateSignedProvisionerCertificate( + std::unique_ptr GenerateSignedProvisionerCertificate( uint32_t system_id, const std::string& serial_number, const std::string& service_id); - DrmCertificate* GenerateIntermediateCertificate( + std::unique_ptr GenerateIntermediateCertificate( uint32_t system_id, const std::string& serial_number); - SignedDrmCertificate* GenerateSignedIntermediateCertificate( + std::unique_ptr GenerateSignedIntermediateCertificate( SignedDrmCertificate* signer, uint32_t system_id, const std::string& serial_number, DrmCertificate::Type signer_cert_type); - DrmCertificate* GenerateDrmCertificate( + std::unique_ptr GenerateDrmCertificate( uint32_t system_id, const std::string& serial_number, DrmCertificate::Algorithm = DrmCertificate::RSA); - SignedDrmCertificate* GenerateSignedDrmCertificate( + std::unique_ptr GenerateSignedDrmCertificate( SignedDrmCertificate* signer, uint32_t system_id, const std::string& serial_number, DrmCertificate::Algorithm = DrmCertificate::RSA); std::string GetPublicKeyByCertType(DrmCertificate::Type cert_type); std::string GetPrivateKeyByCertType(DrmCertificate::Type cert_type); + std::string GetECCPrivateKey(DrmCertificate::Algorithm algorithm); std::string GetECCPublicKey(DrmCertificate::Algorithm algorithm); RsaTestKeys test_rsa_keys_; @@ -136,8 +158,11 @@ void ClientCertTest::TestBasicValidation(const TestTokenAndKeys& expectation, Status status; std::unique_ptr keybox_cert; - status = ClientCert::Create(root_cert_.get(), ClientIdentification::KEYBOX, - expectation.token_, &keybox_cert); + ClientIdentification client_id; + client_id.set_type(ClientIdentification::KEYBOX); + client_id.set_token(expectation.token_); + + status = ClientCert::Create(root_cert_.get(), client_id, &keybox_cert); if (expect_success) { ASSERT_EQ(OkStatus(), status); ASSERT_TRUE(keybox_cert.get()); @@ -163,12 +188,25 @@ void ClientCertTest::TestBasicValidationDrmCertificate( // Test validation of a valid request. Status status; std::unique_ptr drm_certificate_cert; - status = ClientCert::Create(root_cert_.get(), - ClientIdentification::DRM_DEVICE_CERTIFICATE, - expectation.certificate_, &drm_certificate_cert); + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(expectation.certificate_); + if (!expectation.encryption_certificate_.empty()) { + client_id.mutable_device_credentials()->set_token( + expectation.encryption_certificate_); + client_id.mutable_device_credentials()->set_type( + ClientIdentification::DRM_DEVICE_CERTIFICATE); + } + + status = + ClientCert::Create(root_cert_.get(), client_id, &drm_certificate_cert); ASSERT_EQ(expectation.expected_status_, status); if (expectation.expected_status_.ok()) { ASSERT_TRUE(drm_certificate_cert.get()); + if (!expectation.encryption_certificate_.empty()) { + ASSERT_TRUE(drm_certificate_cert->using_dual_certificate()); + } + ASSERT_EQ(expectation.expected_key_type_, drm_certificate_cert->key_type()); if (compare_data) { ASSERT_EQ(expectation.expected_serial_number_, drm_certificate_cert->signer_serial_number()); @@ -182,26 +220,29 @@ void ClientCertTest::TestBasicValidationDrmCertificate( void ClientCertTest::GenerateSignature(const std::string& message, const std::string& private_key, + HashAlgorithm hash_algorithm, std::string* signature) { std::unique_ptr rsa_private_key( RsaPrivateKey::Create(private_key)); ASSERT_TRUE(rsa_private_key != nullptr); - rsa_private_key->GenerateSignature(message, signature); + rsa_private_key->GenerateSignature(message, hash_algorithm, signature); } -// The caller relinquishes ownership of |signer|, which may also be nullptr. -SignedDrmCertificate* ClientCertTest::SignCertificate( - const DrmCertificate& certificate, SignedDrmCertificate* signer, +// The caller retains ownership of |signer|, which may also be nullptr. +std::unique_ptr ClientCertTest::SignCertificate( + const DrmCertificate& certificate, const SignedDrmCertificate* signer, const std::string& private_key) { std::unique_ptr signed_certificate( new SignedDrmCertificate); signed_certificate->set_drm_certificate(certificate.SerializeAsString()); - GenerateSignature(signed_certificate->drm_certificate(), private_key, - signed_certificate->mutable_signature()); + GenerateSignature( + signed_certificate->drm_certificate(), private_key, + HashAlgorithmProtoToEnum(signed_certificate->hash_algorithm()), + signed_certificate->mutable_signature()); if (signer != nullptr) { - signed_certificate->set_allocated_signer(signer); + *(signed_certificate->mutable_signer()) = *signer; } - return signed_certificate.release(); + return signed_certificate; } std::string ClientCertTest::GetPublicKeyByCertType( @@ -224,6 +265,21 @@ std::string ClientCertTest::GetPrivateKeyByCertType( return test_rsa_keys_.private_test_key_1_3072_bits(); } +std::string ClientCertTest::GetECCPrivateKey( + DrmCertificate::Algorithm algorithm) { + ECTestKeys keys; + switch (algorithm) { + case DrmCertificate::ECC_SECP256R1: + return keys.private_test_key_1_secp256r1(); + case DrmCertificate::ECC_SECP384R1: + return keys.private_test_key_1_secp384r1(); + case DrmCertificate::ECC_SECP521R1: + return keys.private_test_key_1_secp521r1(); + default: + return ""; + } +} + std::string ClientCertTest::GetECCPublicKey( DrmCertificate::Algorithm algorithm) { ECTestKeys keys; @@ -239,7 +295,7 @@ std::string ClientCertTest::GetECCPublicKey( } } -DrmCertificate* ClientCertTest::GenerateIntermediateCertificate( +std::unique_ptr ClientCertTest::GenerateIntermediateCertificate( uint32_t system_id, const std::string& serial_number) { std::unique_ptr intermediate_certificate(new DrmCertificate); intermediate_certificate->set_type(DrmCertificate::DEVICE_MODEL); @@ -248,10 +304,11 @@ DrmCertificate* ClientCertTest::GenerateIntermediateCertificate( GetPublicKeyByCertType(DrmCertificate::DEVICE_MODEL)); intermediate_certificate->set_system_id(system_id); intermediate_certificate->set_creation_time_seconds(1234); - return intermediate_certificate.release(); + return intermediate_certificate; } -SignedDrmCertificate* ClientCertTest::GenerateSignedIntermediateCertificate( +std::unique_ptr +ClientCertTest::GenerateSignedIntermediateCertificate( SignedDrmCertificate* signer, uint32_t system_id, const std::string& serial_number, DrmCertificate::Type signer_cert_type) { std::unique_ptr intermediate_certificate( @@ -261,7 +318,7 @@ SignedDrmCertificate* ClientCertTest::GenerateSignedIntermediateCertificate( GetPrivateKeyByCertType(signer_cert_type)); } -DrmCertificate* ClientCertTest::GenerateDrmCertificate( +std::unique_ptr ClientCertTest::GenerateDrmCertificate( uint32_t system_id, const std::string& serial_number, DrmCertificate::Algorithm algorithm) { std::unique_ptr drm_certificate(new DrmCertificate); @@ -274,10 +331,11 @@ DrmCertificate* ClientCertTest::GenerateDrmCertificate( : GetECCPublicKey(algorithm)); drm_certificate->set_creation_time_seconds(4321); drm_certificate->set_algorithm(algorithm); - return drm_certificate.release(); + return drm_certificate; } -SignedDrmCertificate* ClientCertTest::GenerateSignedDrmCertificate( +std::unique_ptr +ClientCertTest::GenerateSignedDrmCertificate( SignedDrmCertificate* signer, uint32_t system_id, const std::string& serial_number, DrmCertificate::Algorithm algorithm) { std::unique_ptr drm_certificate( @@ -285,10 +343,10 @@ SignedDrmCertificate* ClientCertTest::GenerateSignedDrmCertificate( std::unique_ptr signed_drm_certificate( SignCertificate(*drm_certificate, signer, GetPrivateKeyByCertType(DrmCertificate::DEVICE_MODEL))); - return signed_drm_certificate.release(); + return signed_drm_certificate; } -DrmCertificate* ClientCertTest::GenerateProvisionerCertificate( +std::unique_ptr ClientCertTest::GenerateProvisionerCertificate( uint32_t system_id, const std::string& serial_number, const std::string& provider_id) { std::unique_ptr provisioner_certificate(new DrmCertificate); @@ -299,10 +357,11 @@ DrmCertificate* ClientCertTest::GenerateProvisionerCertificate( provisioner_certificate->set_system_id(system_id); provisioner_certificate->set_provider_id(provider_id); provisioner_certificate->set_creation_time_seconds(1234); - return provisioner_certificate.release(); + return provisioner_certificate; } -SignedDrmCertificate* ClientCertTest::GenerateSignedProvisionerCertificate( +std::unique_ptr +ClientCertTest::GenerateSignedProvisionerCertificate( uint32_t system_id, const std::string& serial_number, const std::string& service_id) { std::unique_ptr provisioner_certificate( @@ -341,22 +400,48 @@ TEST_F(ClientCertTest, BasicValidation) { TEST_P(ClientCertTest, BasicCertValidation) { const uint32_t system_id = 1234; const std::string serial_number("serial_number"); - std::unique_ptr signed_cert( - GenerateSignedDrmCertificate( - GenerateSignedIntermediateCertificate(nullptr, system_id, - serial_number, kNoSigner), - system_id, serial_number + "-device", GetParam())); + std::unique_ptr intermediate_certificate = + GenerateSignedIntermediateCertificate(nullptr, system_id, serial_number, + kNoSigner); + std::unique_ptr signed_cert = + GenerateSignedDrmCertificate(intermediate_certificate.get(), system_id, + serial_number + "-device1", + std::get<0>(GetParam())); + SignedMessage::SessionKeyType expected_key_type = + std::get<0>(GetParam()) != DrmCertificate::RSA + ? SignedMessage::EPHEMERAL_ECC_PUBLIC_KEY + : SignedMessage::WRAPPED_AES_KEY; + std::unique_ptr encryption_certificate; + if (std::get<1>(GetParam()) != DrmCertificate::UNKNOWN_ALGORITHM) { + encryption_certificate = GenerateSignedDrmCertificate( + intermediate_certificate.get(), system_id, serial_number + "-device2", + std::get<1>(GetParam())); + expected_key_type = std::get<1>(GetParam()) != DrmCertificate::RSA + ? SignedMessage::EPHEMERAL_ECC_PUBLIC_KEY + : SignedMessage::WRAPPED_AES_KEY; + } const TestCertificateAndData kValidCertificateAndExpectedData( - signed_cert->SerializeAsString(), serial_number, system_id, OkStatus()); + signed_cert->SerializeAsString(), + encryption_certificate == nullptr + ? std::string() + : encryption_certificate->SerializeAsString(), + serial_number, system_id, OkStatus(), expected_key_type); const bool compare_data = true; TestBasicValidationDrmCertificate(kValidCertificateAndExpectedData, compare_data); } -INSTANTIATE_TEST_SUITE_P(BasicCertValidation, ClientCertTest, - testing::Values(DrmCertificate::RSA, - DrmCertificate::ECC_SECP256R1, - DrmCertificate::ECC_SECP384R1, - DrmCertificate::ECC_SECP521R1)); + +INSTANTIATE_TEST_SUITE_P( + BasicCertValidation, ClientCertTest, + testing::Combine(testing::Values(DrmCertificate::RSA, + DrmCertificate::ECC_SECP256R1, + DrmCertificate::ECC_SECP384R1, + DrmCertificate::ECC_SECP521R1), + testing::Values(DrmCertificate::UNKNOWN_ALGORITHM, + DrmCertificate::RSA, + DrmCertificate::ECC_SECP256R1, + DrmCertificate::ECC_SECP384R1, + DrmCertificate::ECC_SECP521R1))); TEST_F(ClientCertTest, InvalidKeybox) { const TestTokenAndKeys kInvalidTokenAndExpectedKeys[] = { @@ -395,71 +480,80 @@ TEST_F(ClientCertTest, InvalidCertificate) { std::unique_ptr invalid_drm_cert( new SignedDrmCertificate); invalid_drm_cert->set_drm_certificate("bad-serialized-cert"); - GenerateSignature(invalid_drm_cert->drm_certificate(), - test_rsa_keys_.private_test_key_2_2048_bits(), - invalid_drm_cert->mutable_signature()); - invalid_drm_cert->set_allocated_signer(GenerateSignedIntermediateCertificate( - nullptr, system_id, signer_sn, kNoSigner)); + GenerateSignature( + invalid_drm_cert->drm_certificate(), + test_rsa_keys_.private_test_key_2_2048_bits(), + HashAlgorithmProtoToEnum(invalid_drm_cert->hash_algorithm()), + invalid_drm_cert->mutable_signature()); + invalid_drm_cert->set_allocated_signer( + GenerateSignedIntermediateCertificate(nullptr, system_id, signer_sn, + kNoSigner) + .release()); // Invalid device public key. - dev_cert.reset(GenerateDrmCertificate(system_id, device_sn)); + dev_cert = GenerateDrmCertificate(system_id, device_sn); dev_cert->set_public_key("bad-device-public-key"); - std::unique_ptr bad_device_public_key( + std::unique_ptr bad_device_public_key = SignCertificate(*dev_cert, GenerateSignedIntermediateCertificate( - nullptr, system_id, signer_sn, kNoSigner), - test_rsa_keys_.private_test_key_2_2048_bits())); + nullptr, system_id, signer_sn, kNoSigner) + .get(), + test_rsa_keys_.private_test_key_2_2048_bits()); // Invalid serialized intermediate certificate. - signed_signer.reset(GenerateSignedIntermediateCertificate( - nullptr, system_id, signer_sn, kNoSigner)); + signed_signer = GenerateSignedIntermediateCertificate(nullptr, system_id, + signer_sn, kNoSigner); signed_signer->set_drm_certificate("bad-serialized-cert"); GenerateSignature(signed_signer->drm_certificate(), test_rsa_keys_.private_test_key_1_3072_bits(), + HashAlgorithmProtoToEnum(signed_signer->hash_algorithm()), signed_signer->mutable_signature()); - dev_cert.reset(GenerateDrmCertificate(system_id, device_sn)); + dev_cert = GenerateDrmCertificate(system_id, device_sn); std::unique_ptr invalid_signer( - SignCertificate(*dev_cert, signed_signer.release(), + SignCertificate(*dev_cert, signed_signer.get(), test_rsa_keys_.private_test_key_2_2048_bits())); // Invalid signer public key. - dev_cert.reset(GenerateDrmCertificate(system_id, device_sn)); - signer_cert.reset(GenerateIntermediateCertificate(system_id, signer_sn)); + dev_cert = GenerateDrmCertificate(system_id, device_sn); + signer_cert = GenerateIntermediateCertificate(system_id, signer_sn); signer_cert->set_public_key("bad-signer-public-key"); std::unique_ptr bad_signer_public_key(SignCertificate( *dev_cert, SignCertificate(*signer_cert, nullptr, - test_rsa_keys_.private_test_key_1_3072_bits()), + test_rsa_keys_.private_test_key_1_3072_bits()) + .get(), test_rsa_keys_.private_test_key_2_2048_bits())); // Invalid device certificate signature. std::unique_ptr bad_device_signature( - GenerateSignedDrmCertificate( - GenerateSignedIntermediateCertificate(nullptr, system_id, signer_sn, - kNoSigner), - system_id, device_sn)); + GenerateSignedDrmCertificate(GenerateSignedIntermediateCertificate( + nullptr, system_id, signer_sn, kNoSigner) + .get(), + system_id, device_sn)); bad_device_signature->set_signature("bad-signature"); // Missing model system ID. - dev_cert.reset(GenerateDrmCertificate(system_id, device_sn)); - signer_cert.reset(GenerateIntermediateCertificate(system_id, signer_sn)); + dev_cert = GenerateDrmCertificate(system_id, device_sn); + signer_cert = GenerateIntermediateCertificate(system_id, signer_sn); signer_cert->clear_system_id(); std::unique_ptr missing_model_sn(SignCertificate( *dev_cert, SignCertificate(*signer_cert, nullptr, - test_rsa_keys_.private_test_key_1_3072_bits()), + test_rsa_keys_.private_test_key_1_3072_bits()) + .get(), test_rsa_keys_.private_test_key_2_2048_bits())); // Missing signer serial number. - dev_cert.reset(GenerateDrmCertificate(system_id, device_sn)); - signer_cert.reset(GenerateIntermediateCertificate(system_id, signer_sn)); + dev_cert = GenerateDrmCertificate(system_id, device_sn); + signer_cert = GenerateIntermediateCertificate(system_id, signer_sn); signer_cert->clear_serial_number(); std::unique_ptr missing_signer_sn(SignCertificate( *dev_cert, SignCertificate(*signer_cert, nullptr, - test_rsa_keys_.private_test_key_1_3072_bits()), + test_rsa_keys_.private_test_key_1_3072_bits()) + .get(), test_rsa_keys_.private_test_key_2_2048_bits())); // Invalid serialized intermediate certificate. - dev_cert.reset(GenerateDrmCertificate(system_id, device_sn)); - signed_signer.reset(GenerateSignedIntermediateCertificate( - nullptr, system_id, signer_sn, kNoSigner)); + dev_cert = GenerateDrmCertificate(system_id, device_sn); + signed_signer = GenerateSignedIntermediateCertificate(nullptr, system_id, + signer_sn, kNoSigner); signed_signer->set_signature("bad-signature"); std::unique_ptr bad_signer_signature( - SignCertificate(*dev_cert, signed_signer.release(), + SignCertificate(*dev_cert, signed_signer.get(), test_rsa_keys_.private_test_key_2_2048_bits())); const TestCertificateAndData kInvalidCertificate[] = { @@ -504,8 +598,11 @@ TEST_F(ClientCertTest, MissingPreProvKey) { "beaa24924907e128f9ff49b54a165cd9c33e6547537eb4d29fb7e8df3c2c1cd9" "2517a12f4922953e")); std::unique_ptr client_cert_ptr; - Status status = ClientCert::Create( - root_cert_.get(), ClientIdentification::KEYBOX, token, &client_cert_ptr); + ClientIdentification client_id; + client_id.set_type(ClientIdentification::KEYBOX); + client_id.set_token(token); + Status status = + ClientCert::Create(root_cert_.get(), client_id, &client_cert_ptr); ASSERT_EQ(MISSING_PRE_PROV_KEY, status.error_code()); } @@ -516,26 +613,27 @@ TEST_F(ClientCertTest, ValidProvisionerDeviceCert) { const std::string intermediate_serial_number("intermediate-serial-number"); const std::string provisioner_serial_number("provisioner-serial-number"); - std::unique_ptr signed_provisioner_cert( + std::unique_ptr signed_provisioner_cert = GenerateSignedProvisionerCertificate(system_id, provisioner_serial_number, - service_id)); + service_id); - std::unique_ptr signed_intermediate_cert( + std::unique_ptr signed_intermediate_cert = GenerateSignedIntermediateCertificate( - signed_provisioner_cert.release(), system_id, - intermediate_serial_number, kProvisionerSigner)); + signed_provisioner_cert.get(), system_id, intermediate_serial_number, + kProvisionerSigner); - std::unique_ptr signed_device_cert( - GenerateSignedDrmCertificate(signed_intermediate_cert.release(), - system_id, device_serial_number)); + std::unique_ptr signed_device_cert = + GenerateSignedDrmCertificate(signed_intermediate_cert.get(), system_id, + device_serial_number); std::string serialized_cert; signed_device_cert->SerializeToString(&serialized_cert); std::unique_ptr drm_cert; + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(serialized_cert); - EXPECT_OK(ClientCert::Create(root_cert_.get(), - ClientIdentification::DRM_DEVICE_CERTIFICATE, - serialized_cert, &drm_cert)); + EXPECT_OK(ClientCert::Create(root_cert_.get(), client_id, &drm_cert)); ASSERT_TRUE(drm_cert); EXPECT_EQ(service_id, drm_cert->service_id()); @@ -551,27 +649,28 @@ TEST_F(ClientCertTest, InvalidProvisionerDeviceCertEmptyServiceId) { const std::string intermediate_serial_number("intermediate-serial-number"); const std::string provisioner_serial_number("provisioner-serial-number"); - std::unique_ptr signed_provisioner_cert( + std::unique_ptr signed_provisioner_cert = GenerateSignedProvisionerCertificate(system_id, provisioner_serial_number, - service_id)); + service_id); std::unique_ptr signed_intermediate_cert( GenerateSignedIntermediateCertificate( - signed_provisioner_cert.release(), system_id, - intermediate_serial_number, kProvisionerSigner)); + signed_provisioner_cert.get(), system_id, intermediate_serial_number, + kProvisionerSigner)); - std::unique_ptr signed_device_cert( - GenerateSignedDrmCertificate(signed_intermediate_cert.release(), - system_id, device_serial_number)); + std::unique_ptr signed_device_cert = + GenerateSignedDrmCertificate(signed_intermediate_cert.get(), system_id, + device_serial_number); std::string serialized_cert; signed_device_cert->SerializeToString(&serialized_cert); std::unique_ptr client_cert_ptr; + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(serialized_cert); EXPECT_EQ("missing-provisioning-service-id", - ClientCert::Create(root_cert_.get(), - ClientIdentification::DRM_DEVICE_CERTIFICATE, - serialized_cert, &client_cert_ptr) + ClientCert::Create(root_cert_.get(), client_id, &client_cert_ptr) .error_message()); EXPECT_FALSE(client_cert_ptr); } @@ -584,29 +683,30 @@ TEST_F(ClientCertTest, InvalidProvisionerDeviceCertChain) { const std::string intermediate_serial_number("intermediate-serial-number"); const std::string intermediate_serial_number2("intermediate-serial-number-2"); - std::unique_ptr signed_intermediate_cert2( + std::unique_ptr signed_intermediate_cert2 = GenerateSignedIntermediateCertificate( - nullptr, system_id2, intermediate_serial_number2, kNoSigner)); + nullptr, system_id2, intermediate_serial_number2, kNoSigner); // Instead of using a provisioner certificate to sign this intermediate // certificate, use another intermediate certificate. This is an invalid // chain and should generate an error when trying to create a client // certificate. - std::unique_ptr signed_intermediate_cert( + std::unique_ptr signed_intermediate_cert = GenerateSignedIntermediateCertificate( - signed_intermediate_cert2.release(), system_id, - intermediate_serial_number, kDeviceModelSigner)); - std::unique_ptr signed_device_cert( - GenerateSignedDrmCertificate(signed_intermediate_cert.release(), - system_id, device_serial_number)); + signed_intermediate_cert2.get(), system_id, + intermediate_serial_number, kDeviceModelSigner); + std::unique_ptr signed_device_cert = + GenerateSignedDrmCertificate(signed_intermediate_cert.get(), system_id, + device_serial_number); std::string serialized_cert; signed_device_cert->SerializeToString(&serialized_cert); std::unique_ptr client_cert_ptr; + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(serialized_cert); ASSERT_EQ("expected-provisioning-provider-certificate-type", - ClientCert::Create(root_cert_.get(), - ClientIdentification::DRM_DEVICE_CERTIFICATE, - serialized_cert, &client_cert_ptr) + ClientCert::Create(root_cert_.get(), client_id, &client_cert_ptr) .error_message()); EXPECT_FALSE(client_cert_ptr); } @@ -619,32 +719,33 @@ TEST_F(ClientCertTest, InvalidDeviceCertChainSize_TooLong) { const std::string intermediate_serial_number2("intermediate-serial-number-2"); const std::string provisioner_serial_number("provisioner-serial-number"); - std::unique_ptr signed_provisioner_cert( + std::unique_ptr signed_provisioner_cert = GenerateSignedProvisionerCertificate(system_id, provisioner_serial_number, - service_id)); + service_id); - std::unique_ptr signed_intermediate_cert1( + std::unique_ptr signed_intermediate_cert1 = GenerateSignedIntermediateCertificate( - signed_provisioner_cert.release(), system_id, - intermediate_serial_number1, kProvisionerSigner)); + signed_provisioner_cert.get(), system_id, intermediate_serial_number1, + kProvisionerSigner); - std::unique_ptr signed_intermediate_cert2( + std::unique_ptr signed_intermediate_cert2 = GenerateSignedIntermediateCertificate( - signed_intermediate_cert1.release(), system_id, - intermediate_serial_number2, kDeviceModelSigner)); + signed_intermediate_cert1.get(), system_id, + intermediate_serial_number2, kDeviceModelSigner); - std::unique_ptr signed_device_cert( - GenerateSignedDrmCertificate(signed_intermediate_cert2.release(), - system_id, device_serial_number)); + std::unique_ptr signed_device_cert = + GenerateSignedDrmCertificate(signed_intermediate_cert2.get(), system_id, + device_serial_number); std::string serialized_cert; signed_device_cert->SerializeToString(&serialized_cert); std::unique_ptr client_cert_ptr = nullptr; + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(serialized_cert); ASSERT_EQ("certificate-chain-size-exceeded", - ClientCert::Create(root_cert_.get(), - ClientIdentification::DRM_DEVICE_CERTIFICATE, - serialized_cert, &client_cert_ptr) + ClientCert::Create(root_cert_.get(), client_id, &client_cert_ptr) .error_message()); EXPECT_FALSE(client_cert_ptr); } @@ -656,26 +757,27 @@ TEST_F(ClientCertTest, DeviceCertTypeNotLeaf) { const std::string provisioner_serial_number("provisioner-serial-number"); const std::string drm_serial_number("drm-serial-number"); - std::unique_ptr signed_provisioner_cert( + std::unique_ptr signed_provisioner_cert = GenerateSignedProvisionerCertificate(system_id, provisioner_serial_number, - service_id)); + service_id); // Use a DEVICE certificate as the intermediate certificate. - std::unique_ptr signed_intermediate_cert( - GenerateSignedDrmCertificate(signed_provisioner_cert.release(), system_id, - intermediate_serial_number)); + std::unique_ptr signed_intermediate_cert = + GenerateSignedDrmCertificate(signed_provisioner_cert.get(), system_id, + intermediate_serial_number); - std::unique_ptr signed_drm_cert( - GenerateSignedDrmCertificate(signed_intermediate_cert.release(), - system_id, drm_serial_number)); + std::unique_ptr signed_drm_cert = + GenerateSignedDrmCertificate(signed_intermediate_cert.get(), system_id, + drm_serial_number); std::string serialized_cert; signed_drm_cert->SerializeToString(&serialized_cert); std::unique_ptr client_cert_ptr; + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(serialized_cert); EXPECT_EQ("device-cert-must-be-leaf", - ClientCert::Create(root_cert_.get(), - ClientIdentification::DRM_DEVICE_CERTIFICATE, - serialized_cert, &client_cert_ptr) + ClientCert::Create(root_cert_.get(), client_id, &client_cert_ptr) .error_message()); EXPECT_FALSE(client_cert_ptr); } @@ -686,21 +788,22 @@ TEST_F(ClientCertTest, InvalidLeafCertificateType) { const std::string intermediate_serial_number("intermediate-serial-number"); const std::string provisioner_serial_number("provisioner-serial-number"); - std::unique_ptr signed_provisioner_cert( + std::unique_ptr signed_provisioner_cert = GenerateSignedProvisionerCertificate(system_id, provisioner_serial_number, - service_id)); - std::unique_ptr signed_intermediate_cert( + service_id); + std::unique_ptr signed_intermediate_cert = GenerateSignedIntermediateCertificate( - signed_provisioner_cert.release(), system_id, - intermediate_serial_number, kProvisionerSigner)); + signed_provisioner_cert.get(), system_id, intermediate_serial_number, + kProvisionerSigner); std::string serialized_cert; signed_intermediate_cert->SerializeToString(&serialized_cert); std::unique_ptr client_cert_ptr; + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(serialized_cert); // Leaf certificate must be a device certificate. EXPECT_EQ("expected-device-certificate-type", - ClientCert::Create(root_cert_.get(), - ClientIdentification::DRM_DEVICE_CERTIFICATE, - serialized_cert, &client_cert_ptr) + ClientCert::Create(root_cert_.get(), client_id, &client_cert_ptr) .error_message()); EXPECT_FALSE(client_cert_ptr); } @@ -709,9 +812,10 @@ TEST_F(ClientCertTest, Protocol21WithDrmCert) { const char message[] = "A weekend wasted is a weekend well spent."; std::unique_ptr client_cert; - ASSERT_OK(ClientCert::Create( - root_cert_.get(), ClientIdentification::DRM_DEVICE_CERTIFICATE, - test_drm_certs_.test_user_device_certificate(), &client_cert)); + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(test_drm_certs_.test_user_device_certificate()); + ASSERT_OK(ClientCert::Create(root_cert_.get(), client_id, &client_cert)); std::unique_ptr private_key( RsaPrivateKey::Create(test_rsa_keys_.private_test_key_3_2048_bits())); @@ -719,14 +823,16 @@ TEST_F(ClientCertTest, Protocol21WithDrmCert) { // Success std::string signature; - ASSERT_TRUE(private_key->GenerateSignature(message, &signature)); - EXPECT_OK(client_cert->VerifySignature(message, signature, VERSION_2_1)); + ASSERT_TRUE(private_key->GenerateSignature(message, kSha256, &signature)); + EXPECT_OK( + client_cert->VerifySignature(message, kSha256, signature, VERSION_2_1)); // Failure ASSERT_EQ(256, signature.size()); ++signature[127]; EXPECT_FALSE( - client_cert->VerifySignature(message, signature, VERSION_2_1).ok()); + client_cert->VerifySignature(message, kSha256, signature, VERSION_2_1) + .ok()); } TEST_F(ClientCertTest, Protocol22WithDrmCert) { @@ -734,9 +840,10 @@ TEST_F(ClientCertTest, Protocol22WithDrmCert) { const std::string message_hash(Sha512_Hash(message)); std::unique_ptr client_cert; - ASSERT_OK(ClientCert::Create( - root_cert_.get(), ClientIdentification::DRM_DEVICE_CERTIFICATE, - test_drm_certs_.test_user_device_certificate(), &client_cert)); + ClientIdentification client_id; + client_id.set_type(ClientIdentification::DRM_DEVICE_CERTIFICATE); + client_id.set_token(test_drm_certs_.test_user_device_certificate()); + ASSERT_OK(ClientCert::Create(root_cert_.get(), client_id, &client_cert)); std::unique_ptr private_key( RsaPrivateKey::Create(test_rsa_keys_.private_test_key_3_2048_bits())); @@ -744,14 +851,17 @@ TEST_F(ClientCertTest, Protocol22WithDrmCert) { // Success std::string signature; - ASSERT_TRUE(private_key->GenerateSignature(message_hash, &signature)); - EXPECT_OK(client_cert->VerifySignature(message, signature, VERSION_2_2)); + ASSERT_TRUE( + private_key->GenerateSignature(message_hash, kSha256, &signature)); + EXPECT_OK( + client_cert->VerifySignature(message, kSha256, signature, VERSION_2_2)); // Failure ASSERT_EQ(256, signature.size()); ++signature[127]; EXPECT_FALSE( - client_cert->VerifySignature(message, signature, VERSION_2_2).ok()); + client_cert->VerifySignature(message, kSha256, signature, VERSION_2_2) + .ok()); } } // namespace widevine diff --git a/common/client_id_util.cc b/common/client_id_util.cc index a997f10..cb25df8 100644 --- a/common/client_id_util.cc +++ b/common/client_id_util.cc @@ -22,6 +22,11 @@ namespace widevine { const char kModDrmMake[] = "company_name"; const char kModDrmModel[] = "model_name"; +const char kModDrmDeviceName[] = "device_name"; +const char kModDrmProductName[] = "product_name"; +const char kModDrmBuildInfo[] = "build_info"; +const char kModDrmOemCryptoSecurityPatchLevel[] = + "oem_crypto_security_patch_level"; void AddClientInfo(ClientIdentification* client_id, absl::string_view name, absl::string_view value) { diff --git a/common/client_id_util.h b/common/client_id_util.h index 6115935..025b158 100644 --- a/common/client_id_util.h +++ b/common/client_id_util.h @@ -21,6 +21,10 @@ namespace widevine { extern const char kModDrmMake[]; extern const char kModDrmModel[]; +extern const char kModDrmDeviceName[]; +extern const char kModDrmProductName[]; +extern const char kModDrmBuildInfo[]; +extern const char kModDrmOemCryptoSecurityPatchLevel[]; // Append the given name/value pair to client_id->client_info(). Does not // check for duplicates. diff --git a/common/content_id_util.cc b/common/content_id_util.cc index 8c7e170..40c755e 100644 --- a/common/content_id_util.cc +++ b/common/content_id_util.cc @@ -7,6 +7,7 @@ //////////////////////////////////////////////////////////////////////////////// #include "common/content_id_util.h" +#include "glog/logging.h" #include "common/error_space.h" #include "common/status.h" #include "license_server_sdk/internal/parse_content_id.h" @@ -18,33 +19,14 @@ namespace widevine { -// TODO(user): Move the util methods from -// //license_server_sdk/internal/parse_content_id.h -// into this file. - Status GetContentIdFromExternalLicenseRequest( const ExternalLicenseRequest& external_license_request, std::string* content_id) { - LicenseRequest::ContentIdentification content_identification = - external_license_request.content_id(); - WidevinePsshData widevine_pssh_data; - if (content_identification.has_widevine_pssh_data()) { - widevine_pssh_data.ParseFromString( - content_identification.widevine_pssh_data().pssh_data(0)); - } else if (content_identification.has_webm_key_id()) { - widevine_pssh_data.ParseFromString( - content_identification.webm_key_id().header()); - } else if (content_identification.has_init_data()) { - ContentInfo content_info; - if (ParseContentId(content_identification, &content_info).ok()) { - widevine_pssh_data = - content_info.content_info_entry(0).pssh().widevine_data(); - } - } - *content_id = widevine_pssh_data.content_id(); + WidevinePsshData pssh_data; + Status status = ParsePsshData(external_license_request, &pssh_data); + *content_id = pssh_data.content_id(); return OkStatus(); } - Status GetContentIdFromSignedExternalLicenseRequest( const SignedMessage& signed_message, std::string* content_id) { if (signed_message.type() != SignedMessage::EXTERNAL_LICENSE_REQUEST) { @@ -61,4 +43,41 @@ Status GetContentIdFromSignedExternalLicenseRequest( content_id); } +Status ParsePsshData(ExternalLicenseRequest external_license_request, + WidevinePsshData* widevine_pssh_data) { + if (!external_license_request.has_content_id()) { + std::string error = "ExternalLicenseRequest does not include ContentId"; + LOG(ERROR) << error + << ", request = " << external_license_request.ShortDebugString(); + return Status(error_space, MISSING_CONTENT_ID, error); + } + ContentInfo content_info; + Status status = + ParseContentId(external_license_request.content_id(), &content_info); + if (!status.ok()) { + std::string error = + "Unable to retrieve ContentId from ExternalLicenseRequest"; + LOG(ERROR) << error << ", status = " << status + << ", request = " << external_license_request.ShortDebugString(); + return Status(error_space, MISSING_CONTENT_ID, error); + } + switch (external_license_request.content_id().init_data().init_data_type()) { + case LicenseRequest::ContentIdentification::InitData::WEBM: + widevine_pssh_data->ParseFromString( + content_info.content_info_entry(0).key_ids(0)); + break; + default: + *widevine_pssh_data = + content_info.content_info_entry(0).pssh().widevine_data(); + break; + } + if (widevine_pssh_data->content_id().empty()) { + std::string error = + "Missing ContentId within Pssh data for ExternalLicenseRequest"; + LOG(ERROR) << error + << ", request = " << external_license_request.ShortDebugString(); + return Status(error_space, MISSING_CONTENT_ID, error); + } + return OkStatus(); +} } // namespace widevine diff --git a/common/content_id_util.h b/common/content_id_util.h index 7272982..3775cb3 100644 --- a/common/content_id_util.h +++ b/common/content_id_util.h @@ -12,6 +12,7 @@ #include "common/status.h" #include "protos/public/external_license.pb.h" #include "protos/public/license_protocol.pb.h" +#include "protos/public/widevine_pssh.pb.h" namespace widevine { @@ -25,6 +26,12 @@ Status GetContentIdFromExternalLicenseRequest( const ExternalLicenseRequest& external_license_request, std::string* content_id); +// Returns OK if successful and |widevine_pssh_data| will be populated by +// parsing |external_license_request|. Else, error and |widevine_pssh_data| +// will not be set within this method. +Status ParsePsshData(ExternalLicenseRequest external_license_request, + WidevinePsshData* widevine_pssh_data); + } // namespace widevine #endif // COMMON_CONTENT_ID_UTIL_H_ diff --git a/common/content_id_util_test.cc b/common/content_id_util_test.cc index a2f82ea..ca0ea31 100644 --- a/common/content_id_util_test.cc +++ b/common/content_id_util_test.cc @@ -8,6 +8,10 @@ #include "common/content_id_util.h" +#include + +#include + #include "testing/gmock.h" #include "testing/gunit.h" #include "protos/public/errors.pb.h" @@ -23,9 +27,9 @@ const char kPlayReadyChallenge[] = ""; namespace widevine { // Builds a SignedMessage that includes an ExternalLicenseRequest. -SignedMessage BuildSignedExternalLicenseRequest( - const ExternalLicenseRequest::RequestType type, const std::string& request, - const std::string& content_id) { +SignedMessage BuildSignedExternalLicenseRequest(const ExternalLicenseType type, + const std::string& request, + const std::string& content_id) { ExternalLicenseRequest external_license_request; external_license_request.set_request_type(type); external_license_request.set_request(request); @@ -47,9 +51,8 @@ SignedMessage BuildSignedExternalLicenseRequest( TEST(ContentIdUtil, GetContentId) { std::string content_id; EXPECT_OK(GetContentIdFromSignedExternalLicenseRequest( - BuildSignedExternalLicenseRequest( - ExternalLicenseRequest::PLAYREADY_LICENSE_REQUEST, - kPlayReadyChallenge, kContentId), + BuildSignedExternalLicenseRequest(PLAYREADY_LICENSE_NEW, + kPlayReadyChallenge, kContentId), &content_id)); EXPECT_EQ(kContentId, content_id); } @@ -57,8 +60,7 @@ TEST(ContentIdUtil, GetContentId) { TEST(ContentIdUtil, GetContentIdFailureWithIncorrectType) { std::string content_id; SignedMessage signed_message = BuildSignedExternalLicenseRequest( - ExternalLicenseRequest::PLAYREADY_LICENSE_REQUEST, kPlayReadyChallenge, - kContentId); + PLAYREADY_LICENSE_NEW, kPlayReadyChallenge, kContentId); signed_message.set_type(SignedMessage::SERVICE_CERTIFICATE_REQUEST); Status status = GetContentIdFromSignedExternalLicenseRequest(signed_message, &content_id); @@ -69,8 +71,7 @@ TEST(ContentIdUtil, GetContentIdFailureWithIncorrectType) { TEST(ContentIdUtil, GetContentIdFailureWithInvalidExternalLicenseRequest) { std::string content_id; SignedMessage signed_message = BuildSignedExternalLicenseRequest( - ExternalLicenseRequest::PLAYREADY_LICENSE_REQUEST, kPlayReadyChallenge, - kContentId); + PLAYREADY_LICENSE_NEW, kPlayReadyChallenge, kContentId); signed_message.set_msg("Invalid payload"); Status status = GetContentIdFromSignedExternalLicenseRequest(signed_message, &content_id); diff --git a/common/core_message_util.cc b/common/core_message_util.cc index ecc846e..90837c0 100644 --- a/common/core_message_util.cc +++ b/common/core_message_util.cc @@ -42,7 +42,7 @@ bool GetCoreProvisioningResponse( } bool GetCoreRenewalOrReleaseLicenseResponse( - const std::string& request_core_message, + uint64_t renewal_duration_seconds, const std::string& request_core_message, std::string* response_core_message) { oemcrypto_core_message::ODK_RenewalRequest odk_renewal_request; if (request_core_message.empty()) { @@ -52,11 +52,6 @@ bool GetCoreRenewalOrReleaseLicenseResponse( &odk_renewal_request)) { return false; } - // TODO(b/141762043): This function is going to need to know what the - // renewal license is, and extract the renewal duration. This should be the - // sum of renewal_delay_seconds + 2 * renewal_recovery_duration_seconds. - uint64_t renewal_duration_seconds = - 3600; // PTAL when addressing b/141762043. return CreateCoreRenewalResponse( odk_renewal_request, renewal_duration_seconds, response_core_message); } diff --git a/common/core_message_util.h b/common/core_message_util.h index e6196b6..7c88dee 100644 --- a/common/core_message_util.h +++ b/common/core_message_util.h @@ -24,7 +24,7 @@ bool GetCoreProvisioningResponse( // Gets the |response_core_message| by parsing |request_core_message| for // release and renewal response. The output is held in |response_core_message|. bool GetCoreRenewalOrReleaseLicenseResponse( - const std::string& request_core_message, + uint64_t renewal_duration_seconds, const std::string& request_core_message, std::string* response_core_message); // Gets the |response_core_message| by parsing |request_core_message| and diff --git a/common/crypto_util.cc b/common/crypto_util.cc index 52d1352..1ff6747 100644 --- a/common/crypto_util.cc +++ b/common/crypto_util.cc @@ -11,6 +11,7 @@ #include "common/crypto_util.h" #include "glog/logging.h" +#include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "openssl/aes.h" #include "openssl/cmac.h" @@ -37,6 +38,8 @@ const char kGroupKeyLabel[] = "GROUP_ENCRYPTION"; // a real group master key in keystore. // TODO(user): figure out why VerifySignatureHmacSha256 can not crypto_mcmcpy // like VerifySignatureHmacSha1. +// TODO(user): Revert logging signature in VerifySignatureHmacSha256. +// function. const char kPhonyGroupMasterKey[] = "fedcba9876543210"; const int kAes128KeySizeBits = 128; const int kAes128KeySizeBytes = 16; diff --git a/common/default_device_security_profile_list.cc b/common/default_device_security_profile_list.cc new file mode 100644 index 0000000..9aa625e --- /dev/null +++ b/common/default_device_security_profile_list.cc @@ -0,0 +1,135 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// +// Implementation of the DefaultDeviceSecurityProfileList class. + +#include "common/default_device_security_profile_list.h" + +#include + +#include "glog/logging.h" +#include "google/protobuf/text_format.h" +#include "common/client_id_util.h" +#include "common/device_status_list.h" +#include "protos/public/client_identification.pb.h" +#include "protos/public/device_certificate_status.pb.h" +#include "protos/public/device_common.pb.h" +#include "protos/public/provisioned_device_info.pb.h" +#include "protos/public/security_profile.pb.h" + +namespace widevine { +using ClientCapabilities = ClientIdentification::ClientCapabilities; + +const char kWidevine[] = "widevine"; + +// Definition of Widevine default device security profiles. +// TODO(user): Add an OWNER file with per-file access to restrict changes to the +// profile definition. +const char kWidevineProfileMin[] = + (" name: \"minimum\"" + " min_output_requirements {" + " hdcp_version: HDCP_NONE" + " analog_output_capabilities: ANALOG_OUTPUT_UNKNOWN" + " }" + " min_security_requirements {" + " oemcrypto_api_version: 0" + " security_level: LEVEL_3" + " resource_rating_tier: 0" + " vulnerability_level: VULNERABILITY_HIGH" + " }" + " owner: \"Widevine\""); + +const char kWidevineProfileLow[] = + (" name: \"low\"" + " min_output_requirements {" + " hdcp_version: HDCP_NONE" + " analog_output_capabilities: ANALOG_OUTPUT_UNKNOWN" + " }" + " min_security_requirements {" + " oemcrypto_api_version: 8" + " security_level: LEVEL_3" + " resource_rating_tier: 1" + " vulnerability_level: VULNERABILITY_MEDIUM" + " }" + " owner: \"Widevine\""); + +const char kWidevineProfileMed[] = + (" name: \"medium\"" + " min_output_requirements {" + " hdcp_version: HDCP_V1" + " analog_output_capabilities: ANALOG_OUTPUT_UNKNOWN" + " }" + " min_security_requirements {" + " oemcrypto_api_version: 12" + " security_level: LEVEL_3" + " resource_rating_tier: 1" + " vulnerability_level: VULNERABILITY_LOW" + " }" + " owner: \"Widevine\""); + +const char kWidevineProfileHigh[] = + (" name: \"high\"" + " min_output_requirements {" + " hdcp_version: HDCP_V1" + " analog_output_capabilities: ANALOG_OUTPUT_SUPPORTS_CGMS_A" + " }" + " min_security_requirements {" + " oemcrypto_api_version: 12" + " security_level: LEVEL_1" + " resource_rating_tier: 2" + " vulnerability_level: VULNERABILITY_NONE" + " }" + " owner: \"Widevine\""); + +const char kWidevineProfileStrict[] = + (" name: \"strict\"" + " min_output_requirements {" + " hdcp_version: HDCP_V2_2" + " analog_output_capabilities: ANALOG_OUTPUT_SUPPORTS_CGMS_A" + " }" + " min_security_requirements {" + " oemcrypto_api_version: 12" + " security_level: LEVEL_1" + " resource_rating_tier: 3" + " vulnerability_level: VULNERABILITY_NONE" + " }" + " owner: \"Widevine\""); + +DefaultDeviceSecurityProfileList::DefaultDeviceSecurityProfileList() + : SecurityProfileList(kWidevine) {} + +int DefaultDeviceSecurityProfileList::Init() { return AddDefaultProfiles(); } + +int DefaultDeviceSecurityProfileList::AddDefaultProfiles() { + std::vector default_profile_strings; + GetDefaultProfileStrings(&default_profile_strings); + for (auto& profile_string : default_profile_strings) { + SecurityProfile profile; + if (!google::protobuf::TextFormat::ParseFromString(profile_string, &profile)) { + LOG(ERROR) << "Unable to load default profile: " << profile.name(); + ClearAllProfiles(); + return 0; + } + InsertProfile(profile); + } + return NumProfiles(); +} + +int DefaultDeviceSecurityProfileList::GetDefaultProfileStrings( + std::vector* default_profile_strings) const { + if (default_profile_strings == nullptr) { + return 0; + } + default_profile_strings->push_back(kWidevineProfileMin); + default_profile_strings->push_back(kWidevineProfileLow); + default_profile_strings->push_back(kWidevineProfileMed); + default_profile_strings->push_back(kWidevineProfileHigh); + default_profile_strings->push_back(kWidevineProfileStrict); + return default_profile_strings->size(); +} + +} // namespace widevine diff --git a/common/default_device_security_profile_list.h b/common/default_device_security_profile_list.h new file mode 100644 index 0000000..33d96d9 --- /dev/null +++ b/common/default_device_security_profile_list.h @@ -0,0 +1,39 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// +// +// Description: +// Container of Widevine default security profiless. + +#ifndef COMMON_DEFAULT_DEVICE_SECURITY_PROFILE_LIST_H_ +#define COMMON_DEFAULT_DEVICE_SECURITY_PROFILE_LIST_H_ + +#include "common/security_profile_list.h" + +namespace widevine { + +class DefaultDeviceSecurityProfileList : public SecurityProfileList { + public: + DefaultDeviceSecurityProfileList(); + ~DefaultDeviceSecurityProfileList() override {} + + // Initialize the security profile list. The list is initially empty, this + // function will populate the list with default profiles. The size of the + // list is returned. + int Init() override; + + private: + // Initialize the list with Widevine default profiles. The size of the + // profile list after the additions is returned. + virtual int AddDefaultProfiles(); + virtual int GetDefaultProfileStrings( + std::vector* default_profile_strings) const; +}; + +} // namespace widevine + +#endif // COMMON_DEFAULT_DEVICE_SECURITY_PROFILE_LIST_H_ diff --git a/common/default_device_security_profile_list_test.cc b/common/default_device_security_profile_list_test.cc new file mode 100644 index 0000000..958836c --- /dev/null +++ b/common/default_device_security_profile_list_test.cc @@ -0,0 +1,186 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// +// + +#include "common/default_device_security_profile_list.h" + +#include "glog/logging.h" +#include "google/protobuf/util/message_differencer.h" +#include "testing/gmock.h" +#include "testing/gunit.h" +#include "absl/memory/memory.h" +#include "common/client_id_util.h" +#include "protos/public/device_common.pb.h" +#include "protos/public/security_profile.pb.h" + +namespace widevine { +namespace security_profile { + +const uint32_t kResourceTierLow = 1; +const uint32_t kResourceTierMed = 2; +const uint32_t kResourceTierHigh = 3; +const char kMinProfileName[] = "minimum"; +const char kLowProfileName[] = "low"; +const char kMedProfileName[] = "medium"; +const char kHighProfileName[] = "high"; +const char kStrictProfileName[] = "strict"; + +class DefaultDeviceSecurityProfileListTest : public ::testing::Test { + public: + DefaultDeviceSecurityProfileListTest() {} + ~DefaultDeviceSecurityProfileListTest() override {} + + void SetUp() override { + SecurityProfile profile; + std::string profile_namespace = "widevine"; + profile_list_ = absl::make_unique(); + const int kNumWidevineProfiles = 5; + ASSERT_EQ(kNumWidevineProfiles, profile_list_->Init()); + } + + // Configure |client_id| and |device_info| with minimum settings. + void SetupMinDrmParams(ClientIdentification* client_id, + ProvisionedDeviceInfo* device_info) { + client_id->mutable_client_capabilities()->set_max_hdcp_version( + ClientCapabilities::HDCP_NONE); + client_id->mutable_client_capabilities()->set_analog_output_capabilities( + ClientCapabilities::ANALOG_OUTPUT_UNKNOWN); + client_id->mutable_client_capabilities()->set_oem_crypto_api_version(0); + client_id->mutable_client_capabilities()->set_resource_rating_tier( + kResourceTierLow); + device_info->set_security_level(ProvisionedDeviceInfo::LEVEL_3); + } + + // Configure |client_id| and |device_info| with maximum settings. + void SetupMaxDrmParams(ClientIdentification* client_id, + ProvisionedDeviceInfo* device_info) { + client_id->mutable_client_capabilities()->set_max_hdcp_version( + ClientCapabilities::HDCP_V2_3); + client_id->mutable_client_capabilities()->set_analog_output_capabilities( + ClientCapabilities::ANALOG_OUTPUT_SUPPORTS_CGMS_A); + client_id->mutable_client_capabilities()->set_oem_crypto_api_version(16); + client_id->mutable_client_capabilities()->set_resource_rating_tier( + kResourceTierHigh); + device_info->set_security_level(ProvisionedDeviceInfo::LEVEL_1); + } + + std::unique_ptr profile_list_; +}; + +TEST_F(DefaultDeviceSecurityProfileListTest, QualifiedProfiles) { + ClientIdentification client_id; + ProvisionedDeviceInfo device_info; + SetupMinDrmParams(&client_id, &device_info); + + std::vector qualified_profiles; + // Should only return the minimum profile. + ASSERT_EQ(1, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMinProfileName) != qualified_profiles.end()); + + // Increase the device capabilities to include the low profile. + client_id.mutable_client_capabilities()->set_oem_crypto_api_version(8); + ASSERT_EQ(2, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); + + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMinProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kLowProfileName) != qualified_profiles.end()); + + // Increase the device capabilities to include the med profile. + client_id.mutable_client_capabilities()->set_max_hdcp_version( + ClientCapabilities::HDCP_V1); + client_id.mutable_client_capabilities()->set_oem_crypto_api_version(12); + ASSERT_EQ(3, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMinProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kLowProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMedProfileName) != qualified_profiles.end()); + + // Increase the device capabilities to include the high profile. + device_info.set_security_level(ProvisionedDeviceInfo::LEVEL_1); + client_id.mutable_client_capabilities()->set_analog_output_capabilities( + ClientCapabilities::ANALOG_OUTPUT_SUPPORTS_CGMS_A); + client_id.mutable_client_capabilities()->set_resource_rating_tier( + kResourceTierMed); + ASSERT_EQ(4, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMinProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kLowProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMedProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kHighProfileName) != qualified_profiles.end()); + + // Increase the device capabilities to include the strict profile. + client_id.mutable_client_capabilities()->set_max_hdcp_version( + ClientCapabilities::HDCP_V2_2); + client_id.mutable_client_capabilities()->set_resource_rating_tier( + kResourceTierHigh); + ASSERT_EQ(5, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMinProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kLowProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMedProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kHighProfileName) != qualified_profiles.end()); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kStrictProfileName) != qualified_profiles.end()); +} + +TEST_F(DefaultDeviceSecurityProfileListTest, + DeviceQualifiedProfilesForLowEndDevice) { + ClientIdentification client_id; + ProvisionedDeviceInfo device_info; + SetupMinDrmParams(&client_id, &device_info); + + // Only 1 profile should qualify for this device. + std::vector qualified_profiles; + ASSERT_EQ(1, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); + EXPECT_TRUE(std::find(qualified_profiles.begin(), qualified_profiles.end(), + kMinProfileName) != qualified_profiles.end()); +} + +TEST_F(DefaultDeviceSecurityProfileListTest, + QualifiedProfilesForHighEndDevice) { + ClientIdentification client_id; + ProvisionedDeviceInfo device_info; + SetupMaxDrmParams(&client_id, &device_info); + + // All 5 default profiles should qualify for this device. + std::vector qualified_profiles; + ASSERT_EQ(5, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); +} + +// TODO(b/160019477): Add test once provisioned device info supports known +// vulnerability. +TEST_F(DefaultDeviceSecurityProfileListTest, + DISABLED_QualifiedProfilesByVunerabilityLevel) { + ClientIdentification client_id; + ProvisionedDeviceInfo device_info; + SetupMaxDrmParams(&client_id, &device_info); + + std::vector qualified_profiles; + ASSERT_EQ(0, profile_list_->GetQualifiedProfiles(client_id, device_info, + &qualified_profiles)); +} + +} // namespace security_profile +} // namespace widevine diff --git a/common/device_info_util.cc b/common/device_info_util.cc index 37cd0a9..1579e91 100644 --- a/common/device_info_util.cc +++ b/common/device_info_util.cc @@ -30,7 +30,7 @@ bool VerifyMakeModel(const ProvisionedDeviceInfo& device_info, make_from_client, model_from_client)) { return true; } - for (DeviceModel product_info : device_info.model_info()) { + for (const DeviceModel& product_info : device_info.model_info()) { if (IsMatchedMakeModel(product_info.manufacturer(), product_info.model_name(), make_from_client, model_from_client)) { diff --git a/common/device_info_util.h b/common/device_info_util.h index 6347f6d..b202519 100644 --- a/common/device_info_util.h +++ b/common/device_info_util.h @@ -7,6 +7,7 @@ //////////////////////////////////////////////////////////////////////////////// #ifndef COMMON_DEVICE_INFO_UTIL_H_ #define COMMON_DEVICE_INFO_UTIL_H_ + #include #include "protos/public/provisioned_device_info.pb.h" diff --git a/common/device_status_list.cc b/common/device_status_list.cc index 7a3fbcc..5cfe9ec 100644 --- a/common/device_status_list.cc +++ b/common/device_status_list.cc @@ -27,6 +27,7 @@ #include "common/client_cert.h" #include "common/drm_service_certificate.h" #include "common/error_space.h" +#include "common/hash_algorithm_util.h" #include "common/keybox_client_cert.h" #include "common/rsa_key.h" #include "common/status.h" @@ -63,7 +64,8 @@ DeviceStatusList::~DeviceStatusList() {} Status DeviceStatusList::UpdateStatusList( const std::string& root_certificate_public_key, const std::string& serialized_device_certificate_status_list, - const std::string& signature, uint32_t expiration_period_seconds) { + HashAlgorithm hash_algorithm, const std::string& signature, + uint32_t expiration_period_seconds) { if (serialized_device_certificate_status_list.empty()) { return Status(error_space, INVALID_CERTIFICATE_STATUS_LIST, "missing-status-list"); @@ -79,7 +81,7 @@ Status DeviceStatusList::UpdateStatusList( "invalid-root-public-key"); } if (!root_key->VerifySignature(serialized_device_certificate_status_list, - signature)) { + hash_algorithm, signature)) { return Status(error_space, INVALID_CERTIFICATE_STATUS_LIST, "invalid-status-list-signature"); } @@ -117,29 +119,11 @@ Status DeviceStatusList::UpdateStatusList( return OkStatus(); } -Status DeviceStatusList::GetCertStatus(const ClientCert& client_cert, - const std::string& device_manufacturer, - ProvisionedDeviceInfo* device_info) { - CHECK(device_info); - - // Keybox checks. - if (client_cert.type() == ClientIdentification::KEYBOX) { - if (!KeyboxClientCert::IsSystemIdKnown(client_cert.system_id())) { - return Status(error_space, UNSUPPORTED_SYSTEM_ID, - "keybox-unsupported-system-id"); - } - // Get device information from certificate status list if available. - if (!GetDeviceInfo(client_cert, device_info)) { - device_info->Clear(); - } - return OkStatus(); - } - - // DRM certificate checks. - if (client_cert.type() != ClientIdentification::DRM_DEVICE_CERTIFICATE) { - return Status(error_space, INVALID_DRM_CERTIFICATE, - "device-certificate-unsupported-token-type"); - } +Status DeviceStatusList::GetCertStatus( + const ClientCert& client_cert, const std::string& make, + const std::string& provider, bool allow_revoked_system_id, + DeviceCertificateStatus* device_certificate_status) { + CHECK(device_certificate_status); absl::ReaderMutexLock lock(&status_map_lock_); if (expiration_period_seconds_ && (GetCurrentTime() > @@ -149,59 +133,67 @@ Status DeviceStatusList::GetCertStatus(const ClientCert& client_cert, } DeviceCertificateStatus* device_cert_status = gtl::FindOrNull(device_status_map_, client_cert.system_id()); - if (device_cert_status) { - *device_info = device_cert_status->device_info(); - if (device_cert_status->status() == - DeviceCertificateStatus::STATUS_REVOKED) { - if (IsRevokedSystemIdAllowed(client_cert.system_id())) { - LOG(WARNING) << "Allowing REVOKED device: " - << device_info->ShortDebugString(); - } else { - return Status(error_space, DRM_DEVICE_CERTIFICATE_REVOKED, - "device-certificate-revoked"); - } + + if (device_cert_status == nullptr) { + if (allow_unknown_devices_) { + return OkStatus(); } - if ((device_cert_status->status() == - DeviceCertificateStatus::STATUS_TEST_ONLY) && - !allow_test_only_devices_) { - if (IsTestOnlyDeviceAllowed(client_cert.system_id(), - device_manufacturer)) { - LOG(WARNING) << "Allowing TEST_ONLY device with systemId = " - << client_cert.system_id() - << "make = " << device_manufacturer - << ", device info = " << device_info->ShortDebugString(); - } else { - VLOG(2) << "Not allowing TEST ONLY device with systemId = " - << client_cert.system_id() << "make = " << device_manufacturer - << ", device info = " << device_info->ShortDebugString(); - return Status(error_space, DEVELOPMENT_CERTIFICATE_NOT_ALLOWED, - "test-only-drm-certificate-not-allowed"); - } + return client_cert.SystemIdUnknownError(); + } + *device_certificate_status = *device_cert_status; + + if (device_cert_status->status() == DeviceCertificateStatus::STATUS_REVOKED) { + if (IsRevokedSystemIdAllowed(client_cert.system_id()) || + allow_revoked_system_id) { + LOG(WARNING) << "Allowing REVOKED device: " + << device_cert_status->device_info().ShortDebugString(); + } else { + return client_cert.SystemIdRevokedError(); } - if (!client_cert.signed_by_provisioner() && - (client_cert.signer_serial_number() != - device_cert_status->drm_serial_number())) { - // Widevine-provisioned device, and the intermediate certificate serial - // number does not match that in the status list. If the status list is - // newer than the certificate, indicate an invalid certificate, so that - // the device re-provisions. If, on the other hand, the certificate status - // list is older than the certificate, the certificate is for all purposes - // unknown. - if (client_cert.signer_creation_time_seconds() < creation_time_seconds_) { - return Status(error_space, INVALID_DRM_CERTIFICATE, - "intermediate-certificate-serial-number-mismatch"); - } - return Status(error_space, DRM_DEVICE_CERTIFICATE_UNKNOWN, - "device-certificate-status-unknown"); - } - } else { - if (!allow_unknown_devices_) { - return Status(error_space, DRM_DEVICE_CERTIFICATE_UNKNOWN, - "device-certificate-status-unknown"); - } - device_info->Clear(); } + // The remainder of this function is for DRM certificates. + if (client_cert.type() == ClientIdentification::KEYBOX) { + return OkStatus(); + } + // DRM certificate checks. + if (client_cert.type() != ClientIdentification::DRM_DEVICE_CERTIFICATE) { + return Status(error_space, INVALID_DRM_CERTIFICATE, + "device-certificate-unsupported-token-type"); + } + if ((device_cert_status->status() == + DeviceCertificateStatus::STATUS_TEST_ONLY) && + !allow_test_only_devices_) { + if (IsTestOnlyDeviceAllowedByMake(client_cert.system_id(), make) && + IsTestOnlyDeviceAllowedByProvider(client_cert.system_id(), provider)) { + LOG(WARNING) << "Allowing TEST_ONLY device with systemId = " + << client_cert.system_id() << ", make = " << make + << ", provider = " << provider << ", device info = " + << device_cert_status->device_info().ShortDebugString(); + } else { + VLOG(2) << "Not allowing TEST ONLY device with systemId = " + << client_cert.system_id() << ", provider = " << provider + << ", device info = " + << device_cert_status->device_info().ShortDebugString(); + return Status(error_space, DEVELOPMENT_CERTIFICATE_NOT_ALLOWED, + "test-only-drm-certificate-not-allowed"); + } + } + if (!client_cert.signed_by_provisioner() && + (client_cert.signer_serial_number() != + device_cert_status->drm_serial_number())) { + // Widevine-provisioned device, and the intermediate certificate serial + // number does not match that in the status list. If the status list is + // newer than the certificate, indicate an invalid certificate, so that + // the device re-provisions. If, on the other hand, the certificate status + // list is older than the certificate, the certificate is for all purposes + // unknown. + if (client_cert.signer_creation_time_seconds() < creation_time_seconds_) { + return Status(error_space, INVALID_DRM_CERTIFICATE, + "intermediate-certificate-serial-number-mismatch"); + } + return client_cert.SystemIdUnknownError(); + } return OkStatus(); } @@ -268,29 +260,55 @@ void DeviceStatusList::AllowRevokedDevices(const std::string& system_id_list) { std::sort(allowed_revoked_devices_.begin(), allowed_revoked_devices_.end()); } -void DeviceStatusList::AllowTestOnlyDevices(const std::string& device_list) { +void DeviceStatusList::AllowTestOnlyDevicesByMake( + const std::string& device_list_by_make) { absl::WriterMutexLock lock(&allowed_test_only_devices_mutex_); - if (device_list.empty()) { - allowed_test_only_devices_.clear(); + if (device_list_by_make.empty()) { + allowed_test_only_devices_by_make_.clear(); return; } - for (absl::string_view device : absl::StrSplit(device_list, ',')) { + for (absl::string_view device : absl::StrSplit(device_list_by_make, ',')) { const std::pair device_split = absl::StrSplit(device, ':'); if (device_split.second.empty() || device_split.second == "*") { - allowed_test_only_devices_.emplace( + allowed_test_only_devices_by_make_.emplace( std::stoi(std::string(device_split.first)), "*"); - VLOG(2) << "Whitelisting TEST_ONLY device: systemId = " - << std::stoi(std::string(device_split.first)) - << ", manufacturer = *"; + VLOG(2) << "Allowing TEST_ONLY device: systemId = " + << std::stoi(std::string(device_split.first)) << ", make *"; } else { - allowed_test_only_devices_.emplace( + allowed_test_only_devices_by_make_.emplace( std::stoi(std::string(device_split.first)), absl::AsciiStrToUpper(device_split.second)); - VLOG(2) << "Whitelisting TEST_ONLY device: systemId = " + VLOG(2) << "Allowing TEST_ONLY device: systemId = " << std::stoi(std::string(device_split.first)) - << ", manufacturer = " - << absl::AsciiStrToUpper(device_split.second); + << ", make = " << absl::AsciiStrToUpper(device_split.second); + } + } +} + +void DeviceStatusList::AllowTestOnlyDevicesByProvider( + const std::string& device_list_by_provider) { + absl::WriterMutexLock lock(&allowed_test_only_devices_mutex_); + if (device_list_by_provider.empty()) { + allowed_test_only_devices_by_provider_.clear(); + return; + } + for (absl::string_view device : + absl::StrSplit(device_list_by_provider, ',')) { + const std::pair device_split = + absl::StrSplit(device, ':'); + if (device_split.second.empty() || device_split.second == "*") { + allowed_test_only_devices_by_provider_.emplace( + std::stoi(std::string(device_split.first)), "*"); + VLOG(2) << "Allowing TEST_ONLY device: systemId = " + << std::stoi(std::string(device_split.first)) << ", provider *"; + } else { + allowed_test_only_devices_by_provider_.emplace( + std::stoi(std::string(device_split.first)), + absl::AsciiStrToUpper(device_split.second)); + VLOG(2) << "Allowing TEST_ONLY device: systemId = " + << std::stoi(std::string(device_split.first)) + << ", provider = " << absl::AsciiStrToUpper(device_split.second); } } } @@ -301,17 +319,34 @@ bool DeviceStatusList::IsRevokedSystemIdAllowed(uint32_t system_id) { return it; } -bool DeviceStatusList::IsTestOnlyDeviceAllowed(uint32_t system_id, - const std::string manufacturer) { +bool DeviceStatusList::IsTestOnlyDeviceAllowedByMake( + uint32_t system_id, const std::string& manufacturer) { absl::ReaderMutexLock lock(&allowed_test_only_devices_mutex_); std::pair::iterator, std::multimap::iterator> - allowed_manufacturers = allowed_test_only_devices_.equal_range(system_id); - for (auto it = allowed_manufacturers.first; - it != allowed_manufacturers.second; ++it) { - std::string allowed_manufacturer = (*it).second; - if (allowed_manufacturer == "*" || - allowed_manufacturer == absl::AsciiStrToUpper(manufacturer)) { + allowed_makes = allowed_test_only_devices_by_make_.equal_range(system_id); + for (auto it = allowed_makes.first; it != allowed_makes.second; ++it) { + std::string allowed_makes = (*it).second; + if (allowed_makes == "*" || + allowed_makes == absl::AsciiStrToUpper(manufacturer)) { + return true; + } + } + return false; +} + +bool DeviceStatusList::IsTestOnlyDeviceAllowedByProvider( + uint32_t system_id, const std::string& provider) { + absl::ReaderMutexLock lock(&allowed_test_only_devices_mutex_); + std::pair::iterator, + std::multimap::iterator> + allowed_providers = + allowed_test_only_devices_by_provider_.equal_range(system_id); + for (auto it = allowed_providers.first; it != allowed_providers.second; + ++it) { + std::string allowed_provider = (*it).second; + if (allowed_provider == "*" || + allowed_provider == absl::AsciiStrToUpper(provider)) { return true; } } @@ -321,7 +356,8 @@ bool DeviceStatusList::IsTestOnlyDeviceAllowed(uint32_t system_id, Status DeviceStatusList::DetermineAndDeserializeServiceResponse( const std::string& service_response, DeviceCertificateStatusList* certificate_status_list, - std::string* serialized_certificate_status_list, std::string* signature) { + std::string* serialized_certificate_status_list, + HashAlgorithm* hash_algorithm, std::string* signature) { if (certificate_status_list == nullptr) { return Status(error_space, error::INVALID_ARGUMENT, "certificate_status_list is empty"); @@ -337,13 +373,15 @@ Status DeviceStatusList::DetermineAndDeserializeServiceResponse( // payload. If that doesn't match, then the method will try to parse the // serialized PublishedDeviceInfo proto. Status status = ExtractPublishedDevicesInfo( - service_response, serialized_certificate_status_list, signature); + service_response, serialized_certificate_status_list, hash_algorithm, + signature); // If the payload was not correctly parsed as a PublishedDevices proto. // then attempt to parse it as a legacy payload. if (!status.ok()) { - status = ExtractLegacyDeviceList( - service_response, serialized_certificate_status_list, signature); + status = ExtractLegacyDeviceList(service_response, + serialized_certificate_status_list, + hash_algorithm, signature); // The payload could not be parsed in either format, return the failure // information. if (!status.ok()) { @@ -361,7 +399,8 @@ Status DeviceStatusList::DetermineAndDeserializeServiceResponse( Status DeviceStatusList::ExtractLegacyDeviceList( const std::string& raw_certificate_provisioning_service_response, - std::string* serialized_certificate_status_list, std::string* signature) { + std::string* serialized_certificate_status_list, + HashAlgorithm* hash_algorithm, std::string* signature) { // First, attempt to extract the legacy JSON response. Example legacy format. // "signedList":"" // where the b64 encoded data is a DeviceCertificateStatusListResponse. @@ -424,12 +463,13 @@ Status DeviceStatusList::ExtractLegacyDeviceList( // and extract the serialized status list and signature. return ParseLegacySignedDeviceCertificateStatusList( serialized_signed_certificate_status_list, - serialized_certificate_status_list, signature); + serialized_certificate_status_list, hash_algorithm, signature); } Status DeviceStatusList::ExtractPublishedDevicesInfo( const std::string& serialized_published_devices, - std::string* serialized_certificate_status_list, std::string* signature) { + std::string* serialized_certificate_status_list, + HashAlgorithm* hash_algorithm, std::string* signature) { // TODO(b/139067045): Change from using the SignedDeviceInfo proto // to using the correct proto from the API. This duplicate, wire-compatible // proto was a temporary way to workaround Proto2/Proto3 compatibility issues. @@ -440,6 +480,7 @@ Status DeviceStatusList::ExtractPublishedDevicesInfo( } *serialized_certificate_status_list = devices_info.device_certificate_status_list(); + *hash_algorithm = HashAlgorithmProtoToEnum(devices_info.hash_algorithm()); *signature = devices_info.signature(); return OkStatus(); } @@ -470,12 +511,12 @@ Status DeviceStatusList::GenerateSignedDeviceCertificateStatusListRequest( DrmServiceCertificate::GetDefaultDrmServiceCertificate(); if (sc == nullptr) { signed_device_certificate_status_list_request->clear(); - return Status(error_space, widevine::INVALID_SERVICE_CERTIFICATE, + return Status(error_space, widevine::SERVICE_CERTIFICATE_NOT_FOUND, "Drm service certificate is not loaded."); } const RsaPrivateKey* private_key = sc->private_key(); if (private_key == nullptr) { - return Status(error_space, widevine::INVALID_SERVICE_CERTIFICATE, + return Status(error_space, widevine::INVALID_SERVICE_PRIVATE_KEY, "Private key in the service certificate is null."); } std::string signature; @@ -490,7 +531,7 @@ Status DeviceStatusList::GenerateSignedDeviceCertificateStatusListRequest( Status DeviceStatusList::ParseLegacySignedDeviceCertificateStatusList( const std::string& serialized_signed_device_certificate_status_list, std::string* serialized_device_certificate_status_list, - std::string* signature) { + HashAlgorithm* hash_algorithm, std::string* signature) { // Parse the serialized_signed_device_certificate_status_list to extract the // serialized_device_certificate_status_list SignedDeviceCertificateStatusList signed_device_list; @@ -509,6 +550,8 @@ Status DeviceStatusList::ParseLegacySignedDeviceCertificateStatusList( } *serialized_device_certificate_status_list = signed_device_list.certificate_status_list(); + *hash_algorithm = + HashAlgorithmProtoToEnum(signed_device_list.hash_algorithm()); *signature = signed_device_list.signature(); return OkStatus(); } @@ -532,4 +575,24 @@ bool DeviceStatusList::IsDrmCertificateRevoked( return false; } +Status DeviceStatusList::GetDeviceCertificateStatusBySystemId( + uint32_t system_id, DeviceCertificateStatus* device_certificate_status) { + absl::ReaderMutexLock lock(&status_map_lock_); + if (expiration_period_seconds_ && + (GetCurrentTime() > + (creation_time_seconds_ + expiration_period_seconds_))) { + return Status(error_space, EXPIRED_CERTIFICATE_STATUS_LIST, + "certificate-status-list-expired"); + } + DeviceCertificateStatus* device_cert_status = + gtl::FindOrNull(device_status_map_, system_id); + if (device_cert_status == nullptr) { + return Status(error_space, DRM_DEVICE_CERTIFICATE_UNKNOWN, + "device-certificate-status-unknown"); + } else { + *device_certificate_status = *device_cert_status; + } + return OkStatus(); +} + } // namespace widevine diff --git a/common/device_status_list.h b/common/device_status_list.h index a4e19a0..7607d7f 100644 --- a/common/device_status_list.h +++ b/common/device_status_list.h @@ -16,6 +16,7 @@ #include #include "absl/synchronization/mutex.h" +#include "common/hash_algorithm.h" #include "common/status.h" #include "protos/public/device_certificate_status.pb.h" #include "protos/public/provisioned_device_info.pb.h" @@ -48,7 +49,8 @@ class DeviceStatusList { Status UpdateStatusList( const std::string& root_certificate_public_key, const std::string& serialized_device_certificate_status_list, - const std::string& signature, uint32_t expiration_period_seconds); + HashAlgorithm hash_algorithm, const std::string& signature, + uint32_t expiration_period_seconds); void set_allow_unknown_devices(bool flag) { allow_unknown_devices_ = flag; } bool allow_unknown_devices() const { return allow_unknown_devices_; } void set_allow_test_only_devices(bool allow) { @@ -63,14 +65,15 @@ class DeviceStatusList { // INVALID_DRM_CERTIFICATE // DRM_DEVICE_CERTIFICATE_REVOKED // DRM_DEVICE_CERTIFICATE_UNKNOWN - // If a TEST_ONLY device using "make" as identified by |device_manufacturer|, - // was not whitelisted, then will return // DEVELOPMENT_CERTIFICATE_NOT_ALLOWED - // If status is OK, a copy of the provisioned device info is copied - // into |device_info|. Caller owns |device_info| and it must not be null. - Status GetCertStatus(const ClientCert& client_cert, - const std::string& device_manufacturer, - widevine::ProvisionedDeviceInfo* device_info); + // |provider| is the service provider making the license request. + // If status is OK, a copy of the device certificate status is copied + // into |device_certificate_status|. Caller owns |device_certificate_status| + // and it must not be null. + Status GetCertStatus( + const ClientCert& client_cert, const std::string& make, + const std::string& provider, bool allow_revoked_system_id, + widevine::DeviceCertificateStatus* device_certificate_status); // Returns true if the pre-provisioning key or certificate for the specified // system ID are active (not disallowed or revoked). bool IsSystemIdActive(uint32_t system_id); @@ -107,8 +110,16 @@ class DeviceStatusList { // of the format : // Example usage: // const std::string device_list = "4121:LG,7912:*" - // AllowTestOnlyDevices(device_list); - virtual void AllowTestOnlyDevices(const std::string& device_list); + // AllowTestOnlyDevicesByMake(device_list_by_make); + virtual void AllowTestOnlyDevicesByMake( + const std::string& device_list_by_make); + + // Same as above, except by providers instead of by manufacturers. + // Example usage: + // const std::string device_list = "4121:YouTube,4121:AndroidVideo" + // AllowTestOnlyDevicesByProvider(device_list); + virtual void AllowTestOnlyDevicesByProvider( + const std::string& device_list_by_provider); // A comma separated list of DRM Certificate Serial Numbers that are revoked. virtual void RevokedDrmCertificateSerialNumbers( @@ -119,6 +130,12 @@ class DeviceStatusList { bool IsDrmCertificateRevoked( const std::string& device_certificate_serial_number); + // Returns OK if |system_id| was found in the device certificate status list + // and |device_certificate_status| is populated. If |system_id| is not found, + // this call returns an error. + virtual Status GetDeviceCertificateStatusBySystemId( + uint32_t system_id, DeviceCertificateStatus* device_certificate_status); + // Parses the serialized certificate status list and the signature from the // service_response. The service_response is the JSON payload that comes // in the response to a certificate status list request. Both the legacy @@ -139,11 +156,13 @@ class DeviceStatusList { // serialized proto against the |signature|. // The |signature| is the signature of the serialized_certificate_status_list // using RSASSA-PSS signed with the root certificate private key. + // The |hash_algorithm| is the hash algorithm used in signature. // Returns WvPLStatus - Status::OK if success, else error. static Status DetermineAndDeserializeServiceResponse( const std::string& service_response, DeviceCertificateStatusList* certificate_status_list, - std::string* serialized_certificate_status_list, std::string* signature); + std::string* serialized_certificate_status_list, + HashAlgorithm* hash_algorithm, std::string* signature); /** * Constructs signed device certificate status list request string. @@ -157,6 +176,20 @@ class DeviceStatusList { const std::string& serialized_service_certificate, std::string* signed_device_certificate_status_list_request); + // Returns true if the system ID is allowed to be revoked. + // Caller owns |system_id|. They must not be null. + bool IsRevokedSystemIdAllowed(uint32_t system_id); + + // Returns true if the device, which is identified by system_id and + // device_manufacturer, is present in |allowed_test_only_devices_by_make_|. + bool IsTestOnlyDeviceAllowedByMake(uint32_t system_id, + const std::string& device_manufacturer); + + // Returns true if the device, which is identified by system_id and + // provider, is present in |allowed_test_only_devices_by_provider_|. + bool IsTestOnlyDeviceAllowedByProvider(uint32_t system_id, + const std::string& provider); + private: friend class DeviceStatusListTest; @@ -167,12 +200,14 @@ class DeviceStatusList { * * @param legacy_certificate_provisioning_service_response * @param serialized_certificate_status_list + * @param hash_algorithm * @param signature * @return WvPLStatus - Status::OK if success, else error. */ static Status ExtractLegacyDeviceList( const std::string& raw_certificate_provisioning_service_response, - std::string* serialized_certificate_status_list, std::string* signature); + std::string* serialized_certificate_status_list, + HashAlgorithm* hash_algorithm, std::string* signature); /** * Parses the serialized published devices response. @@ -182,12 +217,14 @@ class DeviceStatusList { * @param published_devices_response the serialized PublishedDevices proto * containing the certificate status list. * @param serialized_certificate_status_list + * @param hash_algorithm * @param signature * @return WvPLStatus - Status::OK if success, else error. */ static Status ExtractPublishedDevicesInfo( const std::string& serialized_published_devices, - std::string* serialized_certificate_status_list, std::string* signature); + std::string* serialized_certificate_status_list, + HashAlgorithm* hash_algorithm, std::string* signature); /** * Returns a |serialized_device_certificate_status_list| in its output @@ -196,25 +233,28 @@ class DeviceStatusList { * * @param serialized_signed_device_certificate_status_list * @param serialized_device_certificate_status_list - * + * @param hash_algorithm * @return Status - Status::OK if success, else error. */ static Status ParseLegacySignedDeviceCertificateStatusList( const std::string& serialized_signed_device_certificate_status_list, std::string* serialized_device_certificate_status_list, - std::string* signature); + HashAlgorithm* hash_algorithm, std::string* signature); - // Returns true if the system ID is allowed to be revoked. - // Caller owns |system_id|. They must not be null. - bool IsRevokedSystemIdAllowed(uint32_t system_id); - // Returns true if the device, which is identified by system_id and - // device_manufacturer, is present in |allowed_test_only_devices_|. - bool IsTestOnlyDeviceAllowed(uint32_t system_id, - const std::string device_manufacturer); + virtual size_t allowed_test_only_devices_by_make_size() { + absl::ReaderMutexLock lock(&allowed_test_only_devices_mutex_); + return allowed_test_only_devices_by_make_.size(); + } - absl::Mutex status_map_lock_; + virtual size_t allowed_test_only_devices_by_provider_size() { + absl::ReaderMutexLock lock(&allowed_test_only_devices_mutex_); + return allowed_test_only_devices_by_provider_.size(); + } + + mutable absl::Mutex status_map_lock_; // Key is the system id for the device. - std::map device_status_map_; + std::map device_status_map_ + ABSL_GUARDED_BY(status_map_lock_); uint32_t creation_time_seconds_ = 0; uint32_t expiration_period_seconds_ = 0; bool allow_unknown_devices_ = false; @@ -222,10 +262,15 @@ class DeviceStatusList { // Contains the list of system_id values that are allowed to succeed even if // revoked. std::vector allowed_revoked_devices_; - absl::Mutex allowed_test_only_devices_mutex_; - // Contains a map of 'system_id' to 'make'. If 'make' value is "*", any - // 'make' for that 'system_id' is allowed. - std::multimap allowed_test_only_devices_; + mutable absl::Mutex allowed_test_only_devices_mutex_; + // Contains a map of 'system_id' to 'manufacturer'. If manufacturer value is + // "*", any manufacturer using that system_id is allowed. + std::multimap allowed_test_only_devices_by_make_ + ABSL_GUARDED_BY(allowed_test_only_devices_mutex_); + // Contains a map of 'system_id' to 'provider'. If provider value is "*", any + // provider using that system_id is allowed. + std::multimap allowed_test_only_devices_by_provider_ + ABSL_GUARDED_BY(allowed_test_only_devices_mutex_); // Revoked DRM certificate serial numbers. std::set revoked_drm_certificate_serial_numbers_; }; diff --git a/common/device_status_list_test.cc b/common/device_status_list_test.cc index 3b4dcaf..6ab8032 100644 --- a/common/device_status_list_test.cc +++ b/common/device_status_list_test.cc @@ -21,6 +21,8 @@ #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "common/client_cert.h" +#include "common/hash_algorithm.h" +#include "common/hash_algorithm_util.h" #include "common/keybox_client_cert.h" #include "common/rsa_key.h" #include "common/rsa_test_keys.h" @@ -34,20 +36,19 @@ namespace { const char kTestSystemId_1[] = "4121"; -const char kTestManufacturer_LG[] = "LG"; -const char kTestManufacturer_LGE[] = "LGE"; const char kTestSystemId_2[] = "8242"; -const char kTestManufacturer_Samsung[] = "Samsung"; const char kTestSystemId_3[] = "6556"; const char kTestManufacturer[] = "TestManufacturer"; +const char kTestProvider[] = "TestProvider"; +const char kRevokedManufacturer[] = "RevokedManufacturer"; +const widevine::HashAlgorithm kHashAlgorithm = + widevine::HashAlgorithm::kSha256; } // namespace namespace widevine { -using ::testing::_; using ::testing::Return; using ::testing::ReturnRef; -using ::testing::ReturnRefOfCopy; const uint32_t kValidCertSystemId = 100; const uint32_t kRevokedCertSystemId = 101; @@ -66,32 +67,40 @@ const char kRevokedUniqueIdentifiers[] = "revoked-unique-identifiers"; const char kTestOnlySerialNumber[] = "test_only-serial-number"; const char kMismatchSerialNumber[] = "mismatch-serial-number"; const char kDeviceModel[] = "device-model-x"; +const char kRevokedDeviceModel[] = "device-model-revoked"; const char kTestPreprovKey[] = "00112233445566778899aabbccddeeff"; const uint32_t kStatusListCreationTime = 17798001; const uint32_t kDefaultExpirePeriod = 0; +const bool kDenyRevokedDevice = false; +const bool kAllowRevokedDevice = true; class MockClientCert : public ClientCert { public: MockClientCert() {} ~MockClientCert() override {} - MOCK_CONST_METHOD0(system_id, uint32_t()); - MOCK_CONST_METHOD0(signer_serial_number, std::string &()); - MOCK_CONST_METHOD0(signer_creation_time_seconds, uint32_t()); - MOCK_CONST_METHOD0(type, ClientIdentification::TokenType()); - MOCK_CONST_METHOD0(encrypted_unique_id, const std::string &()); - MOCK_CONST_METHOD0(unique_id_hash, const std::string &()); - MOCK_CONST_METHOD0(signed_by_provisioner, bool()); - MOCK_CONST_METHOD3(VerifySignature, Status(const std::string &message, - const std::string &signature, - ProtocolVersion protocol_version)); - MOCK_METHOD2(GenerateSigningKey, void(const std::string &message, - ProtocolVersion protocol_version)); - MOCK_CONST_METHOD0(serial_number, const std::string &()); - MOCK_CONST_METHOD0(key, const std::string &()); - MOCK_CONST_METHOD0(key_type, SignedMessage::SessionKeyType()); - MOCK_CONST_METHOD0(service_id, const std::string &()); - MOCK_CONST_METHOD0(encrypted_key, const std::string &()); - MOCK_CONST_METHOD0(signing_key, const std::string &()); + MOCK_METHOD(uint32_t, system_id, (), (const, override)); + MOCK_METHOD(std::string &, signer_serial_number, (), (const, override)); + MOCK_METHOD(uint32_t, signer_creation_time_seconds, (), (const, override)); + MOCK_METHOD(ClientIdentification::TokenType, type, (), (const, override)); + MOCK_METHOD(const std::string &, encrypted_unique_id, (), (const, override)); + MOCK_METHOD(const std::string &, unique_id_hash, (), (const, override)); + MOCK_METHOD(bool, signed_by_provisioner, (), (const, override)); + MOCK_METHOD(Status, VerifySignature, + (const std::string &message, HashAlgorithm hash_algorithm, + const std::string &signature, ProtocolVersion protocol_version), + (const, override)); + MOCK_METHOD(void, GenerateSigningKey, + (const std::string &message, ProtocolVersion protocol_version), + (override)); + MOCK_METHOD(const std::string &, serial_number, (), (const, override)); + MOCK_METHOD(const std::string &, key, (), (const, override)); + MOCK_METHOD(SignedMessage::SessionKeyType, key_type, (), (const, override)); + MOCK_METHOD(bool, using_dual_certificate, (), (const override)); + MOCK_METHOD(const std::string &, service_id, (), (const, override)); + MOCK_METHOD(const std::string &, encrypted_key, (), (const, override)); + MOCK_METHOD(const std::string &, signing_key, (), (const, override)); + MOCK_METHOD(Status, SystemIdUnknownError, (), (const, override)); + MOCK_METHOD(Status, SystemIdRevokedError, (), (const, override)); }; class DeviceStatusListTest : public ::testing::Test { @@ -116,6 +125,7 @@ class DeviceStatusListTest : public ::testing::Test { cert_status = cert_status_list_.add_certificate_status(); cert_status->mutable_device_info()->set_system_id(kRevokedCertSystemId); cert_status->set_drm_serial_number(kRevokedSerialNumber); + cert_status->mutable_device_info()->set_model(kRevokedDeviceModel); cert_status->set_status(DeviceCertificateStatus::STATUS_REVOKED); // Device cert with status REVOKED ALLOWED DEVICE. @@ -141,6 +151,7 @@ class DeviceStatusListTest : public ::testing::Test { ASSERT_TRUE(root_key); cert_status_list_.SerializeToString(&serialized_cert_status_list_); ASSERT_TRUE(root_key->GenerateSignature(serialized_cert_status_list_, + kHashAlgorithm, &cert_status_list_signature_)); // Update the device_status_list_ with the serialized status list @@ -148,11 +159,12 @@ class DeviceStatusListTest : public ::testing::Test { ASSERT_EQ(OkStatus(), device_status_list_.UpdateStatusList( test_keys_.public_test_key_1_3072_bits(), - serialized_cert_status_list_, cert_status_list_signature_, - kDefaultExpirePeriod)); + serialized_cert_status_list_, kHashAlgorithm, + cert_status_list_signature_, kDefaultExpirePeriod)); } void GenerateTrivialValidStatusList(std::string *serialized_cert_status_list, + HashAlgorithm hash_algorithm, std::string *signature) { DeviceCertificateStatusList cert_status_list; DeviceCertificateStatus *cert_status; @@ -171,17 +183,27 @@ class DeviceStatusListTest : public ::testing::Test { RsaPrivateKey::Create(test_keys_.private_test_key_1_3072_bits())); ASSERT_TRUE(root_key); cert_status_list.SerializeToString(serialized_cert_status_list); - ASSERT_TRUE( - root_key->GenerateSignature(*serialized_cert_status_list, signature)); + ASSERT_TRUE(root_key->GenerateSignature(*serialized_cert_status_list, + hash_algorithm, signature)); } - int VerifyAllowedTestOnlyDevicesAdded() { - return device_status_list_.allowed_test_only_devices_.size(); + int AllowedTestOnlyDevicesByMakeSize() { + return device_status_list_.allowed_test_only_devices_by_make_size(); } - bool VerifyIsTestOnlyDeviceAllowed(uint32_t system_id, - std::string manufacturer) { - return device_status_list_.IsTestOnlyDeviceAllowed(system_id, manufacturer); + int AllowedTestOnlyDevicesByProviderSize() { + return device_status_list_.allowed_test_only_devices_by_provider_size(); + } + + bool IsTestOnlyDeviceAllowedByMake(uint32_t system_id, + const std::string &make) { + return device_status_list_.IsTestOnlyDeviceAllowedByMake(system_id, make); + } + + bool IsTestOnlyDeviceAllowedByProvider(uint32_t system_id, + const std::string &provider) { + return device_status_list_.IsTestOnlyDeviceAllowedByProvider(system_id, + provider); } int VerifyRevokedDeviceCertificatesCount() { @@ -203,7 +225,7 @@ class DeviceStatusListTest : public ::testing::Test { TEST_F(DeviceStatusListTest, CheckForValidAndRevokedCert) { // Test case where the Certificate status is set to Valid. - ProvisionedDeviceInfo device_info; + DeviceCertificateStatus device_certificate_status; MockClientCert valid_client_cert; std::string valid_drm_serial_number(kValidSerialNumber); EXPECT_CALL(valid_client_cert, type()) @@ -212,9 +234,10 @@ TEST_F(DeviceStatusListTest, CheckForValidAndRevokedCert) { .WillRepeatedly(Return(kValidCertSystemId)); EXPECT_CALL(valid_client_cert, signer_serial_number()) .WillRepeatedly(ReturnRef(valid_drm_serial_number)); - EXPECT_EQ(OkStatus(), - device_status_list_.GetCertStatus(valid_client_cert, - kTestManufacturer, &device_info)); + EXPECT_EQ(OkStatus(), device_status_list_.GetCertStatus( + valid_client_cert, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status)); + ProvisionedDeviceInfo device_info = device_certificate_status.device_info(); EXPECT_TRUE(device_info.has_model()); EXPECT_EQ(kDeviceModel, device_info.model()); @@ -227,20 +250,25 @@ TEST_F(DeviceStatusListTest, CheckForValidAndRevokedCert) { .WillRepeatedly(Return(kRevokedCertSystemId)); EXPECT_CALL(revoked_client_cert, signer_serial_number()) .WillRepeatedly(ReturnRef(revoked_drm_serial_number)); + EXPECT_CALL(revoked_client_cert, SystemIdRevokedError()) + .WillRepeatedly( + Return(Status(error_space, DRM_DEVICE_CERTIFICATE_REVOKED, ""))); EXPECT_EQ( DRM_DEVICE_CERTIFICATE_REVOKED, device_status_list_ - .GetCertStatus(revoked_client_cert, kTestManufacturer, &device_info) + .GetCertStatus(revoked_client_cert, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status) .error_code()); // Test case where the revoked cert is allowed. device_status_list_.AllowRevokedDevices(absl::StrCat(kRevokedCertSystemId)); - EXPECT_OK(device_status_list_.GetCertStatus(revoked_client_cert, - kTestManufacturer, &device_info)); + EXPECT_OK(device_status_list_.GetCertStatus( + revoked_client_cert, kTestManufacturer, kTestProvider, kDenyRevokedDevice, + &device_certificate_status)); } TEST_F(DeviceStatusListTest, TestOnlyCertNotAllowed) { - ProvisionedDeviceInfo device_info; + DeviceCertificateStatus device_certificate_status; MockClientCert test_only_client_cert; std::string test_only_drm_serial_number(kTestOnlySerialNumber); EXPECT_CALL(test_only_client_cert, type()) @@ -249,11 +277,12 @@ TEST_F(DeviceStatusListTest, TestOnlyCertNotAllowed) { .WillRepeatedly(Return(kTestOnlyCertSystemId)); EXPECT_CALL(test_only_client_cert, signer_serial_number()) .WillRepeatedly(ReturnRef(test_only_drm_serial_number)); - EXPECT_EQ( - DEVELOPMENT_CERTIFICATE_NOT_ALLOWED, - device_status_list_ - .GetCertStatus(test_only_client_cert, kTestManufacturer, &device_info) - .error_code()); + EXPECT_EQ(DEVELOPMENT_CERTIFICATE_NOT_ALLOWED, + device_status_list_ + .GetCertStatus(test_only_client_cert, kTestManufacturer, + kTestProvider, kDenyRevokedDevice, + &device_certificate_status) + .error_code()); } TEST_F(DeviceStatusListTest, GetRevokedIfentifiers) { @@ -269,7 +298,7 @@ TEST_F(DeviceStatusListTest, GetRevokedIfentifiers) { } TEST_F(DeviceStatusListTest, TestOnlyCertAllowed) { - ProvisionedDeviceInfo device_info; + DeviceCertificateStatus device_certificate_status; MockClientCert test_only_client_cert; std::string test_only_drm_serial_number(kTestOnlySerialNumber); device_status_list_.set_allow_test_only_devices(true); @@ -280,41 +309,98 @@ TEST_F(DeviceStatusListTest, TestOnlyCertAllowed) { EXPECT_CALL(test_only_client_cert, signer_serial_number()) .WillRepeatedly(ReturnRef(test_only_drm_serial_number)); EXPECT_EQ(OkStatus(), - device_status_list_.GetCertStatus(test_only_client_cert, - kTestManufacturer, &device_info)); + device_status_list_.GetCertStatus( + test_only_client_cert, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status)); } -TEST_F(DeviceStatusListTest, ValidAndUnknownKeybox) { +TEST_F(DeviceStatusListTest, RevokedSystemIdAllowed) { + DeviceCertificateStatus device_certificate_status; + MockClientCert revoked_client_cert; + std::string revoked_drm_serial_number(kRevokedSerialNumber); + EXPECT_CALL(revoked_client_cert, type()) + .WillRepeatedly(Return(ClientIdentification::DRM_DEVICE_CERTIFICATE)); + EXPECT_CALL(revoked_client_cert, system_id()) + .WillRepeatedly(Return(kRevokedCertSystemId)); + EXPECT_CALL(revoked_client_cert, signer_serial_number()) + .WillRepeatedly(ReturnRef(revoked_drm_serial_number)); + EXPECT_EQ(OkStatus(), + device_status_list_.GetCertStatus( + revoked_client_cert, kRevokedManufacturer, kTestProvider, + kAllowRevokedDevice, &device_certificate_status)); +} + +// Test case where the Certificate status is set to Valid. +TEST_F(DeviceStatusListTest, ValidKeybox) { std::multimap preprov_keys; preprov_keys.insert(std::make_pair(kValidCertSystemId, kTestPreprovKey)); KeyboxClientCert::SetPreProvisioningKeys(preprov_keys); - - // Test case where the Certificate status is set to Valid. - ProvisionedDeviceInfo device_info; + DeviceCertificateStatus device_certificate_status; MockClientCert valid_client_keybox; + std::string valid_drm_serial_number(kValidSerialNumber); EXPECT_CALL(valid_client_keybox, type()) .WillRepeatedly(Return(ClientIdentification::KEYBOX)); EXPECT_CALL(valid_client_keybox, system_id()) .WillRepeatedly(Return(kValidCertSystemId)); EXPECT_EQ(OkStatus(), - device_status_list_.GetCertStatus(valid_client_keybox, - kTestManufacturer, &device_info)); - EXPECT_TRUE(device_info.has_model()); + device_status_list_.GetCertStatus( + valid_client_keybox, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status)); + ProvisionedDeviceInfo device_info = device_certificate_status.device_info(); + ASSERT_TRUE(device_info.has_model()); EXPECT_EQ(kDeviceModel, device_info.model()); +} +// Test case where the keybox was not loaded into the pre-prov list. +TEST_F(DeviceStatusListTest, UnknownKeybox) { + std::multimap preprov_keys; + preprov_keys.insert(std::make_pair(kValidCertSystemId, kTestPreprovKey)); + KeyboxClientCert::SetPreProvisioningKeys(preprov_keys); + DeviceCertificateStatus device_certificate_status; MockClientCert unknown_client_keybox; + EXPECT_CALL(unknown_client_keybox, type()) .WillRepeatedly(Return(ClientIdentification::KEYBOX)); EXPECT_CALL(unknown_client_keybox, system_id()) .WillRepeatedly(Return(kUnknownSystemId)); - EXPECT_EQ( - UNSUPPORTED_SYSTEM_ID, - device_status_list_ - .GetCertStatus(unknown_client_keybox, kTestManufacturer, &device_info) - .error_code()); - EXPECT_TRUE(device_info.has_model()); - EXPECT_EQ(kDeviceModel, device_info.model()); + EXPECT_CALL(unknown_client_keybox, SystemIdUnknownError()) + .WillRepeatedly(Return(Status(error_space, UNSUPPORTED_SYSTEM_ID, ""))); + EXPECT_EQ(UNSUPPORTED_SYSTEM_ID, + device_status_list_ + .GetCertStatus(unknown_client_keybox, kTestManufacturer, + kTestProvider, kDenyRevokedDevice, + &device_certificate_status) + .error_code()); + ProvisionedDeviceInfo device_info = device_certificate_status.device_info(); + ASSERT_FALSE(device_info.has_model()); +} + +// Test case where the keybox was loaded into the pre-prov list but it's +// certificate status is REVOKED. +TEST_F(DeviceStatusListTest, RevokedKeybox) { + std::multimap preprov_keys; + preprov_keys.insert(std::make_pair(kRevokedCertSystemId, kTestPreprovKey)); + KeyboxClientCert::SetPreProvisioningKeys(preprov_keys); + DeviceCertificateStatus device_certificate_status; + MockClientCert revoked_client_keybox; + + EXPECT_CALL(revoked_client_keybox, type()) + .WillRepeatedly(Return(ClientIdentification::KEYBOX)); + EXPECT_CALL(revoked_client_keybox, system_id()) + .WillRepeatedly(Return(kRevokedCertSystemId)); + EXPECT_CALL(revoked_client_keybox, SystemIdRevokedError()) + .WillRepeatedly( + Return(Status(error_space, DRM_DEVICE_CERTIFICATE_REVOKED, ""))); + EXPECT_EQ(DRM_DEVICE_CERTIFICATE_REVOKED, + device_status_list_ + .GetCertStatus(revoked_client_keybox, kTestManufacturer, + kTestProvider, kDenyRevokedDevice, + &device_certificate_status) + .error_code()); + ProvisionedDeviceInfo device_info = device_certificate_status.device_info(); + ASSERT_TRUE(device_info.has_model()); + EXPECT_EQ(kRevokedDeviceModel, device_info.model()); } TEST_F(DeviceStatusListTest, SignerSerialNumberMismatch) { @@ -323,7 +409,7 @@ TEST_F(DeviceStatusListTest, SignerSerialNumberMismatch) { // Test case where the signer certificate is older than the current status // list. MockClientCert older_client_cert; - ProvisionedDeviceInfo device_info; + DeviceCertificateStatus device_certificate_status; std::string mismatch_drm_serial_number(kMismatchSerialNumber); EXPECT_CALL(older_client_cert, type()) .WillRepeatedly(Return(ClientIdentification::DRM_DEVICE_CERTIFICATE)); @@ -336,20 +422,23 @@ TEST_F(DeviceStatusListTest, SignerSerialNumberMismatch) { EXPECT_EQ( INVALID_DRM_CERTIFICATE, device_status_list_ - .GetCertStatus(older_client_cert, kTestManufacturer, &device_info) + .GetCertStatus(older_client_cert, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status) .error_code()); // We allow this case only for certs signed by a provisioner cert. EXPECT_CALL(older_client_cert, signed_by_provisioner()) .WillOnce(Return(true)); - EXPECT_EQ(OkStatus(), - device_status_list_.GetCertStatus(older_client_cert, - kTestManufacturer, &device_info)); + EXPECT_EQ(OkStatus(), device_status_list_.GetCertStatus( + older_client_cert, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status)); + ProvisionedDeviceInfo device_info = device_certificate_status.device_info(); EXPECT_TRUE(device_info.has_system_id()); EXPECT_EQ(kValidCertSystemId, device_info.system_id()); // Test case where the signer certificate is newer than the current status // list, and unknown devices are allowed. + device_certificate_status.Clear(); MockClientCert newer_client_cert1; EXPECT_CALL(newer_client_cert1, type()) .WillRepeatedly(Return(ClientIdentification::DRM_DEVICE_CERTIFICATE)); @@ -359,14 +448,19 @@ TEST_F(DeviceStatusListTest, SignerSerialNumberMismatch) { .WillRepeatedly(ReturnRef(mismatch_drm_serial_number)); EXPECT_CALL(newer_client_cert1, signer_creation_time_seconds()) .WillRepeatedly(Return(kStatusListCreationTime)); + EXPECT_CALL(newer_client_cert1, SystemIdUnknownError()) + .WillRepeatedly( + Return(Status(error_space, DRM_DEVICE_CERTIFICATE_UNKNOWN, ""))); EXPECT_EQ( DRM_DEVICE_CERTIFICATE_UNKNOWN, device_status_list_ - .GetCertStatus(newer_client_cert1, kTestManufacturer, &device_info) + .GetCertStatus(newer_client_cert1, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status) .error_code()); // Test case where the signer certificate is newer than the current status // list, and unknown devices are not allowed. + device_certificate_status.Clear(); device_status_list_.set_allow_unknown_devices(false); MockClientCert newer_client_cert2; EXPECT_CALL(newer_client_cert2, type()) @@ -377,10 +471,14 @@ TEST_F(DeviceStatusListTest, SignerSerialNumberMismatch) { .WillRepeatedly(ReturnRef(mismatch_drm_serial_number)); EXPECT_CALL(newer_client_cert2, signer_creation_time_seconds()) .WillRepeatedly(Return(kStatusListCreationTime + 1)); + EXPECT_CALL(newer_client_cert2, SystemIdUnknownError()) + .WillRepeatedly( + Return(Status(error_space, DRM_DEVICE_CERTIFICATE_UNKNOWN, ""))); EXPECT_EQ( DRM_DEVICE_CERTIFICATE_UNKNOWN, device_status_list_ - .GetCertStatus(newer_client_cert2, kTestManufacturer, &device_info) + .GetCertStatus(newer_client_cert2, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status) .error_code()); } @@ -388,7 +486,7 @@ TEST_F(DeviceStatusListTest, InvalidStatusList) { EXPECT_EQ(INVALID_CERTIFICATE_STATUS_LIST, device_status_list_ .UpdateStatusList(test_keys_.public_test_key_2_2048_bits(), - serialized_cert_status_list_, + serialized_cert_status_list_, kHashAlgorithm, cert_status_list_signature_, 0) .error_code()); @@ -396,14 +494,14 @@ TEST_F(DeviceStatusListTest, InvalidStatusList) { EXPECT_EQ(INVALID_CERTIFICATE_STATUS_LIST, device_status_list_ .UpdateStatusList(test_keys_.public_test_key_1_3072_bits(), - serialized_cert_status_list_, + serialized_cert_status_list_, kHashAlgorithm, cert_status_list_signature_, 0) .error_code()); } class MockDeviceStatusList : public DeviceStatusList { public: - MOCK_CONST_METHOD0(GetCurrentTime, uint32_t()); + MOCK_METHOD(uint32_t, GetCurrentTime, (), (const, override)); }; TEST_F(DeviceStatusListTest, ExpiredStatusListOnSet) { @@ -414,12 +512,12 @@ TEST_F(DeviceStatusListTest, ExpiredStatusListOnSet) { .WillOnce(Return(kStatusListCreationTime + 101)); EXPECT_EQ(OkStatus(), mock_device_status_list.UpdateStatusList( test_keys_.public_test_key_1_3072_bits(), - serialized_cert_status_list_, + serialized_cert_status_list_, kHashAlgorithm, cert_status_list_signature_, 100)); EXPECT_EQ(EXPIRED_CERTIFICATE_STATUS_LIST, mock_device_status_list .UpdateStatusList(test_keys_.public_test_key_1_3072_bits(), - serialized_cert_status_list_, + serialized_cert_status_list_, kHashAlgorithm, cert_status_list_signature_, 100) .error_code()); } @@ -433,10 +531,10 @@ TEST_F(DeviceStatusListTest, ExpiredStatusListOnCertCheck) { .WillOnce(Return(kStatusListCreationTime + 101)); EXPECT_EQ(OkStatus(), mock_device_status_list.UpdateStatusList( test_keys_.public_test_key_1_3072_bits(), - serialized_cert_status_list_, + serialized_cert_status_list_, kHashAlgorithm, cert_status_list_signature_, 100)); - ProvisionedDeviceInfo device_info; + DeviceCertificateStatus device_certificate_status; MockClientCert valid_client_cert; std::string valid_drm_serial_number(kValidSerialNumber); EXPECT_CALL(valid_client_cert, type()) @@ -447,14 +545,15 @@ TEST_F(DeviceStatusListTest, ExpiredStatusListOnCertCheck) { .WillRepeatedly(ReturnRef(valid_drm_serial_number)); EXPECT_CALL(valid_client_cert, signer_creation_time_seconds()) .WillRepeatedly(Return(kStatusListCreationTime - 1)); - EXPECT_EQ(OkStatus(), - mock_device_status_list.GetCertStatus( - valid_client_cert, kTestManufacturer, &device_info)); + EXPECT_EQ(OkStatus(), mock_device_status_list.GetCertStatus( + valid_client_cert, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status)); EXPECT_EQ( EXPIRED_CERTIFICATE_STATUS_LIST, mock_device_status_list - .GetCertStatus(valid_client_cert, kTestManufacturer, &device_info) + .GetCertStatus(valid_client_cert, kTestManufacturer, kTestProvider, + kDenyRevokedDevice, &device_certificate_status) .error_code()); } @@ -477,49 +576,109 @@ TEST_F(DeviceStatusListTest, IsSystemIdActive) { device_status_list_.IsSystemIdActive(kRevokedAllowedDeviceCertSystemId)); } -TEST_F(DeviceStatusListTest, IsTestOnlyDeviceAllowed) { - std::string whitelisted_device_list = - std::string(kTestSystemId_1) + ":" + std::string(kTestManufacturer_LG); - whitelisted_device_list += "," + std::string(kTestSystemId_2) + ":" + - std::string(kTestManufacturer_Samsung); - whitelisted_device_list += "," + std::string(kTestSystemId_3) + ":"; - whitelisted_device_list += ", " + std::string(kTestSystemId_1) + ":" + - std::string(kTestManufacturer_LGE); - device_status_list_.AllowTestOnlyDevices(whitelisted_device_list); - EXPECT_EQ(4, VerifyAllowedTestOnlyDevicesAdded()); +TEST_F(DeviceStatusListTest, IsTestOnlyDeviceAllowedByMake) { + const char kTestManufacturer_AA[] = "AA"; + const char kTestManufacturer_AAA[] = "AAA"; + const char kTestManufacturer_BBB[] = "BBB"; + const char kTestManufacturer_BbB[] = "BbB"; + const char kTestManufacturer_bbb[] = "bbb"; + const char kTestManufacturer_CCC[] = "CCC"; + const char kTestManufacturer_DDD[] = "AAA"; + std::string allowed_device_list = + std::string(kTestSystemId_1) + ":" + std::string(kTestManufacturer_AA); + allowed_device_list += "," + std::string(kTestSystemId_2) + ":" + + std::string(kTestManufacturer_BBB); + allowed_device_list += "," + std::string(kTestSystemId_3) + ":"; + allowed_device_list += ", " + std::string(kTestSystemId_1) + ":" + + std::string(kTestManufacturer_AAA); + device_status_list_.AllowTestOnlyDevicesByMake(allowed_device_list); + EXPECT_EQ(4, AllowedTestOnlyDevicesByMakeSize()); // Verify that device with system_id = kTestSystemId_1 and - // manufacturer = kTestManufacturer_LG is allowed. - EXPECT_TRUE(VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_1), - kTestManufacturer_LG)); + // manufacturer AA is allowed. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_1), + kTestManufacturer_AA)); // Verify that device with system_id = kTestSystemId_1 and - // manufacturer = kTestManufacturer_LGE is allowed. - EXPECT_TRUE(VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_1), - kTestManufacturer_LGE)); + // manufacturer AAA is allowed. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_1), + kTestManufacturer_AAA)); // Verify that device with system_id = kTestSystemId_2 and - // manufacturer = kTestManufacturer_LGE is not allowed. - // This is because this combination is not 'whitelisted'. - EXPECT_FALSE(VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_2), - kTestManufacturer_LGE)); + // manufacturer AAA is not allowed. + // This is because this combination is not in the allowed list. + EXPECT_FALSE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_2), + kTestManufacturer_AAA)); // Verify that device with system_id = kTestSystemId_2 and - // manufacturer = kTestManufacturer_Samsung is allowed. - EXPECT_TRUE(VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_2), - kTestManufacturer_Samsung)); + // manufacturer BBB is allowed. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_2), + kTestManufacturer_BBB)); // Verifes that device with mixed case succeeds. - EXPECT_TRUE( - VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_2), "samSung")); - EXPECT_TRUE( - VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_2), "SAMsung")); + EXPECT_TRUE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_2), + kTestManufacturer_BbB)); + EXPECT_TRUE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_2), + kTestManufacturer_bbb)); // Verify that device with system_id = kTestSystemId_3 and // any manufacturer is allowed. This checks that any manufacturer is // allowed for this system_id. - EXPECT_TRUE( - VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_3), "Cisco")); - EXPECT_TRUE(VerifyIsTestOnlyDeviceAllowed(std::stoi(kTestSystemId_3), - "ScientificAtlanta")); + EXPECT_TRUE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_3), + kTestManufacturer_CCC)); + EXPECT_TRUE(IsTestOnlyDeviceAllowedByMake(std::stoi(kTestSystemId_3), + kTestManufacturer_DDD)); uint32_t unknown_system_id = 7890; // Verify that device with system_id = unknown_system_id and - // manufacturer = "Cisco" is not allowed. - EXPECT_FALSE(VerifyIsTestOnlyDeviceAllowed(unknown_system_id, "Cisco")); + // manufacturer CCC is not allowed. + EXPECT_FALSE( + IsTestOnlyDeviceAllowedByMake(unknown_system_id, kTestManufacturer_CCC)); +} + +TEST_F(DeviceStatusListTest, IsTestOnlyDeviceAllowedByProvider) { + const char kTestProvider_AA[] = "AA"; + const char kTestProvider_AAA[] = "AAA"; + const char kTestProvider_BBB[] = "BBB"; + const char kTestProvider_BbB[] = "BbB"; + const char kTestProvider_bbb[] = "bbb"; + const char kTestProvider_CCC[] = "CCC"; + std::string allowed_device_list = + std::string(kTestSystemId_1) + ":" + std::string(kTestProvider_AA); + allowed_device_list += + "," + std::string(kTestSystemId_2) + ":" + std::string(kTestProvider_BBB); + allowed_device_list += "," + std::string(kTestSystemId_3) + ":"; + allowed_device_list += ", " + std::string(kTestSystemId_1) + ":" + + std::string(kTestProvider_AAA); + device_status_list_.AllowTestOnlyDevicesByProvider(allowed_device_list); + EXPECT_EQ(4, AllowedTestOnlyDevicesByProviderSize()); + // Verify that device with system_id = kTestSystemId_1 and + // provider AA is allowed. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_1), + kTestProvider_AA)); + // Verify that device with system_id = kTestSystemId_1 and + // provider AAA is allowed. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_1), + kTestProvider_AAA)); + // Verify that device with system_id = kTestSystemId_2 and + // provider AAA is not allowed. + // This is because this combination is not 'whitelisted'. + EXPECT_FALSE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_2), + kTestProvider_AAA)); + // Verify that device with system_id = kTestSystemId_2 and + // provider BBB is allowed. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_2), + kTestProvider_BBB)); + // Verifes that device with mixed case succeeds. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_2), + kTestProvider_BbB)); + EXPECT_TRUE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_2), + kTestProvider_bbb)); + // Verify that device with system_id = kTestSystemId_3 and + // any provider is allowed. This checks that any provider is + // allowed for this system_id. + EXPECT_TRUE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_3), + kTestProvider_CCC)); + EXPECT_TRUE(IsTestOnlyDeviceAllowedByProvider(std::stoi(kTestSystemId_3), + kTestProvider_AAA)); + uint32_t unknown_system_id = 7890; + // Verify that device with system_id = unknown_system_id and + // provider CCC is not allowed. + EXPECT_FALSE( + IsTestOnlyDeviceAllowedByProvider(unknown_system_id, kTestProvider_CCC)); } TEST_F(DeviceStatusListTest, IsDrmDeviceCertificateRevoked) { @@ -552,6 +711,7 @@ TEST_F(DeviceStatusListTest, DetermineAndDeserializeServiceResponseSuccess) { SignedDeviceInfo published_devices; GenerateTrivialValidStatusList( published_devices.mutable_device_certificate_status_list(), + HashAlgorithmProtoToEnum(published_devices.hash_algorithm()), published_devices.mutable_signature()); std::string serialized_published_devices; @@ -561,13 +721,17 @@ TEST_F(DeviceStatusListTest, DetermineAndDeserializeServiceResponseSuccess) { DeviceCertificateStatusList actual_cert_status_list; std::string actual_serialized_cert_status_list; std::string actual_signature; + HashAlgorithm hash_algorithm; ASSERT_EQ(OkStatus(), DeviceStatusList::DetermineAndDeserializeServiceResponse( serialized_published_devices, &actual_cert_status_list, - &actual_serialized_cert_status_list, &actual_signature)); + &actual_serialized_cert_status_list, &hash_algorithm, + &actual_signature)); EXPECT_EQ(published_devices.device_certificate_status_list(), actual_serialized_cert_status_list); EXPECT_EQ(published_devices.signature(), actual_signature); + EXPECT_EQ(HashAlgorithmProtoToEnum(published_devices.hash_algorithm()), + hash_algorithm); DeviceCertificateStatusList expected_cert_status_list; ASSERT_TRUE(expected_cert_status_list.ParseFromString( @@ -580,9 +744,11 @@ TEST_F(DeviceStatusListTest, DetermineAndDeserializeServiceResponseLegacySuccess) { std::string serialized_cert_status_list; std::string signature; - GenerateTrivialValidStatusList(&serialized_cert_status_list, &signature); - SignedDeviceCertificateStatusList legacy_signed_cert_status_list; + GenerateTrivialValidStatusList( + &serialized_cert_status_list, + HashAlgorithmProtoToEnum(legacy_signed_cert_status_list.hash_algorithm()), + &signature); *(legacy_signed_cert_status_list.mutable_certificate_status_list()) = serialized_cert_status_list; *(legacy_signed_cert_status_list.mutable_signature()) = signature; @@ -604,12 +770,17 @@ TEST_F(DeviceStatusListTest, std::string actual_serialized_cert_status_list; std::string actual_signature; DeviceCertificateStatusList actual_cert_status_list; + HashAlgorithm hash_algorithm; ASSERT_EQ(OkStatus(), DeviceStatusList::DetermineAndDeserializeServiceResponse( server_response, &actual_cert_status_list, - &actual_serialized_cert_status_list, &actual_signature)); + &actual_serialized_cert_status_list, &hash_algorithm, + &actual_signature)); EXPECT_EQ(serialized_cert_status_list, actual_serialized_cert_status_list); EXPECT_EQ(signature, actual_signature); + EXPECT_EQ( + HashAlgorithmProtoToEnum(legacy_signed_cert_status_list.hash_algorithm()), + hash_algorithm); DeviceCertificateStatusList expected_cert_status_list; ASSERT_TRUE( @@ -622,9 +793,11 @@ TEST_F(DeviceStatusListTest, DetermineAndDeserializeServiceResponseLegacyWebSafeBase64Success) { std::string serialized_cert_status_list; std::string signature; - GenerateTrivialValidStatusList(&serialized_cert_status_list, &signature); - SignedDeviceCertificateStatusList legacy_signed_cert_status_list; + GenerateTrivialValidStatusList( + &serialized_cert_status_list, + HashAlgorithmProtoToEnum(legacy_signed_cert_status_list.hash_algorithm()), + &signature); *(legacy_signed_cert_status_list.mutable_certificate_status_list()) = serialized_cert_status_list; *(legacy_signed_cert_status_list.mutable_signature()) = signature; @@ -639,14 +812,18 @@ TEST_F(DeviceStatusListTest, std::string actual_serialized_cert_status_list; std::string actual_signature; + HashAlgorithm hash_algorithm; DeviceCertificateStatusList actual_cert_status_list; ASSERT_EQ(OkStatus(), DeviceStatusList::DetermineAndDeserializeServiceResponse( websafe_b64_serialized_signed_cert_status_list, &actual_cert_status_list, &actual_serialized_cert_status_list, - &actual_signature)); + &hash_algorithm, &actual_signature)); EXPECT_EQ(serialized_cert_status_list, actual_serialized_cert_status_list); EXPECT_EQ(signature, actual_signature); + EXPECT_EQ( + HashAlgorithmProtoToEnum(legacy_signed_cert_status_list.hash_algorithm()), + hash_algorithm); DeviceCertificateStatusList expected_cert_status_list; ASSERT_TRUE( @@ -659,9 +836,11 @@ TEST_F(DeviceStatusListTest, DetermineAndDeserializeServiceResponseLegacyBase64Success) { std::string serialized_cert_status_list; std::string signature; - GenerateTrivialValidStatusList(&serialized_cert_status_list, &signature); - SignedDeviceCertificateStatusList legacy_signed_cert_status_list; + GenerateTrivialValidStatusList( + &serialized_cert_status_list, + HashAlgorithmProtoToEnum(legacy_signed_cert_status_list.hash_algorithm()), + &signature); *(legacy_signed_cert_status_list.mutable_certificate_status_list()) = serialized_cert_status_list; *(legacy_signed_cert_status_list.mutable_signature()) = signature; @@ -676,14 +855,18 @@ TEST_F(DeviceStatusListTest, std::string actual_serialized_cert_status_list; std::string actual_signature; + HashAlgorithm hash_algorithm; DeviceCertificateStatusList actual_cert_status_list; ASSERT_EQ(OkStatus(), DeviceStatusList::DetermineAndDeserializeServiceResponse( websafe_b64_serialized_signed_cert_status_list, &actual_cert_status_list, &actual_serialized_cert_status_list, - &actual_signature)); + &hash_algorithm, &actual_signature)); EXPECT_EQ(serialized_cert_status_list, actual_serialized_cert_status_list); EXPECT_EQ(signature, actual_signature); + EXPECT_EQ( + HashAlgorithmProtoToEnum(legacy_signed_cert_status_list.hash_algorithm()), + hash_algorithm); DeviceCertificateStatusList expected_cert_status_list; ASSERT_TRUE( @@ -692,4 +875,44 @@ TEST_F(DeviceStatusListTest, expected_cert_status_list, actual_cert_status_list)); } +TEST_F(DeviceStatusListTest, + GetDeviceCertificateStatusBySystemIdExpiredDeviceCertificateStatusList) { + MockDeviceStatusList mock_device_status_list; + EXPECT_CALL(mock_device_status_list, GetCurrentTime()) + .Times(3) + .WillOnce(Return(kStatusListCreationTime + 100)) + .WillOnce(Return(kStatusListCreationTime + 100)) + .WillOnce(Return(kStatusListCreationTime + 101)); + EXPECT_EQ(OkStatus(), mock_device_status_list.UpdateStatusList( + test_keys_.public_test_key_1_3072_bits(), + serialized_cert_status_list_, kHashAlgorithm, + cert_status_list_signature_, 100)); + DeviceCertificateStatus device_certificate_status; + EXPECT_EQ(OkStatus(), + mock_device_status_list.GetDeviceCertificateStatusBySystemId( + kValidCertSystemId, &device_certificate_status)); + EXPECT_EQ(EXPIRED_CERTIFICATE_STATUS_LIST, + mock_device_status_list + .GetDeviceCertificateStatusBySystemId( + kValidCertSystemId, &device_certificate_status) + .error_code()); +} + +TEST_F(DeviceStatusListTest, + GetDeviceCertificateStatusBySystemIdUnknownDevice) { + DeviceCertificateStatus device_certificate_status; + uint32_t unknown_system_id = 2000; + EXPECT_EQ(DRM_DEVICE_CERTIFICATE_UNKNOWN, + device_status_list_ + .GetDeviceCertificateStatusBySystemId( + unknown_system_id, &device_certificate_status) + .error_code()); +} + +TEST_F(DeviceStatusListTest, GetDeviceCertificateStatusBySystemIdOk) { + DeviceCertificateStatus device_certificate_status; + EXPECT_OK(device_status_list_.GetDeviceCertificateStatusBySystemId( + kValidCertSystemId, &device_certificate_status)); +} + } // namespace widevine diff --git a/common/drm_root_certificate.cc b/common/drm_root_certificate.cc index 61cde09..eac92c6 100644 --- a/common/drm_root_certificate.cc +++ b/common/drm_root_certificate.cc @@ -18,6 +18,8 @@ #include "absl/synchronization/mutex.h" #include "common/ec_key.h" #include "common/error_space.h" +#include "common/hash_algorithm.h" +#include "common/hash_algorithm_util.h" #include "common/rsa_key.h" #include "common/sha_util.h" #include "common/signer_public_key.h" @@ -281,6 +283,7 @@ class VerifiedCertSignatureCache { // cache. Status VerifySignature(const std::string& cert, const std::string& serial_number, + HashAlgorithm hash_algorithm, const std::string& signature, const DrmCertificate& signer) { { @@ -314,7 +317,7 @@ class VerifiedCertSignatureCache { return Status(error_space, INVALID_DRM_CERTIFICATE, "invalid-signer-public-key"); } - if (!signer_public_key->VerifySignature(cert, signature)) { + if (!signer_public_key->VerifySignature(cert, hash_algorithm, signature)) { return Status(error_space, INVALID_SIGNATURE, "cache-miss-invalid-signature"); } @@ -428,8 +431,10 @@ Status DrmRootCertificate::Create(CertificateType cert_type, return Status(error_space, INVALID_DRM_CERTIFICATE, "invalid-root-public-key"); } - if (!public_key->VerifySignature(signed_root_cert.drm_certificate(), - signed_root_cert.signature())) { + if (!public_key->VerifySignature( + signed_root_cert.drm_certificate(), + HashAlgorithmProtoToEnum(signed_root_cert.hash_algorithm()), + signed_root_cert.signature())) { return Status(error_space, INVALID_DRM_CERTIFICATE, "invalid-root-certificate-signature"); } @@ -519,6 +524,7 @@ Status DrmRootCertificate::VerifySignatures( // Always use cache for root-signed certificates. return signature_cache_->VerifySignature( signed_cert.drm_certificate(), cert_serial_number, + HashAlgorithmProtoToEnum(signed_cert.hash_algorithm()), signed_cert.signature(), root_cert_); } DrmCertificate signer; @@ -539,9 +545,10 @@ Status DrmRootCertificate::VerifySignatures( } if (use_cache) { - status = signature_cache_->VerifySignature(signed_cert.drm_certificate(), - cert_serial_number, - signed_cert.signature(), signer); + status = signature_cache_->VerifySignature( + signed_cert.drm_certificate(), cert_serial_number, + HashAlgorithmProtoToEnum(signed_cert.hash_algorithm()), + signed_cert.signature(), signer); if (!status.ok()) { return status; } @@ -552,8 +559,10 @@ Status DrmRootCertificate::VerifySignatures( return Status(error_space, INVALID_DRM_CERTIFICATE, "invalid-leaf-signer-public-key"); } - if (!signer_public_key->VerifySignature(signed_cert.drm_certificate(), - signed_cert.signature())) { + if (!signer_public_key->VerifySignature( + signed_cert.drm_certificate(), + HashAlgorithmProtoToEnum(signed_cert.hash_algorithm()), + signed_cert.signature())) { return Status(error_space, INVALID_SIGNATURE, "cache-miss-invalid-signature"); } diff --git a/common/drm_root_certificate_test.cc b/common/drm_root_certificate_test.cc index 2c5f6f8..a671885 100644 --- a/common/drm_root_certificate_test.cc +++ b/common/drm_root_certificate_test.cc @@ -21,6 +21,8 @@ #include "common/ec_key.h" #include "common/ec_test_keys.h" #include "common/error_space.h" +#include "common/hash_algorithm.h" +#include "common/hash_algorithm_util.h" #include "common/rsa_key.h" #include "common/rsa_test_keys.h" #include "common/test_drm_certificates.h" @@ -101,6 +103,7 @@ class SignerPrivateKey { public: virtual ~SignerPrivateKey() {} virtual bool GenerateSignature(const std::string& message, + HashAlgorithm hash_algorithm, std::string* signature) const = 0; virtual DrmCertificate::Algorithm algorithm() const = 0; static std::unique_ptr Create( @@ -119,8 +122,9 @@ class SignerPrivateKeyImpl : public SignerPrivateKey { : private_key_(std::move(private_key)), algorithm_(algorithm) {} ~SignerPrivateKeyImpl() override {} bool GenerateSignature(const std::string& message, + HashAlgorithm hash_algorithm, std::string* signature) const override { - return private_key_->GenerateSignature(message, signature); + return private_key_->GenerateSignature(message, hash_algorithm, signature); } DrmCertificate::Algorithm algorithm() const override { return algorithm_; } @@ -255,7 +259,9 @@ class DrmRootCertificateTest : public testing::TestWithParam { ASSERT_TRUE(drm_certificates_[kClientKey].SerializeToString( current_sc->mutable_drm_certificate())); ASSERT_TRUE(private_keys_[kInterMediateKey]->GenerateSignature( - current_sc->drm_certificate(), current_sc->mutable_signature())); + current_sc->drm_certificate(), + HashAlgorithmProtoToEnum(current_sc->hash_algorithm()), + current_sc->mutable_signature())); current_sc = current_sc->mutable_signer(); drm_certificates_[kInterMediateKey].set_algorithm( @@ -263,7 +269,9 @@ class DrmRootCertificateTest : public testing::TestWithParam { ASSERT_TRUE(drm_certificates_[kInterMediateKey].SerializeToString( current_sc->mutable_drm_certificate())); ASSERT_TRUE(private_keys_[kDrmRootKey]->GenerateSignature( - current_sc->drm_certificate(), current_sc->mutable_signature())); + current_sc->drm_certificate(), + HashAlgorithmProtoToEnum(current_sc->hash_algorithm()), + current_sc->mutable_signature())); current_sc = current_sc->mutable_signer(); drm_certificates_[kDrmRootKey].set_algorithm( @@ -271,7 +279,9 @@ class DrmRootCertificateTest : public testing::TestWithParam { ASSERT_TRUE(drm_certificates_[kDrmRootKey].SerializeToString( current_sc->mutable_drm_certificate())); ASSERT_TRUE(private_keys_[kDrmRootKey]->GenerateSignature( - current_sc->drm_certificate(), current_sc->mutable_signature())); + current_sc->drm_certificate(), + HashAlgorithmProtoToEnum(current_sc->hash_algorithm()), + current_sc->mutable_signature())); } RsaTestKeys rsa_test_keys_; diff --git a/common/drm_service_certificate.cc b/common/drm_service_certificate.cc index 67674b9..1cc77eb 100644 --- a/common/drm_service_certificate.cc +++ b/common/drm_service_certificate.cc @@ -13,7 +13,7 @@ #include #include "glog/logging.h" -#include "base/thread_annotations.h" +#include "absl/base/thread_annotations.h" #include "absl/strings/escaping.h" #include "absl/synchronization/mutex.h" #include "util/gtl/map_util.h" diff --git a/common/drm_service_certificate_test.cc b/common/drm_service_certificate_test.cc index b60e108..a978990 100644 --- a/common/drm_service_certificate_test.cc +++ b/common/drm_service_certificate_test.cc @@ -17,6 +17,7 @@ #include "absl/strings/escaping.h" #include "common/aes_cbc_util.h" #include "common/drm_root_certificate.h" +#include "common/hash_algorithm_util.h" #include "common/rsa_key.h" #include "common/rsa_test_keys.h" #include "common/rsa_util.h" @@ -62,8 +63,10 @@ class DrmServiceCertificateTest : public ::testing::Test { cert.set_creation_time_seconds(creation_time_seconds); SignedDrmCertificate signed_cert; cert.SerializeToString(signed_cert.mutable_drm_certificate()); - root_private_key_->GenerateSignature(signed_cert.drm_certificate(), - signed_cert.mutable_signature()); + root_private_key_->GenerateSignature( + signed_cert.drm_certificate(), + HashAlgorithmProtoToEnum(signed_cert.hash_algorithm()), + signed_cert.mutable_signature()); std::string serialized_cert; signed_cert.SerializeToString(&serialized_cert); return serialized_cert; diff --git a/common/dual_certificate_client_cert.cc b/common/dual_certificate_client_cert.cc new file mode 100644 index 0000000..ac7d7cd --- /dev/null +++ b/common/dual_certificate_client_cert.cc @@ -0,0 +1,113 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#include "common/dual_certificate_client_cert.h" + +#include "common/error_space.h" +#include "common/status.h" +#include "protos/public/errors.pb.h" + +namespace widevine { + +Status DualCertificateClientCert::Initialize( + const DrmRootCertificate* root_certificate, + const std::string& serialized_signing_certificate, + const std::string& serialized_encryption_certificate) { + Status status = signing_certificate_.Initialize( + root_certificate, serialized_signing_certificate); + if (!status.ok()) { + return status; + } + status = encryption_certificate_.Initialize( + root_certificate, serialized_encryption_certificate); + if (!status.ok()) { + return status; + } + if (encryption_certificate_.signer_serial_number() != + signing_certificate_.signer_serial_number()) { + return Status(error_space, INVALID_DRM_CERTIFICATE, + "certificate_signer_mismatch"); + } + if ((encryption_certificate_.system_id() != + signing_certificate_.system_id()) || + (encryption_certificate_.service_id() != + signing_certificate_.service_id()) || + (encryption_certificate_.signer_creation_time_seconds() != + signing_certificate_.signer_creation_time_seconds()) || + (encryption_certificate_.signed_by_provisioner() != + signing_certificate_.signed_by_provisioner())) { + return Status(error_space, INVALID_DRM_CERTIFICATE, + "invalid_certificate_pair"); + } + return OkStatus(); +} + +Status DualCertificateClientCert::VerifySignature( + const std::string& message, HashAlgorithm hash_algorithm, + const std::string& signature, ProtocolVersion protocol_version) const { + return signing_certificate_.VerifySignature(message, hash_algorithm, + signature, protocol_version); +} + +void DualCertificateClientCert::GenerateSigningKey( + const std::string& message, ProtocolVersion protocol_version) { + encryption_certificate_.GenerateSigningKey(message, protocol_version); +} + +const std::string& DualCertificateClientCert::encrypted_key() const { + return encryption_certificate_.encrypted_key(); +} + +const std::string& DualCertificateClientCert::key() const { + return encryption_certificate_.key(); +} + +SignedMessage::SessionKeyType DualCertificateClientCert::key_type() const { + return encryption_certificate_.key_type(); +} + +// TODO(b/155979840): Support revocation check for the encryption certificate. +const std::string& DualCertificateClientCert::serial_number() const { + return signing_certificate_.serial_number(); +} + +const std::string& DualCertificateClientCert::service_id() const { + return signing_certificate_.service_id(); +} + +const std::string& DualCertificateClientCert::signing_key() const { + return encryption_certificate_.signing_key(); +} + +const std::string& DualCertificateClientCert::signer_serial_number() const { + return signing_certificate_.signer_serial_number(); +} + +uint32_t DualCertificateClientCert::signer_creation_time_seconds() const { + return signing_certificate_.signer_creation_time_seconds(); +} + +bool DualCertificateClientCert::signed_by_provisioner() const { + return signing_certificate_.signed_by_provisioner(); +} + +uint32_t DualCertificateClientCert::system_id() const { + return signing_certificate_.system_id(); +} + +// TODO(b/155979840): Support revocation check for the encryption certificate. +const std::string& DualCertificateClientCert::encrypted_unique_id() const { + return signing_certificate_.encrypted_unique_id(); +} + +// TODO(b/155979840): Support revocation check for the encryption certificate. +const std::string& DualCertificateClientCert::unique_id_hash() const { + return signing_certificate_.unique_id_hash(); +} + +} // namespace widevine diff --git a/common/dual_certificate_client_cert.h b/common/dual_certificate_client_cert.h new file mode 100644 index 0000000..306c2d2 --- /dev/null +++ b/common/dual_certificate_client_cert.h @@ -0,0 +1,57 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef COMMON_DUAL_CERTIFICATE_CLIENT_CERT_H_ +#define COMMON_DUAL_CERTIFICATE_CLIENT_CERT_H_ + +#include "common/certificate_client_cert.h" + +namespace widevine { + +class DualCertificateClientCert : public ClientCert { + public: + DualCertificateClientCert() = default; + ~DualCertificateClientCert() override = default; + DualCertificateClientCert(const DualCertificateClientCert&) = delete; + DualCertificateClientCert& operator=(const DualCertificateClientCert&) = + delete; + + Status Initialize(const DrmRootCertificate* root_certificate, + const std::string& serialized_signing_certificate, + const std::string& serialized_encryption_certificate); + Status VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, + const std::string& signature, + ProtocolVersion protocol_version) const override; + void GenerateSigningKey(const std::string& message, + ProtocolVersion protocol_version) override; + + const std::string& encrypted_key() const override; + const std::string& key() const override; + SignedMessage::SessionKeyType key_type() const override; + bool using_dual_certificate() const override { return true; } + const std::string& serial_number() const override; + const std::string& service_id() const override; + const std::string& signing_key() const override; + const std::string& signer_serial_number() const override; + uint32_t signer_creation_time_seconds() const override; + bool signed_by_provisioner() const override; + uint32_t system_id() const override; + widevine::ClientIdentification::TokenType type() const override { + return ClientIdentification::DRM_DEVICE_CERTIFICATE; + } + const std::string& encrypted_unique_id() const override; + const std::string& unique_id_hash() const override; + + private: + CertificateClientCert signing_certificate_; + CertificateClientCert encryption_certificate_; +}; + +} // namespace widevine +#endif // COMMON_DUAL_CERTIFICATE_CLIENT_CERT_H_ diff --git a/common/ec_key.cc b/common/ec_key.cc index f16fb84..0b16c75 100644 --- a/common/ec_key.cc +++ b/common/ec_key.cc @@ -23,6 +23,7 @@ #include "openssl/sha.h" #include "common/aes_cbc_util.h" #include "common/ec_util.h" +#include "common/hash_algorithm.h" #include "common/openssl_util.h" #include "common/sha_util.h" @@ -53,6 +54,22 @@ std::string OpenSSLErrorString(uint32_t error) { return buf; } +std::string GetMessageDigest(const std::string& message, + widevine::HashAlgorithm hash_algorithm) { + switch (hash_algorithm) { + case widevine::HashAlgorithm::kUnspecified: + case widevine::HashAlgorithm::kSha256: + return widevine::Sha256_Hash(message); + case widevine::HashAlgorithm::kSha1: + LOG(ERROR) << "Unexpected hash algorithm: " + << static_cast(hash_algorithm); + return ""; + } + LOG(FATAL) << "Unexpected hash algorithm: " + << static_cast(hash_algorithm); + return ""; +} + } // namespace ECPrivateKey::ECPrivateKey(EC_KEY* ec_key) : key_(ec_key) { @@ -159,6 +176,47 @@ bool ECPrivateKey::GenerateSignature(const std::string& message, return true; } +bool ECPrivateKey::GenerateSignature(const std::string& message, + HashAlgorithm hash_algorithm, + std::string* signature) const { + if (message.empty()) { + LOG(ERROR) << "|message| cannot be empty"; + return false; + } + if (signature == nullptr) { + LOG(ERROR) << "|signature| cannot be nullptr"; + return false; + } + + // Hash the message using corresponding hash algorithm. + std::string message_digest = GetMessageDigest(message, hash_algorithm); + if (message_digest.empty()) { + LOG(ERROR) << "Empty message digest"; + return false; + } + + size_t max_signature_size = ECDSA_size(key()); + if (max_signature_size == 0) { + LOG(ERROR) << "key_ does not have a group set"; + return false; + } + signature->resize(max_signature_size); + unsigned int bytes_written = 0; + int result = ECDSA_sign( + 0 /* unused type */, + reinterpret_cast(message_digest.data()), + message_digest.size(), + reinterpret_cast(const_cast(signature->data())), + &bytes_written, key()); + if (result != 1) { + LOG(ERROR) << "Could not calculate signature: " + << OpenSSLErrorString(ERR_get_error()); + return false; + } + signature->resize(bytes_written); + return true; +} + bool ECPrivateKey::MatchesPrivateKey(const ECPrivateKey& private_key) const { return BN_cmp(EC_KEY_get0_private_key(key()), EC_KEY_get0_private_key(private_key.key())) == 0; @@ -254,6 +312,39 @@ bool ECPublicKey::VerifySignature(const std::string& message, return true; } +bool ECPublicKey::VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, + const std::string& signature) const { + if (message.empty()) { + LOG(ERROR) << "|message| cannot be empty"; + return false; + } + if (signature.empty()) { + LOG(ERROR) << "|signature| cannot be empty"; + return false; + } + + // Hash the message using corresponding hash algorithm. + std::string message_digest = GetMessageDigest(message, hash_algorithm); + if (message_digest.empty()) { + LOG(ERROR) << "Empty message digest"; + return false; + } + + int result = ECDSA_verify( + 0 /* unused type */, + reinterpret_cast(message_digest.data()), + message_digest.size(), + reinterpret_cast(const_cast(signature.data())), + signature.size(), key()); + if (result != 1) { + LOG(ERROR) << "Could not verify signature: " + << OpenSSLErrorString(ERR_get_error()); + return false; + } + return true; +} + bool ECPublicKey::MatchesPrivateKey(const ECPrivateKey& private_key) const { return private_key.MatchesPublicKey(*this); } diff --git a/common/ec_key.h b/common/ec_key.h index fe9726d..e000f53 100644 --- a/common/ec_key.h +++ b/common/ec_key.h @@ -16,7 +16,9 @@ #include #include +#include "absl/base/macros.h" #include "openssl/ec.h" +#include "common/hash_algorithm.h" #include "common/openssl_util.h" namespace widevine { @@ -65,9 +67,24 @@ class ECPrivateKey { // DER-encoded signature. // Caller retains ownership of all pointers. // Returns true on success and false on error. + // TODO(b/155438325): remove this function after the below function is fully + // propagated. + ABSL_DEPRECATED( + "Use the below function with |hash_algorithm| argument instead.") virtual bool GenerateSignature(const std::string& message, std::string* signature) const; + // Given a message, calculates a signature using ECDSA with the key_. + // |message| is the message to be signed. + // |hash_algorithm| specifies the hash algorithm. + // |signature| will contain the resulting signature. This will be an ASN.1 + // DER-encoded signature. + // Caller retains ownership of all pointers. + // Returns true on success and false on error. + virtual bool GenerateSignature(const std::string& message, + HashAlgorithm hash_algorithm, + std::string* signature) const; + // Returns whether the given private key is the same as key_. virtual bool MatchesPrivateKey(const ECPrivateKey& private_key) const; @@ -110,9 +127,23 @@ class ECPublicKey { // |message| is the message that was signed. // |signature| is an ASN.1 DER-encoded signature. // Returns true on success and false on error. + // TODO(b/155438325): remove this function after the below function is fully + // propagated. + ABSL_DEPRECATED( + "Use the below function with |hash_algorithm| argument instead.") virtual bool VerifySignature(const std::string& message, const std::string& signature) const; + // Given a message and a signature, verifies that the signature was created + // using the private key associated with key_. + // |message| is the message that was signed. + // |hash_algorithm| specifies the hash algorithm. + // |signature| is an ASN.1 DER-encoded signature. + // Returns true on success and false on error. + virtual bool VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, + const std::string& signature) const; + // Returns whether the given private key is part of the same key pair as key_. virtual bool MatchesPrivateKey(const ECPrivateKey& private_key) const; diff --git a/common/ec_key_test.cc b/common/ec_key_test.cc index 4a06e7d..ec6fcdc 100644 --- a/common/ec_key_test.cc +++ b/common/ec_key_test.cc @@ -111,6 +111,10 @@ class ECKeyTestKeyPairs : public ECKeyTest, std::unique_ptr public_key_; }; +// Death test naming convention. See below link for details: +// go/gunitadvanced#death-test-naming +using ECKeyTestKeyPairsDeathTest = ECKeyTestKeyPairs; + TEST_P(ECKeyTestKeyPairs, CreateWrongKey) { EXPECT_EQ(ECPrivateKey::Create(test_public_key_), nullptr); EXPECT_EQ(ECPublicKey::Create(test_private_key_), nullptr); @@ -165,6 +169,40 @@ TEST_P(ECKeyTestKeyPairs, SignVerify) { EXPECT_TRUE(public_key_->VerifySignature(plaintext_message_, signature)); } +TEST_P(ECKeyTestKeyPairs, SignVerifySha1) { + std::string signature; + EXPECT_FALSE(private_key_->GenerateSignature( + plaintext_message_, HashAlgorithm::kSha1, &signature)); + EXPECT_FALSE(public_key_->VerifySignature(plaintext_message_, + HashAlgorithm::kSha1, signature)); +} + +TEST_P(ECKeyTestKeyPairs, SignVerifySha256) { + std::string signature; + ASSERT_TRUE(private_key_->GenerateSignature( + plaintext_message_, HashAlgorithm::kSha256, &signature)); + ASSERT_TRUE(public_key_->VerifySignature(plaintext_message_, + HashAlgorithm::kSha256, signature)); +} + +TEST_P(ECKeyTestKeyPairs, SignVerifyUnspecified) { + std::string signature; + ASSERT_TRUE(private_key_->GenerateSignature( + plaintext_message_, HashAlgorithm::kUnspecified, &signature)); + ASSERT_TRUE(public_key_->VerifySignature( + plaintext_message_, HashAlgorithm::kUnspecified, signature)); +} + +TEST_P(ECKeyTestKeyPairsDeathTest, SignVerifyUnexpected) { + std::string signature; + HashAlgorithm unexpected_hash_algorithm = static_cast(1234); + EXPECT_DEATH(private_key_->GenerateSignature( + plaintext_message_, unexpected_hash_algorithm, &signature), + "Unexpected hash algorithm: 1234"); + EXPECT_FALSE(public_key_->VerifySignature( + plaintext_message_, unexpected_hash_algorithm, signature)); +} + TEST_P(ECKeyTestKeyPairs, InvalidSignVerifyParameters) { std::string signature; EXPECT_FALSE(private_key_->GenerateSignature("", &signature)); @@ -220,6 +258,9 @@ TEST_P(ECKeyTestKeyPairs, KeyPointEncodingSuccess) { INSTANTIATE_TEST_SUITE_P(ECKeyTestKeyPairs, ECKeyTestKeyPairs, ::testing::ValuesIn(ECKeyTest::GetTestKeyList())); +INSTANTIATE_TEST_SUITE_P(ECKeyTestKeyPairsDeathTest, ECKeyTestKeyPairsDeathTest, + ::testing::ValuesIn(ECKeyTest::GetTestKeyList())); + class ECKeyTestCurveMismatch : public ECKeyTest, public ::testing::WithParamInterface< diff --git a/common/ecies_crypto_test.cc b/common/ecies_crypto_test.cc index 783ec1c..f3e2ff8 100644 --- a/common/ecies_crypto_test.cc +++ b/common/ecies_crypto_test.cc @@ -200,9 +200,10 @@ TEST(EciesEncryptorTest, EciesEncryptNullKeySource) { class MockEcKeySource : public ECKeySource { public: MockEcKeySource() = default; - MOCK_METHOD3(GetECKey, - bool(ECPrivateKey::EllipticCurve curve, std::string* private_key, - std::string* public_key)); + MOCK_METHOD(bool, GetECKey, + (ECPrivateKey::EllipticCurve curve, std::string* private_key, + std::string* public_key), + (override)); }; TEST(EciesEncryptorTest, EciesEncryptKeysourceFail) { diff --git a/common/file_util_test.cc b/common/file_util_test.cc index 97e94cd..e07669e 100644 --- a/common/file_util_test.cc +++ b/common/file_util_test.cc @@ -7,6 +7,7 @@ //////////////////////////////////////////////////////////////////////////////// #include "common/file_util.h" + #include "testing/gunit.h" #include "absl/strings/str_cat.h" diff --git a/common/hash_algorithm.h b/common/hash_algorithm.h new file mode 100644 index 0000000..356b190 --- /dev/null +++ b/common/hash_algorithm.h @@ -0,0 +1,18 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef COMMON_HASH_ALGORITHM_H_ +#define COMMON_HASH_ALGORITHM_H_ + +namespace widevine { + +enum class HashAlgorithm { kUnspecified, kSha1, kSha256 }; + +} // namespace widevine + +#endif // COMMON_HASH_ALGORITHM_H_ diff --git a/common/hash_algorithm_util.cc b/common/hash_algorithm_util.cc new file mode 100644 index 0000000..d52e7af --- /dev/null +++ b/common/hash_algorithm_util.cc @@ -0,0 +1,51 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#include "common/hash_algorithm_util.h" + +#include "glog/logging.h" +#include "protos/public/hash_algorithm.pb.h" + +namespace widevine { + +HashAlgorithm HashAlgorithmProtoToEnum( + HashAlgorithmProto hash_algorithm_proto) { + switch (hash_algorithm_proto) { + case HASH_ALGORITHM_UNSPECIFIED: + return HashAlgorithm::kUnspecified; + case HASH_ALGORITHM_SHA_1: + return HashAlgorithm::kSha1; + case HASH_ALGORITHM_SHA_256: + return HashAlgorithm::kSha256; + default: + // See below link for using proto3 enum in switch statement: + // http://shortn/_ma9MY7V9wh + if (HashAlgorithmProto_IsValid(hash_algorithm_proto)) { + LOG(ERROR) << "Unsupported value " << hash_algorithm_proto; + } else { + LOG(WARNING) << "Unexpected value " << hash_algorithm_proto; + } + return HashAlgorithm::kUnspecified; + } +} + +HashAlgorithmProto HashAlgorithmEnumToProto(HashAlgorithm hash_algorithm) { + switch (hash_algorithm) { + case HashAlgorithm::kUnspecified: + return HASH_ALGORITHM_UNSPECIFIED; + case HashAlgorithm::kSha1: + return HASH_ALGORITHM_SHA_1; + case HashAlgorithm::kSha256: + return HASH_ALGORITHM_SHA_256; + } + LOG(WARNING) << "Unexpected hash algorithm " + << static_cast(hash_algorithm); + return HASH_ALGORITHM_UNSPECIFIED; +} + +} // namespace widevine diff --git a/common/hash_algorithm_util.h b/common/hash_algorithm_util.h new file mode 100644 index 0000000..5541544 --- /dev/null +++ b/common/hash_algorithm_util.h @@ -0,0 +1,23 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef COMMON_HASH_ALGORITHM_UTIL_H_ +#define COMMON_HASH_ALGORITHM_UTIL_H_ + +#include "common/hash_algorithm.h" +#include "protos/public/hash_algorithm.pb.h" + +namespace widevine { + +HashAlgorithm HashAlgorithmProtoToEnum(HashAlgorithmProto hash_algorithm_proto); + +HashAlgorithmProto HashAlgorithmEnumToProto(HashAlgorithm hash_algorithm); + +} // namespace widevine + +#endif // COMMON_HASH_ALGORITHM_UTIL_H_ diff --git a/common/hash_algorithm_util_test.cc b/common/hash_algorithm_util_test.cc new file mode 100644 index 0000000..f5fbfdd --- /dev/null +++ b/common/hash_algorithm_util_test.cc @@ -0,0 +1,58 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#include "common/hash_algorithm_util.h" + +#include "testing/gmock.h" +#include "testing/gunit.h" +#include "common/hash_algorithm.h" + +namespace widevine { + +TEST(HashAlgorithmTest, ProtoToEnumUnspecified) { + EXPECT_EQ(HashAlgorithm::kUnspecified, + HashAlgorithmProtoToEnum(HASH_ALGORITHM_UNSPECIFIED)); +} + +TEST(HashAlgorithmTest, ProtoToEnumSha1) { + EXPECT_EQ(HashAlgorithm::kSha1, + HashAlgorithmProtoToEnum(HASH_ALGORITHM_SHA_1)); +} + +TEST(HashAlgorithmTest, ProtoToEnumSha256) { + EXPECT_EQ(HashAlgorithm::kSha256, + HashAlgorithmProtoToEnum(HASH_ALGORITHM_SHA_256)); +} + +TEST(HashAlgorithmTest, ProtoToEnumUnsupported) { + EXPECT_EQ(HashAlgorithm::kUnspecified, + HashAlgorithmProtoToEnum(static_cast(1234))); +} + +TEST(HashAlgorithmTest, EnumToProtoUnspecified) { + EXPECT_EQ(HASH_ALGORITHM_UNSPECIFIED, + HashAlgorithmEnumToProto(HashAlgorithm::kUnspecified)); +} + +TEST(HashAlgorithmTest, EnumToProtoSha1) { + EXPECT_EQ(HASH_ALGORITHM_SHA_1, + HashAlgorithmEnumToProto(HashAlgorithm::kSha1)); +} + +TEST(HashAlgorithmTest, EnumToProtoSha256) { + EXPECT_EQ(HASH_ALGORITHM_SHA_256, + HashAlgorithmEnumToProto(HashAlgorithm::kSha256)); +} + +TEST(HashAlgorithmTest, EnumToProtoUnexpected) { + int some_value = 1234; + EXPECT_EQ(HASH_ALGORITHM_UNSPECIFIED, + HashAlgorithmEnumToProto(static_cast(some_value))); +} + +} // namespace widevine diff --git a/common/keybox_client_cert.cc b/common/keybox_client_cert.cc index c4f4899..bd8e748 100644 --- a/common/keybox_client_cert.cc +++ b/common/keybox_client_cert.cc @@ -33,9 +33,11 @@ Status KeyboxClientCert::Initialize(const std::string& keybox_token) { return OkStatus(); } +// |hash_algorithm| is needed for function inheritance. +// For KeyBoxClientCert, we always use HMAC-SHA256 in signature verification. Status KeyboxClientCert::VerifySignature( - const std::string& message, const std::string& signature, - ProtocolVersion protocol_version) const { + const std::string& message, HashAlgorithm hash_algorithm, + const std::string& signature, ProtocolVersion protocol_version) const { DCHECK(!signing_key_.empty()); using crypto_util::VerifySignatureHmacSha256; if (!VerifySignatureHmacSha256( diff --git a/common/keybox_client_cert.h b/common/keybox_client_cert.h index a068b8f..98fb111 100644 --- a/common/keybox_client_cert.h +++ b/common/keybox_client_cert.h @@ -10,10 +10,12 @@ #define COMMON_KEYBOX_CLIENT_CERT_H_ #include "common/client_cert.h" +#include "common/error_space.h" +#include "common/hash_algorithm.h" +#include "protos/public/errors.pb.h" namespace widevine { -// class KeyboxClientCert : public ClientCert { public: KeyboxClientCert() {} @@ -23,6 +25,7 @@ class KeyboxClientCert : public ClientCert { Status Initialize(const std::string& keybox_token); Status VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, const std::string& signature, ProtocolVersion protocol_version) const override; @@ -34,6 +37,7 @@ class KeyboxClientCert : public ClientCert { SignedMessage::SessionKeyType key_type() const override { return SignedMessage::WRAPPED_AES_KEY; } + bool using_dual_certificate() const override { return false; } const std::string& serial_number() const override { return serial_number_; } const std::string& service_id() const override { return unimplemented_; } const std::string& signing_key() const override { return signing_key_; } @@ -59,6 +63,14 @@ class KeyboxClientCert : public ClientCert { const std::multimap& keymap); static bool IsSystemIdKnown(const uint32_t system_id); static uint32_t GetSystemId(const std::string& keybox_bytes); + Status SystemIdUnknownError() const override { + return Status(error_space, UNSUPPORTED_SYSTEM_ID, + "keybox-unsupported-system-id"); + } + Status SystemIdRevokedError() const override { + return Status(error_space, DRM_DEVICE_CERTIFICATE_REVOKED, + "keybox-system-id-revoked"); + } private: std::string unimplemented_; diff --git a/common/mock_rsa_key.h b/common/mock_rsa_key.h index 711cba9..a241e1d 100644 --- a/common/mock_rsa_key.h +++ b/common/mock_rsa_key.h @@ -10,7 +10,9 @@ #define COMMON_MOCK_RSA_KEY_H_ #include + #include "testing/gmock.h" +#include "common/hash_algorithm.h" #include "common/rsa_key.h" namespace widevine { @@ -20,12 +22,23 @@ class MockRsaPrivateKey : public RsaPrivateKey { MockRsaPrivateKey() : RsaPrivateKey(RSA_new()) {} ~MockRsaPrivateKey() override {} - MOCK_CONST_METHOD2(Decrypt, bool(const std::string& encrypted_message, - std::string* decrypted_message)); - MOCK_CONST_METHOD2(GenerateSignature, - bool(const std::string& message, std::string* signature)); - MOCK_CONST_METHOD1(MatchesPrivateKey, bool(const RsaPrivateKey& private_key)); - MOCK_CONST_METHOD1(MatchesPublicKey, bool(const RsaPublicKey& public_key)); + MOCK_METHOD(bool, Decrypt, + (const std::string& encrypted_message, + std::string* decrypted_message), + (const, override)); + // TODO(b/155438325): remove this function after the below function is fully + // propagated. + MOCK_METHOD(bool, GenerateSignature, + (const std::string& message, std::string* signature), + (const, override)); + MOCK_METHOD(bool, GenerateSignature, + (const std::string& message, HashAlgorithm hash_algorithm, + std::string* signature), + (const, override)); + MOCK_METHOD(bool, MatchesPrivateKey, (const RsaPrivateKey& private_key), + (const, override)); + MOCK_METHOD(bool, MatchesPublicKey, (const RsaPublicKey& public_key), + (const, override)); private: MockRsaPrivateKey(const MockRsaPrivateKey&) = delete; @@ -37,12 +50,23 @@ class MockRsaPublicKey : public RsaPublicKey { MockRsaPublicKey() : RsaPublicKey(RSA_new()) {} ~MockRsaPublicKey() override {} - MOCK_CONST_METHOD2(Encrypt, bool(const std::string& clear_message, - std::string* encrypted_message)); - MOCK_CONST_METHOD2(VerifySignature, bool(const std::string& message, - const std::string& signature)); - MOCK_CONST_METHOD1(MatchesPrivateKey, bool(const RsaPrivateKey& private_key)); - MOCK_CONST_METHOD1(MatchesPublicKey, bool(const RsaPublicKey& public_key)); + MOCK_METHOD(bool, Encrypt, + (const std::string& clear_message, + std::string* encrypted_message), + (const, override)); + // TODO(b/155438325): remove this function after the below function is fully + // propagated. + MOCK_METHOD(bool, VerifySignature, + (const std::string& message, const std::string& signature), + (const, override)); + MOCK_METHOD(bool, VerifySignature, + (const std::string& message, HashAlgorithm hash_algorithm, + const std::string& signature), + (const, override)); + MOCK_METHOD(bool, MatchesPrivateKey, (const RsaPrivateKey& private_key), + (const, override)); + MOCK_METHOD(bool, MatchesPublicKey, (const RsaPublicKey& public_key), + (const, override)); private: MockRsaPublicKey(const MockRsaPublicKey&) = delete; @@ -54,16 +78,14 @@ class MockRsaKeyFactory : public RsaKeyFactory { MockRsaKeyFactory() {} ~MockRsaKeyFactory() override {} - MOCK_CONST_METHOD1( - CreateFromPkcs1PrivateKey, - std::unique_ptr(const std::string& private_key)); - MOCK_CONST_METHOD2(CreateFromPkcs8PrivateKey, - std::unique_ptr( - const std::string& private_key, - const std::string& private_key_passphrase)); - MOCK_CONST_METHOD1( - CreateFromPkcs1PublicKey, - std::unique_ptr(const std::string& public_key)); + MOCK_METHOD(std::unique_ptr, CreateFromPkcs1PrivateKey, + (const std::string& private_key), (const, override)); + MOCK_METHOD(std::unique_ptr, CreateFromPkcs8PrivateKey, + (const std::string& private_key, + const std::string& private_key_passphrase), + (const, override)); + MOCK_METHOD(std::unique_ptr, CreateFromPkcs1PublicKey, + (const std::string& public_key), (const, override)); private: MockRsaKeyFactory(const MockRsaKeyFactory&) = delete; diff --git a/common/playready_interface.h b/common/playready_interface.h new file mode 100644 index 0000000..caa7d30 --- /dev/null +++ b/common/playready_interface.h @@ -0,0 +1,42 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// +#ifndef COMMON_PLAYREADY_INTERFACE_H_ +#define COMMON_PLAYREADY_INTERFACE_H_ + +#include + +#include "util/error_space.h" +#include "protos/public/license_protocol.pb.h" + +namespace widevine { + +class PlayReadyInterface { + public: + PlayReadyInterface() {} + virtual ~PlayReadyInterface() {} + + // Sends to a PlayReady Service running the PlayReady license server on + // Windows . + // Args: + // - |challenge| is a std::string which contains PlayReadyLicenseRequest. + // - |policy| is a std::string which contains the PlayReady Policy Setting. + // - |license| is a std::string of PlayReadyLicenseResponse returned from PlayReady + // Service. + + // Returns: + // - status code from downstream components. + virtual util::Status SendToPlayReady( + const std::string& playready_challenge, const std::string& provider, + const std::string& content_id, + const std::list& keys, + const License::Policy& policy, std::string* playready_license) = 0; +}; + +} // namespace widevine + +#endif // COMMON_PLAYREADY_INTERFACE_H_ diff --git a/common/playready_sdk_impl.cc b/common/playready_sdk_impl.cc new file mode 100644 index 0000000..0540ed0 --- /dev/null +++ b/common/playready_sdk_impl.cc @@ -0,0 +1,23 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// +#include "common/playready_sdk_impl.h" + +#include "absl/status/status.h" +#include "absl/strings/escaping.h" +#include "util/task/codes.pb.h" + +namespace widevine { + +// TODO(user): fill in SendToPlayReady function. +util::Status PlayReadySdkImpl::SendToPlayReady( + const std::string& playready_challenge, const std::string& provider, + const std::string& content_id, const std::list& keys, + const License::Policy& policy, std::string* playready_license) { + return OkStatus; +} +} // namespace widevine diff --git a/common/playready_sdk_impl.h b/common/playready_sdk_impl.h new file mode 100644 index 0000000..8595487 --- /dev/null +++ b/common/playready_sdk_impl.h @@ -0,0 +1,28 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// +#ifndef COMMON_PLAYREADY_SDK_IMPL_H_ +#define COMMON_PLAYREADY_SDK_IMPL_H_ + +#include "common/playready_interface.h" +#include "protos/public/license_protocol.pb.h" + +namespace widevine { +class PlayReadySdkImpl : public PlayReadyInterface { + public: + PlayReadySdkImpl() : PlayReadyInterface() {} + ~PlayReadySdkImpl() override {} + + util::Status SendToPlayReady(const std::string& playready_challenge, + const std::string& provider, + const std::string& content_id, + const std::list& keys, + const License::Policy& policy, + std::string* playready_license) override; +}; +} // namespace widevine +#endif // COMMON_PLAYREADY_SDK_IMPL_H_ diff --git a/common/remote_attestation_verifier.h b/common/remote_attestation_verifier.h index d32db1d..7ae9ed7 100644 --- a/common/remote_attestation_verifier.h +++ b/common/remote_attestation_verifier.h @@ -16,7 +16,7 @@ #include #include -#include "base/thread_annotations.h" +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "common/status.h" #include "common/x509_cert.h" diff --git a/common/rot_id_generator.cc b/common/rot_id_generator.cc index da48749..58c89fe 100644 --- a/common/rot_id_generator.cc +++ b/common/rot_id_generator.cc @@ -15,6 +15,7 @@ #include #include "glog/logging.h" +#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "common/crypto_util.h" #include "common/ec_key.h" @@ -77,11 +78,7 @@ Status RootOfTrustIdGenerator::Generate(uint32_t system_id, std::string RootOfTrustIdGenerator::GenerateUniqueIdHash( const std::string& unique_id) const { - if (unique_id.empty()) { - LOG(WARNING) << "unique_id should not be empty."; - return ""; - } - return Sha256_Hash(absl::StrCat(unique_id, wv_shared_salt_)); + return widevine::GenerateUniqueIdHash(unique_id, wv_shared_salt_); } Status RootOfTrustIdDecryptor::DecryptUniqueId( @@ -104,4 +101,30 @@ Status RootOfTrustIdDecryptor::DecryptUniqueId( return OkStatus(); } +Status RootOfTrustIdDecryptor::VerifyAndExtractAllValues( + uint32_t system_id, const RootOfTrustId& root_of_trust_id, + std::string* device_unique_id, std::string* device_unique_id_hash) const { + CHECK(device_unique_id != nullptr) << "device_unique_id was null."; + CHECK(device_unique_id_hash != nullptr) << "device_unique_id_hash was null."; + + Status status = DecryptUniqueId( + system_id, root_of_trust_id.encrypted_unique_id(), device_unique_id); + if (!status.ok()) { + return status; + } + *device_unique_id_hash = + widevine::GenerateUniqueIdHash(*device_unique_id, wv_shared_salt_); + std::string revocation_hash = + GenerateRotIdHash(root_of_trust_id.encrypted_unique_id(), system_id, + *device_unique_id_hash); + // This should not happen unless there's a bug in the way we issue root of + // trust ids. + if (revocation_hash != root_of_trust_id.unique_id_hash()) { + return Status(error::INVALID_ARGUMENT, + "The generated revocation hash did not match the one in the " + "root_of_trust_id"); + } + return OkStatus(); +} + } // namespace widevine diff --git a/common/rot_id_generator.h b/common/rot_id_generator.h index cece136..b6ae69a 100644 --- a/common/rot_id_generator.h +++ b/common/rot_id_generator.h @@ -38,7 +38,7 @@ class RootOfTrustIdGenerator { // values. The unique id hash values identify revoked devices and are // published in the DCSL and consumed by the License SDK. RootOfTrustIdGenerator(std::unique_ptr ecies_encryptor, - std::string wv_shared_salt) + const std::string& wv_shared_salt) : ecies_encryptor_(std::move(ecies_encryptor)), wv_shared_salt_(std::move(wv_shared_salt)) {} @@ -73,8 +73,10 @@ class RootOfTrustIdGenerator { class RootOfTrustIdDecryptor { public: explicit RootOfTrustIdDecryptor( - std::unique_ptr ecies_decryptor) - : ecies_decryptor_(std::move(ecies_decryptor)) {} + std::unique_ptr ecies_decryptor, + const std::string& wv_shared_salt) + : ecies_decryptor_(std::move(ecies_decryptor)), + wv_shared_salt_(wv_shared_salt) {} // Decrypts the |rot_encrypted_id| using the |system_id| as part of the // context. |unique_id| contains the decrypted value on success. @@ -83,8 +85,18 @@ class RootOfTrustIdDecryptor { Status DecryptUniqueId(uint32_t system_id, const std::string& rot_encrypted_id, std::string* unique_id) const; + // Decrypts the encrypted id within the |root_of_trust_id|, extacting the + // |device_unique_id|, and generating the |device_unique_id_hash|. It then + // generates the rot id revocation hash and verifies that it matches the + // unique_id_hash from the root_of_trust_id. + Status VerifyAndExtractAllValues(uint32_t system_id, + const RootOfTrustId& root_of_trust_id, + std::string* device_unique_id, + std::string* device_unique_id_hash) const; + private: std::unique_ptr ecies_decryptor_; + std::string wv_shared_salt_; }; } // namespace widevine diff --git a/common/rot_id_generator_test.cc b/common/rot_id_generator_test.cc index 91956b0..e8b9bc1 100644 --- a/common/rot_id_generator_test.cc +++ b/common/rot_id_generator_test.cc @@ -80,8 +80,9 @@ class MockEciesEncryptor : public EciesEncryptor { ECPublicKey::Create(test_keys.public_test_key_1_secp256r1()); return new MockEciesEncryptor(std::move(ec_key)); } - MOCK_CONST_METHOD3(Encrypt, bool(const std::string&, const std::string&, - std::string*)); + MOCK_METHOD(bool, Encrypt, + (const std::string&, const std::string&, std::string*), + (const, override)); private: explicit MockEciesEncryptor(std::unique_ptr ec_key) @@ -92,7 +93,7 @@ class MockEciesEncryptor : public EciesEncryptor { TEST_F(RootOfTrustIdGeneratorTest, GenerateIdSuccess) { RootOfTrustIdGenerator generator(CreateEncryptor(), kTestSharedSalt); - RootOfTrustIdDecryptor decryptor(CreateDecryptor()); + RootOfTrustIdDecryptor decryptor(CreateDecryptor(), kTestSharedSalt); // Generate the root of trust id. RootOfTrustId root_of_trust_id; @@ -117,7 +118,7 @@ TEST_F(RootOfTrustIdGeneratorTest, GenerateIdSuccess) { TEST_F(RootOfTrustIdGeneratorTest, GenerateIdUniqueSuccess) { RootOfTrustIdGenerator generator(CreateEncryptor(), kTestSharedSalt); - RootOfTrustIdDecryptor decryptor(CreateDecryptor()); + RootOfTrustIdDecryptor decryptor(CreateDecryptor(), kTestSharedSalt); std::string rot_encrypted_id; std::string rot_id_hash; @@ -208,7 +209,7 @@ TEST_F(RootOfTrustIdGeneratorTest, GenerateIdNullRotIdFail) { TEST_F(RootOfTrustIdGeneratorTest, DecryptorSystemIdMismatchFails) { RootOfTrustIdGenerator generator(CreateEncryptor(), kTestSharedSalt); - RootOfTrustIdDecryptor decryptor(CreateDecryptor()); + RootOfTrustIdDecryptor decryptor(CreateDecryptor(), kTestSharedSalt); // Generate the root of trust id. RootOfTrustId root_of_trust_id; @@ -228,7 +229,7 @@ TEST_F(RootOfTrustIdGeneratorTest, DecryptorSystemIdMismatchFails) { } TEST_F(RootOfTrustIdGeneratorTest, DecryptorBlankUniqueId) { - RootOfTrustIdDecryptor decryptor(CreateDecryptor()); + RootOfTrustIdDecryptor decryptor(CreateDecryptor(), kTestSharedSalt); // Attempt to decrypt empty encrypted id. std::string decrypted_unique_id; @@ -239,7 +240,7 @@ TEST_F(RootOfTrustIdGeneratorTest, DecryptorBlankUniqueId) { TEST_F(RootOfTrustIdGeneratorTest, DecryptorSystemIdNullDecryptedIdFails) { RootOfTrustIdGenerator generator(CreateEncryptor(), kTestSharedSalt); - RootOfTrustIdDecryptor decryptor(CreateDecryptor()); + RootOfTrustIdDecryptor decryptor(CreateDecryptor(), kTestSharedSalt); // Generate the root of trust id. RootOfTrustId root_of_trust_id; @@ -256,4 +257,48 @@ TEST_F(RootOfTrustIdGeneratorTest, DecryptorSystemIdNullDecryptedIdFails) { "unique_id"); } +TEST_F(RootOfTrustIdGeneratorTest, VerifyAndExtractAllValuesSuccess) { + RootOfTrustIdGenerator generator(CreateEncryptor(), kTestSharedSalt); + RootOfTrustIdDecryptor decryptor(CreateDecryptor(), kTestSharedSalt); + + // Generate the root of trust id. + RootOfTrustId root_of_trust_id; + ASSERT_OK( + generator.Generate(kTestSystemId, kTestUniqueId, &root_of_trust_id)); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + expected_root_of_trust_id_, root_of_trust_id)); + + // Verify decrypted unique id. + std::string decrypted_unique_id; + std::string decrypted_unique_id_hash; + EXPECT_OK(decryptor.VerifyAndExtractAllValues(kTestSystemId, root_of_trust_id, + &decrypted_unique_id, + &decrypted_unique_id_hash)); + EXPECT_EQ(kTestUniqueId, decrypted_unique_id); + EXPECT_EQ(generator.GenerateUniqueIdHash(kTestUniqueId), + decrypted_unique_id_hash); +} + +TEST_F(RootOfTrustIdGeneratorTest, VerifyAndExtractAllValuesSystemIdMismatch) { + RootOfTrustIdGenerator generator(CreateEncryptor(), kTestSharedSalt); + RootOfTrustIdDecryptor decryptor(CreateDecryptor(), kTestSharedSalt); + + // Generate the root of trust id. + RootOfTrustId root_of_trust_id; + ASSERT_OK( + generator.Generate(kTestSystemId, kTestUniqueId, &root_of_trust_id)); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + expected_root_of_trust_id_, root_of_trust_id)); + + // Verify decrypted unique id. + std::string decrypted_unique_id; + std::string decrypted_unique_id_hash; + EXPECT_EQ(error::INTERNAL, + decryptor + .VerifyAndExtractAllValues(kTestSystemId + 1, root_of_trust_id, + &decrypted_unique_id, + &decrypted_unique_id_hash) + .error_code()); +} + } // namespace widevine diff --git a/common/rot_id_util.cc b/common/rot_id_util.cc index f994baa..a43e942 100644 --- a/common/rot_id_util.cc +++ b/common/rot_id_util.cc @@ -29,4 +29,17 @@ std::string GenerateRotIdHash(const std::string& salt, uint32_t system_id, return Sha256_Hash(absl::StrCat(salt, system_id, unique_id_hash)); } +std::string GenerateUniqueIdHash(const std::string& unique_id, + const std::string& salt) { + if (unique_id.empty()) { + LOG(WARNING) << "unique_id should not be empty."; + return ""; + } + if (salt.empty()) { + LOG(WARNING) << "salt should not be empty."; + return ""; + } + return widevine::Sha256_Hash(absl::StrCat(unique_id, salt)); +} + } // namespace widevine diff --git a/common/rot_id_util.h b/common/rot_id_util.h index 49a55dc..835985e 100644 --- a/common/rot_id_util.h +++ b/common/rot_id_util.h @@ -21,6 +21,14 @@ namespace widevine { +// Helper function that generates the unique id hash from the |unique_id| and +// the |salt|. |salt| is an internal secret. +// +// Returns the hash value on success. +// If |salt| or |unique_id| are empty, this will return an empty string. +std::string GenerateUniqueIdHash(const std::string& unique_id, + const std::string& salt); + // Helper function that generates the hash for the ROT id from the // |unique_id_hash|, the |system_id| and the |salt|. |salt| is typically an // encrypted unique id. Since we use an ephemeral eliptic curve key as part of diff --git a/common/rot_id_util_test.cc b/common/rot_id_util_test.cc index 51bf6df..9b89d98 100644 --- a/common/rot_id_util_test.cc +++ b/common/rot_id_util_test.cc @@ -19,9 +19,15 @@ namespace { constexpr char kFakeEncryptedId[] = "fake encrypted id"; constexpr char kFakeUniqueIdHash[] = "fake unique_id hash"; +constexpr char kFakeUniqueId[] = "fake unique_id"; +constexpr char kFakeSecretSalt[] = "fake secret salt"; + // This is the ROT ID Hash generated from the fake values. constexpr char kRotIdHashHex[] = "0a757dde0f1080b60f34bf8e46af573ce987b5ed1c831b44952e2feed5243a95"; +// This is the unique id hash generated from the fake unique id value. +constexpr char kUniqueIdHashHex[] = + "da20922e84b48e52223496f44b07632a4db19d488cd71cf813de300b9d244e06"; constexpr uint32_t kFakeSystemId = 1234; constexpr uint32_t kOtherFakeSystemId = 9876; @@ -63,4 +69,17 @@ TEST(RotIdUtilTest, GenerateRotIdHashSuccess) { GenerateRotIdHash(kFakeEncryptedId, kFakeSystemId, kFakeUniqueIdHash)); } +// This test really only ensures the stability of the GenerateUniqueIdHash +// implementation. If the hash ever changes, then it will introduce problems +// into the ecosystem. +TEST(RotIdUtilTest, GenerateUniqueIdHashSuccess) { + ASSERT_EQ(absl::HexStringToBytes(kUniqueIdHashHex), + GenerateUniqueIdHash(kFakeUniqueId, kFakeSecretSalt)); +} + +TEST(RotIdUtilTest, GenerateUniqueIdHashEmptyValues) { + ASSERT_EQ("", GenerateUniqueIdHash(kFakeUniqueId, "")); + ASSERT_EQ("", GenerateUniqueIdHash("", kFakeSecretSalt)); +} + } // namespace widevine diff --git a/common/rsa_key.cc b/common/rsa_key.cc index a57935b..809ae03 100644 --- a/common/rsa_key.cc +++ b/common/rsa_key.cc @@ -13,7 +13,7 @@ // // RSA signature details: // Algorithm: RSASSA-PSS -// Hash algorithm: SHA1 +// Hash algorithm: |hash_algorithm| // Mask generation function: mgf1SHA1 // Salt length: 20 bytes // Trailer field: 0xbc @@ -21,7 +21,7 @@ // RSA encryption details: // Algorithm: RSA-OAEP // Mask generation function: mgf1SHA1 -// Label (encoding paramter): empty std::string +// Label (encoding parameter): empty std::string #include "common/rsa_key.h" @@ -31,6 +31,7 @@ #include "openssl/evp.h" #include "openssl/rsa.h" #include "openssl/sha.h" +#include "common/hash_algorithm.h" #include "common/rsa_util.h" #include "common/sha_util.h" @@ -51,6 +52,21 @@ std::string OpenSSLErrorString(uint32_t error) { return buf; } +std::string GetMessageDigest(const std::string& message, + widevine::HashAlgorithm hash_algorithm) { + switch (hash_algorithm) { + // The default hash algorithm of RSA signature is SHA1. + case widevine::HashAlgorithm::kUnspecified: + case widevine::HashAlgorithm::kSha1: + return widevine::Sha1_Hash(message); + case widevine::HashAlgorithm::kSha256: + return widevine::Sha256_Hash(message); + } + LOG(FATAL) << "Unexpected hash algorithm: " + << static_cast(hash_algorithm); + return ""; +} + } // namespace namespace widevine { @@ -137,6 +153,47 @@ bool RsaPrivateKey::GenerateSignature(const std::string& message, return true; } +bool RsaPrivateKey::GenerateSignature(const std::string& message, + HashAlgorithm hash_algorithm, + std::string* signature) const { + DCHECK(signature); + + if (message.empty()) { + LOG(ERROR) << "Message to be signed is empty"; + return false; + } + // Hash the message using corresponding hash algorithm. + std::string message_digest = GetMessageDigest(message, hash_algorithm); + if (message_digest.empty()) { + LOG(ERROR) << "Empty message digest"; + return false; + } + + // Add PSS padding. + size_t rsa_size = RSA_size(key_); + std::string padded_digest(rsa_size, 0); + if (!RSA_padding_add_PKCS1_PSS_mgf1( + key_, reinterpret_cast(&padded_digest[0]), + reinterpret_cast(&message_digest[0]), EVP_sha1(), + EVP_sha1(), kPssSaltLength)) { + LOG(ERROR) << "RSA padding failure: " + << OpenSSLErrorString(ERR_get_error()); + return false; + } + // Encrypt PSS padded digest. + signature->assign(rsa_size, 0); + if (RSA_private_encrypt(padded_digest.size(), + reinterpret_cast(&padded_digest[0]), + reinterpret_cast(&(*signature)[0]), + key_, RSA_NO_PADDING) != + static_cast(signature->size())) { + LOG(ERROR) << "RSA private encrypt failure: " + << OpenSSLErrorString(ERR_get_error()); + return false; + } + return true; +} + bool RsaPrivateKey::GenerateSignatureSha256Pkcs7(const std::string& message, std::string* signature) const { DCHECK(signature); @@ -253,6 +310,52 @@ bool RsaPublicKey::VerifySignature(const std::string& message, return true; } +bool RsaPublicKey::VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, + const std::string& signature) const { + if (message.empty()) { + LOG(ERROR) << "Signed message is empty"; + return false; + } + size_t rsa_size = RSA_size(key_); + if (signature.size() != rsa_size) { + LOG(ERROR) << "Message signature is of the wrong size (expected " + << rsa_size << ", actual " << signature.size() << ")"; + return false; + } + // Decrypt the signature. + std::string padded_digest(signature.size(), 0); + if (RSA_public_decrypt( + signature.size(), + const_cast( + reinterpret_cast(signature.data())), + reinterpret_cast(&padded_digest[0]), key_, + RSA_NO_PADDING) != static_cast(rsa_size)) { + LOG(ERROR) << "RSA public decrypt failure: " + << OpenSSLErrorString(ERR_get_error()); + return false; + } + + // Hash the message using SHA1 using corresponding hash algorithm. + std::string message_digest = GetMessageDigest(message, hash_algorithm); + if (message_digest.empty()) { + LOG(ERROR) << "Empty message digest"; + return false; + } + + // Verify PSS padding. + if (RSA_verify_PKCS1_PSS_mgf1( + key_, reinterpret_cast(&message_digest[0]), + EVP_sha1(), EVP_sha1(), + reinterpret_cast(&padded_digest[0]), + kPssSaltLength) == 0) { + LOG(ERROR) << "RSA Verify PSS padding failure: " + << OpenSSLErrorString(ERR_get_error()); + return false; + } + return true; +} + bool RsaPublicKey::VerifySignatureSha256Pkcs7( const std::string& message, const std::string& signature) const { if (message.empty()) { diff --git a/common/rsa_key.h b/common/rsa_key.h index 68e6ce2..9c4c212 100644 --- a/common/rsa_key.h +++ b/common/rsa_key.h @@ -18,7 +18,9 @@ #include #include +#include "absl/base/macros.h" #include "openssl/rsa.h" +#include "common/hash_algorithm.h" namespace widevine { @@ -41,9 +43,20 @@ class RsaPrivateKey { // Generate RSSASSA-PSS signature. Caller retains ownership of all parameters. // Returns true if successful, false otherwise. + // TODO(b/155438325): remove this function after the below function is fully + // propagated. + ABSL_DEPRECATED( + "Use the below function with |hash_algorithm| argument instead.") virtual bool GenerateSignature(const std::string& message, std::string* signature) const; + // Generate RSSASSA-PSS signature. Caller retains ownership of all parameters. + // |hash_algorithm| indicates the hash algorithm used. Returns true if + // successful, false otherwise. + virtual bool GenerateSignature(const std::string& message, + HashAlgorithm hash_algorithm, + std::string* signature) const; + // Generate SHA256 digest, PKCS#7 padded signature. Caller retains ownership // of all parameters. Returns true if successful, false otherwise. virtual bool GenerateSignatureSha256Pkcs7(const std::string& message, @@ -98,9 +111,20 @@ class RsaPublicKey { // Verify RSSASSA-PSS signature. Caller retains ownership of all parameters. // Returns true if validation succeeds, false otherwise. + // TODO(b/155438325): remove this function after the below function is fully + // propagated. + ABSL_DEPRECATED( + "Use the below function with |hash_algorithm| argument instead.") virtual bool VerifySignature(const std::string& message, const std::string& signature) const; + // Verify RSSASSA-PSS signature. Caller retains ownership of all parameters. + // |hash_algorithm| indicates the hash algorithm used. Returns true if + // validation succeeds, false otherwise. + virtual bool VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, + const std::string& signature) const; + // Verify a signature. This method takes two parameters: |message| which is a // std::string containing the data which was signed, and |signature| which is a // std::string containing the message SHA256 digest signature with PKCS#7 diff --git a/common/rsa_key_test.cc b/common/rsa_key_test.cc index b1153f1..e0822fd 100644 --- a/common/rsa_key_test.cc +++ b/common/rsa_key_test.cc @@ -10,10 +10,11 @@ // Description: // Unit test for rsa_key RSA encryption and signing. +#include "common/rsa_key.h" + #include #include "testing/gunit.h" -#include "common/rsa_key.h" #include "common/rsa_test_keys.h" #include "common/rsa_util.h" @@ -46,8 +47,7 @@ TEST_F(RsaKeyTest, CopyConstructor) { std::unique_ptr private_key_copy( new RsaPrivateKey(*private_key)); - std::unique_ptr public_key_copy( - new RsaPublicKey(*public_key)); + std::unique_ptr public_key_copy(new RsaPublicKey(*public_key)); EXPECT_TRUE(public_key_copy->MatchesPublicKey(*public_key)); EXPECT_TRUE(public_key_copy->MatchesPrivateKey(*private_key)); diff --git a/common/security_profile_list.cc b/common/security_profile_list.cc index f9008c4..2d753a2 100644 --- a/common/security_profile_list.cc +++ b/common/security_profile_list.cc @@ -11,8 +11,12 @@ #include +#include "glog/logging.h" +#include "google/protobuf/text_format.h" #include "common/client_id_util.h" +#include "common/device_status_list.h" #include "protos/public/client_identification.pb.h" +#include "protos/public/device_certificate_status.pb.h" #include "protos/public/device_common.pb.h" #include "protos/public/provisioned_device_info.pb.h" #include "protos/public/security_profile.pb.h" @@ -23,98 +27,44 @@ using ClientCapabilities = ClientIdentification::ClientCapabilities; SecurityProfileList::SecurityProfileList(const std::string& profile_namespace) : profile_namespace_(profile_namespace) {} -int SecurityProfileList::Init() { return AddDefaultProfiles(); } +int SecurityProfileList::Init() { return 0; } -int SecurityProfileList::AddDefaultProfiles() { - const uint32_t oemcrypto_8 = 8; - const uint32_t oemcrypto_12 = 12; - const bool make_model_not_verified = false; - SecurityProfile profile; - - PopulateProfile(SecurityProfile::SECURITY_PROFILE_LEVEL_1, "WVSP1", - ClientCapabilities::HDCP_NONE, - ClientCapabilities::ANALOG_OUTPUT_UNKNOWN, oemcrypto_8, - make_model_not_verified, ProvisionedDeviceInfo::LEVEL_3, - kResourceTierLow, &profile); - InsertProfile(profile); - - PopulateProfile(SecurityProfile::SECURITY_PROFILE_LEVEL_2, "WVSP2", - ClientCapabilities::HDCP_NONE, - ClientCapabilities::ANALOG_OUTPUT_SUPPORTS_CGMS_A, - oemcrypto_12, make_model_not_verified, - ProvisionedDeviceInfo::LEVEL_2, kResourceTierLow, &profile); - InsertProfile(profile); - - PopulateProfile(SecurityProfile::SECURITY_PROFILE_LEVEL_3, "WVSP3", - ClientCapabilities::HDCP_V1, - ClientCapabilities::ANALOG_OUTPUT_SUPPORTS_CGMS_A, - oemcrypto_12, make_model_not_verified, - ProvisionedDeviceInfo::LEVEL_1, kResourceTierMed, &profile); - InsertProfile(profile); - - PopulateProfile(SecurityProfile::SECURITY_PROFILE_LEVEL_4, "WVSP4", - ClientCapabilities::HDCP_V2_2, - ClientCapabilities::ANALOG_OUTPUT_SUPPORTS_CGMS_A, - oemcrypto_12, make_model_not_verified, - ProvisionedDeviceInfo::LEVEL_1, kResourceTierHigh, &profile); - InsertProfile(profile); - absl::ReaderMutexLock lock(&mutex_); - return security_profiles_.size(); -} - -SecurityProfile::Level SecurityProfileList::GetBestProfileLevel( - const ClientIdentification& client_id, - const ProvisionedDeviceInfo& device_info, - SecurityProfile::DrmInfo* drm_info) const { - // Iterate through each SP starting from the strictest first. - absl::ReaderMutexLock lock(&mutex_); - // Profile list is assumed to be sorted. - for (auto& profile : security_profiles_) { - if (!IsProfileAllowed(profile, client_id, device_info)) { - continue; - } - if (drm_info != nullptr) { - GetDrmInfo(client_id, device_info, drm_info); - } - return profile.level(); - } - return SecurityProfile::SECURITY_PROFILE_LEVEL_UNDEFINED; -} - -int SecurityProfileList::GetAllowedProfilesFromList( +int SecurityProfileList::GetQualifiedProfilesFromSpecifiedProfiles( const std::vector& profiles_to_check, const ClientIdentification& client_id, const ProvisionedDeviceInfo& device_info, - std::vector* profiles_to_allow) const { - if (profiles_to_allow == nullptr) { + std::vector* qualified_profiles) const { + if (qualified_profiles == nullptr) { return 0; } + qualified_profiles->clear(); absl::ReaderMutexLock lock(&mutex_); for (auto& profile_name : profiles_to_check) { SecurityProfile profile; if (GetProfileByName(profile_name, &profile)) { - if (IsProfileAllowed(profile, client_id, device_info)) { - profiles_to_allow->push_back(profile.name()); + if (DoesProfileQualify(profile, client_id, device_info)) { + qualified_profiles->push_back(profile.name()); } } } - return profiles_to_allow->size(); + return qualified_profiles->size(); } -int SecurityProfileList::GetAllowedProfiles( +int SecurityProfileList::GetQualifiedProfiles( const ClientIdentification& client_id, const ProvisionedDeviceInfo& device_info, - std::vector* profiles_to_allow) const { - if (profiles_to_allow == nullptr) { + std::vector* qualified_profiles) const { + if (qualified_profiles == nullptr) { return 0; } + qualified_profiles->clear(); absl::ReaderMutexLock lock(&mutex_); for (auto& profile : security_profiles_) { - if (IsProfileAllowed(profile, client_id, device_info)) { - profiles_to_allow->push_back(profile.name()); + if (DoesProfileQualify(profile, client_id, device_info)) { + qualified_profiles->push_back(profile.name()); } } - return profiles_to_allow->size(); + return qualified_profiles->size(); } bool SecurityProfileList::GetDrmInfo(const ClientIdentification& client_id, @@ -127,62 +77,65 @@ bool SecurityProfileList::GetDrmInfo(const ClientIdentification& client_id, client_id.client_capabilities().max_hdcp_version()); drm_info->mutable_output()->set_analog_output_capabilities( client_id.client_capabilities().analog_output_capabilities()); - drm_info->mutable_security()->set_oemcrypto_version( + drm_info->mutable_security()->set_oemcrypto_api_version( client_id.client_capabilities().oem_crypto_api_version()); drm_info->mutable_security()->set_resource_rating_tier( client_id.client_capabilities().resource_rating_tier()); drm_info->mutable_security()->set_security_level( device_info.security_level()); - drm_info->mutable_security()->set_request_model_info_status(false); drm_info->mutable_request_model_info()->set_manufacturer( GetClientInfo(client_id, kModDrmMake)); drm_info->mutable_request_model_info()->set_model_name( GetClientInfo(client_id, kModDrmModel)); - drm_info->set_system_id(device_info.system_id()); - return true; -} - -bool SecurityProfileList::PopulateProfile( - const SecurityProfile::Level profile_level, const std::string& profile_name, - const ClientCapabilities::HdcpVersion min_hdcp_version, - const ClientCapabilities::AnalogOutputCapabilities - analog_output_capabilities, - const uint32_t min_oemcrypto_version, const bool make_model_verified, - const ProvisionedDeviceInfo::WvSecurityLevel security_level, - const uint32_t resource_rating_tier, - SecurityProfile* profile_to_create) const { - if (profile_to_create == nullptr) { - return false; - } - profile_to_create->set_level(profile_level); - profile_to_create->set_name(profile_name); - profile_to_create->mutable_min_output_requirements()->set_hdcp_version( - min_hdcp_version); - profile_to_create->mutable_min_output_requirements() - ->set_analog_output_capabilities(analog_output_capabilities); - profile_to_create->mutable_min_security_requirements()->set_oemcrypto_version( - min_oemcrypto_version); - profile_to_create->mutable_min_security_requirements()->set_security_level( - security_level); - profile_to_create->mutable_min_security_requirements() - ->set_resource_rating_tier(resource_rating_tier); - profile_to_create->mutable_min_security_requirements() - ->set_request_model_info_status(make_model_verified); - return true; -} - -bool SecurityProfileList::GetProfileByLevel( - SecurityProfile::Level level, SecurityProfile* security_profile) const { - absl::ReaderMutexLock lock(&mutex_); - for (auto& profile : security_profiles_) { - if (profile.level() == level) { - if (security_profile != nullptr) { - *security_profile = profile; - } - return true; + drm_info->mutable_request_model_info()->set_status( + DeviceModel::MODEL_STATUS_UNVERIFIED); + for (const auto& model_info : device_info.model_info()) { + if (model_info.manufacturer() == + drm_info->request_model_info().manufacturer() && + model_info.model_name() == + drm_info->request_model_info().model_name()) { + drm_info->mutable_request_model_info()->set_status(model_info.status()); + drm_info->mutable_request_model_info()->set_model_year( + model_info.model_year()); + break; } } - return false; + drm_info->set_system_id(device_info.system_id()); + SecurityProfile::ClientInfo* client_info = drm_info->mutable_client_info(); + client_info->set_device_name(GetClientInfo(client_id, kModDrmDeviceName)); + SecurityProfile::ClientInfo::ProductInfo* product_info = + client_info->mutable_product_info(); + product_info->set_product_name(GetClientInfo(client_id, kModDrmProductName)); + product_info->set_build_info(GetClientInfo(client_id, kModDrmBuildInfo)); + product_info->set_oem_crypto_security_patch_level( + GetClientInfo(client_id, kModDrmOemCryptoSecurityPatchLevel)); + // TODO(user): Figure out how to get device platform pushed into SPL API. + DeviceCertificateStatus device_certificate_status; + SecurityProfile::DeviceState device_model_state = + SecurityProfile::DEVICE_STATE_UNKNOWN; + Status status = + DeviceStatusList::Instance()->GetDeviceCertificateStatusBySystemId( + device_info.system_id(), &device_certificate_status); + if (status.ok() && device_certificate_status.has_status()) { + switch (device_certificate_status.status()) { + case DeviceCertificateStatus::STATUS_IN_TESTING: + device_model_state = SecurityProfile::IN_TESTING; + break; + case DeviceCertificateStatus::STATUS_RELEASED: + device_model_state = SecurityProfile::RELEASED; + break; + case DeviceCertificateStatus::STATUS_TEST_ONLY: + device_model_state = SecurityProfile::TEST_ONLY; + break; + case DeviceCertificateStatus::STATUS_REVOKED: + device_model_state = SecurityProfile::REVOKED; + break; + default: + break; + } + } + drm_info->set_device_model_state(device_model_state); + return true; } bool SecurityProfileList::GetProfileByName( @@ -202,49 +155,91 @@ bool SecurityProfileList::GetProfileByName( bool SecurityProfileList::InsertProfile( const SecurityProfile& profile_to_insert) { // Check if profile already exist. - if (GetProfileByLevel(profile_to_insert.level(), nullptr)) { - return false; - } if (GetProfileByName(profile_to_insert.name(), nullptr)) { + LOG(ERROR) << "Unable to insert profile: " << profile_to_insert.name() + << ". Name already exist."; return false; } + if (profile_to_insert.min_security_requirements().security_level() == + ProvisionedDeviceInfo::LEVEL_UNSPECIFIED) { + LOG(ERROR) << "Unable to insert profile: " << profile_to_insert.name() + << ". Security level not specified."; + return false; + } + absl::WriterMutexLock lock(&mutex_); security_profiles_.push_back(profile_to_insert); - sort(security_profiles_.begin(), security_profiles_.end(), - CompareProfileLevel); return true; } -bool SecurityProfileList::CompareProfileLevel(const SecurityProfile& p1, - const SecurityProfile& p2) { - // Profiles are sorted from highest to lowest (strictest) level. - return (p1.level() > p2.level()); +int SecurityProfileList::NumProfiles() const { + absl::ReaderMutexLock lock(&mutex_); + return security_profiles_.size(); } -bool SecurityProfileList::IsProfileAllowed( +void SecurityProfileList::ClearAllProfiles() { + absl::WriterMutexLock lock(&mutex_); + security_profiles_.clear(); +} + +bool SecurityProfileList::DoesProfileQualify( const SecurityProfile& profile, const ClientIdentification& client_id, const ProvisionedDeviceInfo& device_info) const { if (profile.min_security_requirements().security_level() < device_info.security_level()) { + VLOG(1) << "Profile does not qualify <" << profile.name() + << "> security level: " + << profile.min_security_requirements().security_level() + << ", device: " << device_info.security_level(); return false; } - if (profile.min_security_requirements().oemcrypto_version() > + if (profile.min_security_requirements().oemcrypto_api_version() > client_id.client_capabilities().oem_crypto_api_version()) { + VLOG(1) << "Profile does not qualify <" << profile.name() + << "> oemcrypto version: " + << profile.min_security_requirements().oemcrypto_api_version() + << ", device: " + << client_id.client_capabilities().oem_crypto_api_version(); return false; } if (profile.min_output_requirements().hdcp_version() > client_id.client_capabilities().max_hdcp_version()) { + VLOG(1) << "profile does not qualify <" << profile.name() + << "> hdcp_version: " + << profile.min_output_requirements().hdcp_version() << ", device: " + << client_id.client_capabilities().max_hdcp_version(); return false; } if (profile.min_output_requirements().analog_output_capabilities() > client_id.client_capabilities().analog_output_capabilities()) { + VLOG(1) << "Profile idoes not qualify <" << profile.name() + << "> analog output: " + << profile.min_output_requirements().analog_output_capabilities() + << ", device: " + << client_id.client_capabilities().analog_output_capabilities(); return false; } if (profile.min_security_requirements().resource_rating_tier() > client_id.client_capabilities().resource_rating_tier()) { + VLOG(1) << "Profile does not qualify <" << profile.name() + << "> resource rating tier: " + << profile.min_security_requirements().resource_rating_tier() + << ", device: " + << client_id.client_capabilities().resource_rating_tier(); return false; } return true; } +void SecurityProfileList::GetProfileNames( + std::vector* profile_names) const { + if (profile_names == nullptr) { + return; + } + absl::ReaderMutexLock lock(&mutex_); + for (auto& profile : security_profiles_) { + profile_names->push_back(profile.name()); + } +} + } // namespace widevine diff --git a/common/security_profile_list.h b/common/security_profile_list.h index 459c014..07b2a3b 100644 --- a/common/security_profile_list.h +++ b/common/security_profile_list.h @@ -7,25 +7,22 @@ //////////////////////////////////////////////////////////////////////////////// // // Description: -// Container of Widevine security profiles. Security profiles indicate the -// level of security of a device based on the device's output protections, -// version of OEMCrypto and security level. +// Container of device security profiles. Security profiles indicate rules +// to allow using the profile. The rules are based on DRM capabilities of a +// device. #ifndef COMMON_SECURITY_PROFILE_LIST_H_ #define COMMON_SECURITY_PROFILE_LIST_H_ #include "absl/synchronization/mutex.h" #include "protos/public/client_identification.pb.h" +#include "protos/public/device_security_profile_data.pb.h" #include "protos/public/provisioned_device_info.pb.h" #include "protos/public/security_profile.pb.h" namespace widevine { using ClientCapabilities = ClientIdentification::ClientCapabilities; -const uint32_t kResourceTierLow = 1; -const uint32_t kResourceTierMed = 2; -const uint32_t kResourceTierHigh = 3; - // The SecurityProfileList will hold all security profiles. During license // acquisition, information from the client and information from the server are // combined to deternmine the device's security profile level. @@ -33,57 +30,31 @@ const uint32_t kResourceTierHigh = 3; class SecurityProfileList { public: explicit SecurityProfileList(const std::string& profile_namespace); - ~SecurityProfileList() {} + virtual ~SecurityProfileList() {} - // Initialize the security profile list. The list is initially empty, this - // function will populate the list with default profiles. The size of the - // list is returned. - int Init(); + // Initialize the security profile list. The size of the profile list is + // returned. + virtual int Init(); // Add the specified profile to the existing list of profiles. Returns true // if successfully inserted, false if unable to insert. bool InsertProfile(const SecurityProfile& profile_to_insert); - // Populate |profile_to_create| with the specified output protections and - // security parameters. All input parameters are used hence should be set. - bool PopulateProfile( - const SecurityProfile::Level profile_level, - const std::string& profile_name, - const ClientCapabilities::HdcpVersion min_hdcp_version, - const ClientCapabilities::AnalogOutputCapabilities - analog_output_capabilities, - const uint32_t min_oemcrypto_version, const bool make_model_verified, - const ProvisionedDeviceInfo::WvSecurityLevel security_level, - const uint32_t resource_rating_tier, - SecurityProfile* profile_to_create) const; - - // Return the highest security level based on the device capabilities. - // If |drm_info| is not null, |drm_info| is populated with the device data. - SecurityProfile::Level GetBestProfileLevel( - const ClientIdentification& client_id, - const ProvisionedDeviceInfo& device_info, - SecurityProfile::DrmInfo* drm_info) const; - - // Populates |profiles_to_allow| with a list of profiles that meet the - // requirements for the this device. The number of profiles is returned. - int GetAllowedProfiles(const ClientIdentification& client_id, - const ProvisionedDeviceInfo& device_info, - std::vector* profiles_to_allow) const; - // Populates |profiles_allow| with a list of profiles from the specified // |profiles_to_check| list that meet the requirements for the this device. // The number of profiles is returned. - int GetAllowedProfilesFromList( + virtual int GetQualifiedProfilesFromSpecifiedProfiles( const std::vector& profiles_to_check, const ClientIdentification& client_id, const ProvisionedDeviceInfo& device_info, - std::vector* profiles_to_allow) const; + std::vector* qualified_profiles) const; - // Return true if a profile exist matching the specified |level|. - // |security_profile| is owned by the caller and is populated if a profile - // exist. - bool GetProfileByLevel(SecurityProfile::Level level, - SecurityProfile* security_profile) const; + // Populates |profiles_to_allow| with a list of profiles that meet the + // requirements for the this device. The number of profiles is returned. + virtual int GetQualifiedProfiles( + const ClientIdentification& client_id, + const ProvisionedDeviceInfo& device_info, + std::vector* qualified_profiles) const; // Return true if a profile exist matching the specified |name|. // |security_profile| is owned by the caller and is populated if a profile @@ -97,17 +68,20 @@ class SecurityProfileList { const ProvisionedDeviceInfo& device_info, SecurityProfile::DrmInfo* drm_info) const; + // Return the number of profiles in the list. + int NumProfiles() const; + + // Return a list of profile names. + virtual void GetProfileNames(std::vector* profile_names) const; + + protected: + void ClearAllProfiles(); + private: - // Initialize the list with Widevine default profiles. The size of the - // profile list after the additions is returned. - int AddDefaultProfiles(); + bool DoesProfileQualify(const SecurityProfile& profile, + const ClientIdentification& client_id, + const ProvisionedDeviceInfo& device_info) const; - static bool CompareProfileLevel(const SecurityProfile& p1, - const SecurityProfile& p2); - - bool IsProfileAllowed(const SecurityProfile& profile, - const ClientIdentification& client_id, - const ProvisionedDeviceInfo& device_info) const; mutable absl::Mutex mutex_; // Security profiles diff --git a/common/security_profile_list_test.cc b/common/security_profile_list_test.cc index 76bd98c..964a779 100644 --- a/common/security_profile_list_test.cc +++ b/common/security_profile_list_test.cc @@ -13,6 +13,7 @@ #include "testing/gmock.h" #include "testing/gunit.h" #include "absl/memory/memory.h" +#include "common/client_id_util.h" #include "protos/public/device_common.pb.h" #include "protos/public/security_profile.pb.h" @@ -23,7 +24,16 @@ const char kMakeName[] = "company_name"; const char kMakeValue[] = "Google"; const char kModelName[] = "model_name"; const char kModelValue[] = "model1"; +const char kDeviceNameValue[] = "TestDeviceName"; +const char kProductNameValue[] = "TestProductName"; +const char kBuildInfoValue[] = "TestBuildInfo"; +const char kOemCryptoSecurityPatchLevelValue[] = + "TestOemCryptoSecurityPatchLevel"; +const char kDefaultContentOwnerName[] = "Widevine"; const uint32_t kSystemId = 1234; +const uint32_t kResourceTierLow = 1; +const uint32_t kResourceTierMed = 2; +const uint32_t kResourceTierHigh = 3; class SecurityProfileListTest : public ::testing::Test { public: @@ -32,75 +42,71 @@ class SecurityProfileListTest : public ::testing::Test { void SetUp() override { const uint32_t oemcrypto_12 = 12; - const bool make_model_not_verified = false; - const ClientIdentification::ClientCapabilities::HdcpVersion hdcp_version = - ClientCapabilities::HDCP_V2_2; - test_profile_1_.set_level(SecurityProfile::SECURITY_PROFILE_LEVEL_1); - test_profile_1_.mutable_min_output_requirements()->set_hdcp_version( - hdcp_version); - test_profile_1_.mutable_min_output_requirements() - ->set_analog_output_capabilities( - ClientCapabilities::ANALOG_OUTPUT_SUPPORTS_CGMS_A); - test_profile_1_.mutable_min_security_requirements()->set_oemcrypto_version( - oemcrypto_12); - test_profile_1_.mutable_min_security_requirements()->set_security_level( - ProvisionedDeviceInfo::LEVEL_1); - test_profile_1_.mutable_min_security_requirements() - ->set_resource_rating_tier(kResourceTierHigh); - test_profile_1_.mutable_min_security_requirements() - ->set_request_model_info_status(make_model_not_verified); - std::string profile_namespace = "widevine_test"; + SecurityProfile profile; + std::string profile_namespace = "widevine"; profile_list_ = absl::make_unique(profile_namespace); - ClientIdentification_NameValue *nv = client_id_.add_client_info(); - nv->set_name(kMakeName); - nv->set_value(kMakeValue); - nv = client_id_.add_client_info(); - nv->set_name(kModelName); - nv->set_value(kModelValue); + AddClientInfo(&client_id_, kMakeName, kMakeValue); + AddClientInfo(&client_id_, kModelName, kModelValue); + AddClientInfo(&client_id_, kModDrmDeviceName, kDeviceNameValue); + AddClientInfo(&client_id_, kModDrmProductName, kProductNameValue); + AddClientInfo(&client_id_, kModDrmBuildInfo, kBuildInfoValue); + AddClientInfo(&client_id_, kModDrmOemCryptoSecurityPatchLevel, + kOemCryptoSecurityPatchLevelValue); client_id_.mutable_client_capabilities()->set_oem_crypto_api_version( oemcrypto_12); client_id_.mutable_client_capabilities()->set_max_hdcp_version( - hdcp_version); + ClientCapabilities::HDCP_V2_2); client_id_.mutable_client_capabilities()->set_resource_rating_tier( kResourceTierHigh); device_info_.set_security_level(ProvisionedDeviceInfo::LEVEL_1); device_info_.set_system_id(kSystemId); } - SecurityProfile test_profile_1_; std::unique_ptr profile_list_; ClientIdentification client_id_; ProvisionedDeviceInfo device_info_; }; TEST_F(SecurityProfileListTest, InsertProfile) { - // This test will not initialize the SecurityProfileList, hence it's empty. - // Insert test profile 1 into the list. - EXPECT_TRUE(profile_list_->InsertProfile(test_profile_1_)); - // Should not allow insertion of an already existing level. - EXPECT_FALSE(profile_list_->InsertProfile(test_profile_1_)); - SecurityProfile profile; - // Should not allow insertion of an already existing name. - // Make sure the level is not the same as already inserted level_1. - profile.set_level(SecurityProfile::SECURITY_PROFILE_LEVEL_2); - profile.set_name(test_profile_1_.name()); - EXPECT_FALSE(profile_list_->InsertProfile(test_profile_1_)); - ASSERT_TRUE(profile_list_->GetProfileByLevel( - SecurityProfile::SECURITY_PROFILE_LEVEL_1, &profile)); - EXPECT_TRUE( - google::protobuf::util::MessageDifferencer::Equals(test_profile_1_, profile)); + // Insert test profile1 into the list. + SecurityProfileList profile_list("widevine-test"); + SecurityProfile profile1; + profile1.set_name("profile1"); + profile1.mutable_min_security_requirements()->set_security_level( + ProvisionedDeviceInfo::LEVEL_3); + EXPECT_TRUE(profile_list.InsertProfile(profile1)); + // Verify the list still has one profile. + EXPECT_EQ(1, profile_list.NumProfiles()); + // Should not allow insert if existing profile has the same name. + SecurityProfile profile2; + profile2.set_name(profile1.name()); + profile2.mutable_min_security_requirements()->set_security_level( + ProvisionedDeviceInfo::LEVEL_3); + EXPECT_FALSE(profile_list.InsertProfile(profile2)); + // Verify the list still has one profile. + EXPECT_EQ(1, profile_list.NumProfiles()); + // Should allow insert since this profile has a different name. + profile2.set_name("profile2"); + EXPECT_TRUE(profile_list.InsertProfile(profile2)); + EXPECT_EQ(2, profile_list.NumProfiles()); } TEST_F(SecurityProfileListTest, GetDrmInfo) { SecurityProfile::DrmInfo drm_info; + DeviceModel* device_model = device_info_.add_model_info(); + device_model->set_manufacturer(GetClientInfo(client_id_, kModDrmMake)); + device_model->set_model_name(GetClientInfo(client_id_, kModDrmModel)); + device_model->set_status(DeviceModel::MODEL_STATUS_VERIFIED); + const uint32_t model_launch_year = 2015; + device_model->set_model_year(model_launch_year); ASSERT_TRUE(profile_list_->GetDrmInfo(client_id_, device_info_, &drm_info)); EXPECT_EQ(client_id_.client_capabilities().max_hdcp_version(), drm_info.output().hdcp_version()); EXPECT_EQ(client_id_.client_capabilities().analog_output_capabilities(), drm_info.output().analog_output_capabilities()); EXPECT_EQ(client_id_.client_capabilities().oem_crypto_api_version(), - drm_info.security().oemcrypto_version()); + drm_info.security().oemcrypto_api_version()); EXPECT_EQ(client_id_.client_capabilities().resource_rating_tier(), drm_info.security().resource_rating_tier()); @@ -108,104 +114,92 @@ TEST_F(SecurityProfileListTest, GetDrmInfo) { drm_info.security().security_level()); EXPECT_EQ(device_info_.system_id(), drm_info.system_id()); - // make_mode status is currently hard-coded to false. - EXPECT_EQ(false, drm_info.security().request_model_info_status()); EXPECT_EQ(kMakeValue, drm_info.request_model_info().manufacturer()); EXPECT_EQ(kModelValue, drm_info.request_model_info().model_name()); + EXPECT_EQ(DeviceModel::MODEL_STATUS_VERIFIED, + drm_info.request_model_info().status()); + EXPECT_EQ(model_launch_year, drm_info.request_model_info().model_year()); + EXPECT_EQ(kDeviceNameValue, drm_info.client_info().device_name()); + EXPECT_EQ(kProductNameValue, + drm_info.client_info().product_info().product_name()); + EXPECT_EQ(kBuildInfoValue, + drm_info.client_info().product_info().build_info()); + EXPECT_EQ( + kOemCryptoSecurityPatchLevelValue, + drm_info.client_info().product_info().oem_crypto_security_patch_level()); } -TEST_F(SecurityProfileListTest, ProfileLevels) { - SecurityProfile::DrmInfo drm_info; - profile_list_->Init(); +TEST_F(SecurityProfileListTest, QualifiedProfiles) { + SecurityProfile profile1; + profile1.set_name("profile1"); + profile1.mutable_min_security_requirements()->set_security_level( + ProvisionedDeviceInfo::LEVEL_3); + profile1.mutable_min_output_requirements()->set_hdcp_version( + ClientCapabilities::HDCP_V1); + profile_list_->InsertProfile(profile1); - client_id_.mutable_client_capabilities()->set_max_hdcp_version( - ClientCapabilities::HDCP_NONE); - client_id_.mutable_client_capabilities()->set_analog_output_capabilities( - ClientCapabilities::ANALOG_OUTPUT_UNKNOWN); - client_id_.mutable_client_capabilities()->set_oem_crypto_api_version(7); - client_id_.mutable_client_capabilities()->set_resource_rating_tier( - kResourceTierLow); - device_info_.set_security_level(ProvisionedDeviceInfo::LEVEL_3); + SecurityProfile profile2; + profile2.set_name("profile2"); + profile2.mutable_min_security_requirements()->set_security_level( + ProvisionedDeviceInfo::LEVEL_1); + profile2.mutable_min_output_requirements()->set_hdcp_version( + ClientCapabilities::HDCP_V2); + profile_list_->InsertProfile(profile2); - // Lowest profile level requires OEMCrypto version 8. - ASSERT_EQ( - SecurityProfile::SECURITY_PROFILE_LEVEL_UNDEFINED, - profile_list_->GetBestProfileLevel(client_id_, device_info_, &drm_info)); + // Both profiles should qualify based on client_info and device_info from the + // Setup function. + std::vector qualified_profiles; + EXPECT_EQ(2, profile_list_->GetQualifiedProfiles(client_id_, device_info_, + &qualified_profiles)); + EXPECT_NE(qualified_profiles.end(), + std::find(qualified_profiles.begin(), qualified_profiles.end(), + profile1.name())); + EXPECT_NE(qualified_profiles.end(), + std::find(qualified_profiles.begin(), qualified_profiles.end(), + profile2.name())); - // Move up to profile 1 - client_id_.mutable_client_capabilities()->set_oem_crypto_api_version(8); - ASSERT_EQ( - SecurityProfile::SECURITY_PROFILE_LEVEL_1, - profile_list_->GetBestProfileLevel(client_id_, device_info_, &drm_info)); - - // Move up to profile 2 - client_id_.mutable_client_capabilities()->set_analog_output_capabilities( - ClientCapabilities::ANALOG_OUTPUT_SUPPORTS_CGMS_A); - client_id_.mutable_client_capabilities()->set_oem_crypto_api_version(12); - device_info_.set_security_level(ProvisionedDeviceInfo::LEVEL_2); - ASSERT_EQ( - SecurityProfile::SECURITY_PROFILE_LEVEL_2, - profile_list_->GetBestProfileLevel(client_id_, device_info_, &drm_info)); - - // Move up to profile 3 + // Reduce the DRM capabilities of the device so profile2 will not qualify. client_id_.mutable_client_capabilities()->set_max_hdcp_version( ClientCapabilities::HDCP_V1); - device_info_.set_security_level(ProvisionedDeviceInfo::LEVEL_1); - client_id_.mutable_client_capabilities()->set_resource_rating_tier( - kResourceTierMed); - ASSERT_EQ( - SecurityProfile::SECURITY_PROFILE_LEVEL_3, - profile_list_->GetBestProfileLevel(client_id_, device_info_, &drm_info)); - - // Move up to profile 4 - client_id_.mutable_client_capabilities()->set_max_hdcp_version( - ClientCapabilities::HDCP_V2_2); - client_id_.mutable_client_capabilities()->set_resource_rating_tier( - kResourceTierHigh); - ASSERT_EQ( - SecurityProfile::SECURITY_PROFILE_LEVEL_4, - profile_list_->GetBestProfileLevel(client_id_, device_info_, &drm_info)); + ASSERT_EQ(1, profile_list_->GetQualifiedProfiles(client_id_, device_info_, + &qualified_profiles)); + EXPECT_NE(qualified_profiles.end(), + std::find(qualified_profiles.begin(), qualified_profiles.end(), + profile1.name())); } TEST_F(SecurityProfileListTest, FindProfile) { - // This test will not initialize the SecurityProfileList, hence it's empty. - // Insert test profile 1 into the list. + SecurityProfileList profile_list("widevine-test"); SecurityProfile profile1; - profile1.set_level(SecurityProfile::SECURITY_PROFILE_LEVEL_1); profile1.set_name("profile1"); + profile1.mutable_min_security_requirements()->set_security_level( + ProvisionedDeviceInfo::LEVEL_3); + EXPECT_EQ(kDefaultContentOwnerName, profile1.owner()); SecurityProfile profile2; - profile2.set_level(SecurityProfile::SECURITY_PROFILE_LEVEL_2); profile2.set_name("profile2"); - SecurityProfile profile3; - profile3.set_level(SecurityProfile::SECURITY_PROFILE_LEVEL_3); - profile3.set_name("profile3"); - - // Insert profiles 1 & 2, but not 3.. - EXPECT_TRUE(profile_list_->InsertProfile(profile1)); - EXPECT_TRUE(profile_list_->InsertProfile(profile2)); - - // Find the profile by its level. - SecurityProfile profile; - EXPECT_TRUE(profile_list_->GetProfileByLevel(profile1.level(), &profile)); - EXPECT_EQ(profile1.name(), profile.name()); - EXPECT_EQ(profile1.level(), profile.level()); - - EXPECT_TRUE(profile_list_->GetProfileByLevel(profile2.level(), &profile)); - EXPECT_EQ(profile2.name(), profile.name()); - EXPECT_EQ(profile2.level(), profile.level()); - - EXPECT_FALSE(profile_list_->GetProfileByName(profile3.name(), nullptr)); + profile2.mutable_min_security_requirements()->set_security_level( + ProvisionedDeviceInfo::LEVEL_3); + // Override the default owner name. + profile2.set_owner("owner2"); + // Insert profiles 1 & 2. + EXPECT_TRUE(profile_list.InsertProfile(profile1)); + EXPECT_TRUE(profile_list.InsertProfile(profile2)); + EXPECT_EQ(2, profile_list.NumProfiles()); // Find the profile by its name. - EXPECT_TRUE(profile_list_->GetProfileByName(profile1.name(), &profile)); + SecurityProfile profile; + EXPECT_TRUE(profile_list.GetProfileByName(profile1.name(), &profile)); EXPECT_EQ(profile1.name(), profile.name()); EXPECT_EQ(profile1.level(), profile.level()); + EXPECT_EQ(profile1.owner(), profile.owner()); - EXPECT_TRUE(profile_list_->GetProfileByName(profile2.name(), &profile)); + EXPECT_TRUE(profile_list.GetProfileByName(profile2.name(), &profile)); EXPECT_EQ(profile2.name(), profile.name()); EXPECT_EQ(profile2.level(), profile.level()); + EXPECT_EQ(profile2.owner(), profile.owner()); - EXPECT_FALSE(profile_list_->GetProfileByName(profile3.name(), nullptr)); + EXPECT_FALSE( + profile_list.GetProfileByName("you-should-not-find-me", &profile)); } } // namespace security_profile diff --git a/common/sha_util.cc b/common/sha_util.cc index 5f3d5b2..0844855 100644 --- a/common/sha_util.cc +++ b/common/sha_util.cc @@ -37,7 +37,8 @@ std::string Sha512_Hash(const std::string& message) { return digest; } -std::string GenerateSha1Uuid(const std::string& name_space, const std::string& name) { +std::string GenerateSha1Uuid(const std::string& name_space, + const std::string& name) { // X.667 14 Setting the fields of a name-based UUID. // - Allocate a UUID to use as a "name space identifier" for all UUIDs // generated from names in that name space. diff --git a/common/sha_util_test.cc b/common/sha_util_test.cc index e0354d2..9d71936 100644 --- a/common/sha_util_test.cc +++ b/common/sha_util_test.cc @@ -7,6 +7,7 @@ //////////////////////////////////////////////////////////////////////////////// #include "common/sha_util.h" + #include "testing/gunit.h" #include "absl/strings/escaping.h" diff --git a/common/signature_util.cc b/common/signature_util.cc index 3927390..48a4a59 100644 --- a/common/signature_util.cc +++ b/common/signature_util.cc @@ -39,6 +39,7 @@ Status GenerateAesSignature(const std::string& message, Status GenerateRsaSignature(const std::string& message, const std::string& private_key, + HashAlgorithm hash_algorithm, std::string* signature) { if (signature == nullptr) { return Status(error::INVALID_ARGUMENT, "signature is nullptr"); @@ -49,7 +50,7 @@ Status GenerateRsaSignature(const std::string& message, return Status(error::INTERNAL, "Failed to construct a RsaPrivateKey"); } std::string sig; - if (!rsa_private_key->GenerateSignature(message, &sig)) { + if (!rsa_private_key->GenerateSignature(message, hash_algorithm, &sig)) { return Status(error::INTERNAL, "Failed to generate a RSA signature"); } if (sig.empty()) { diff --git a/common/signature_util.h b/common/signature_util.h index ee64d47..8e6a086 100644 --- a/common/signature_util.h +++ b/common/signature_util.h @@ -11,6 +11,7 @@ #include +#include "common/hash_algorithm.h" #include "common/status.h" namespace widevine { @@ -23,11 +24,13 @@ Status GenerateAesSignature(const std::string& message, const std::string& aes_key, const std::string& aes_iv, std::string* signature); -// Generates a RSA signature of |message| using |private_key|. -// Signature is returned via |sigature| if generation was successful. -// Returns a Status that carries the details of error if generation failed. +// Generates a RSA signature of |message| using |private_key| and +// |hash_algorithm|. Signature is returned via |sigature| if generation was +// successful. Returns a Status that carries the details of error if generation +// failed. Status GenerateRsaSignature(const std::string& message, const std::string& private_key, + HashAlgorithm hash_algorithm, std::string* signature); } // namespace signature_util diff --git a/common/signer_public_key.cc b/common/signer_public_key.cc index c1ac103..87dd617 100644 --- a/common/signer_public_key.cc +++ b/common/signer_public_key.cc @@ -25,9 +25,10 @@ class SignerPublicKeyImpl : public SignerPublicKey { SignerPublicKeyImpl(const SignerPublicKeyImpl&) = delete; SignerPublicKeyImpl& operator=(const SignerPublicKeyImpl&) = delete; - bool VerifySignature(const std::string& message, + bool VerifySignature(const std::string& message, HashAlgorithm hash_algorithm, const std::string& signature) const override { - if (!signer_public_key_->VerifySignature(message, signature)) { + if (!signer_public_key_->VerifySignature(message, hash_algorithm, + signature)) { return false; } return true; diff --git a/common/signer_public_key.h b/common/signer_public_key.h index 228db48..9670f6d 100644 --- a/common/signer_public_key.h +++ b/common/signer_public_key.h @@ -11,6 +11,7 @@ #include +#include "common/hash_algorithm.h" #include "protos/public/drm_certificate.pb.h" namespace widevine { @@ -24,8 +25,9 @@ class SignerPublicKey { SignerPublicKey(const SignerPublicKey&) = delete; SignerPublicKey& operator=(const SignerPublicKey&) = delete; - // Verify message using |signer_public_key_|. + // Verify message using |signer_public_key_| and |hash_algorithm|. virtual bool VerifySignature(const std::string& message, + HashAlgorithm hash_algorithm, const std::string& signature) const = 0; // A factory method to create a SignerPublicKey. The |algorithm| is used to diff --git a/common/signer_public_key_test.cc b/common/signer_public_key_test.cc index 7f079d0..2243690 100644 --- a/common/signer_public_key_test.cc +++ b/common/signer_public_key_test.cc @@ -13,6 +13,7 @@ #include "testing/gunit.h" #include "common/ec_key.h" #include "common/ec_test_keys.h" +#include "common/hash_algorithm.h" #include "common/rsa_key.h" #include "common/rsa_test_keys.h" #include "protos/public/drm_certificate.pb.h" @@ -20,6 +21,7 @@ namespace widevine { static const char kMessage[] = "The rain in Spain falls mainly in the blank?"; +const HashAlgorithm kHashAlgorithm = HashAlgorithm::kSha256; class SignerPublicKeyTest : public ::testing::Test { public: @@ -32,12 +34,13 @@ TEST_F(SignerPublicKeyTest, RSA) { RsaPrivateKey::Create(rsa_test_keys_.private_test_key_1_3072_bits())); std::string signature; - ASSERT_TRUE(private_key->GenerateSignature(kMessage, &signature)); + ASSERT_TRUE( + private_key->GenerateSignature(kMessage, kHashAlgorithm, &signature)); std::unique_ptr public_key = SignerPublicKey::Create( rsa_test_keys_.public_test_key_1_3072_bits(), DrmCertificate::RSA); ASSERT_NE(public_key, nullptr); - EXPECT_TRUE(public_key->VerifySignature(kMessage, signature)); + EXPECT_TRUE(public_key->VerifySignature(kMessage, kHashAlgorithm, signature)); } TEST_F(SignerPublicKeyTest, ECC) { @@ -45,13 +48,14 @@ TEST_F(SignerPublicKeyTest, ECC) { ECPrivateKey::Create(ec_test_keys_.private_test_key_1_secp521r1()); std::string signature; - ASSERT_TRUE(private_key->GenerateSignature(kMessage, &signature)); + ASSERT_TRUE( + private_key->GenerateSignature(kMessage, kHashAlgorithm, &signature)); std::unique_ptr public_key = SignerPublicKey::Create(ec_test_keys_.public_test_key_1_secp521r1(), DrmCertificate::ECC_SECP521R1); ASSERT_NE(public_key, nullptr); - EXPECT_TRUE(public_key->VerifySignature(kMessage, signature)); + EXPECT_TRUE(public_key->VerifySignature(kMessage, kHashAlgorithm, signature)); } TEST_F(SignerPublicKeyTest, IncorrectAlgorithm) { diff --git a/common/signing_key_util_test.cc b/common/signing_key_util_test.cc index cb9fb56..70e448c 100644 --- a/common/signing_key_util_test.cc +++ b/common/signing_key_util_test.cc @@ -7,6 +7,7 @@ //////////////////////////////////////////////////////////////////////////////// #include "common/signing_key_util.h" + #include "testing/gunit.h" #include "absl/strings/escaping.h" #include "common/crypto_util.h" diff --git a/common/string_util_test.cc b/common/string_util_test.cc index 9354571..a28e619 100644 --- a/common/string_util_test.cc +++ b/common/string_util_test.cc @@ -7,6 +7,7 @@ //////////////////////////////////////////////////////////////////////////////// #include "common/string_util.h" + #include "testing/gmock.h" #include "testing/gunit.h" diff --git a/common/test_utils.cc b/common/test_utils.cc index 5dfce0b..2b4973a 100644 --- a/common/test_utils.cc +++ b/common/test_utils.cc @@ -9,6 +9,7 @@ #include "common/test_utils.h" #include + #include #include "glog/logging.h" diff --git a/common/test_utils.h b/common/test_utils.h index 9252f93..bc8f701 100644 --- a/common/test_utils.h +++ b/common/test_utils.h @@ -22,7 +22,7 @@ namespace widevine { // and PKCS#1 1.5 padding. |pem_private_key| is a PEM-encoded private RSA key, // |message| is the message to be signed, and |signature| is a pointer to a // std::string where the signature will be stored. The caller returns ownership of -// all paramters. +// all parameters. Status GenerateRsaSignatureSha256Pkcs1(const std::string& pem_private_key, const std::string& message, std::string* signature); diff --git a/common/verified_media_pipeline.h b/common/verified_media_pipeline.h index a6372ce..8a95a22 100644 --- a/common/verified_media_pipeline.h +++ b/common/verified_media_pipeline.h @@ -13,6 +13,7 @@ #define COMMON_VERIFIED_MEDIA_PIPELINE_H_ #include + #include "common/status.h" #include "protos/public/license_protocol.pb.h" diff --git a/common/vmp_checker.cc b/common/vmp_checker.cc index 64f1474..9479582 100644 --- a/common/vmp_checker.cc +++ b/common/vmp_checker.cc @@ -13,12 +13,14 @@ #include "common/vmp_checker.h" #include + #include #include #include "glog/logging.h" #include "common/certificate_type.h" #include "common/error_space.h" +#include "common/hash_algorithm_util.h" #include "common/rsa_key.h" #include "common/x509_cert.h" #include "protos/public/errors.pb.h" @@ -334,7 +336,9 @@ Status VmpChecker::VerifyVmpData(const std::string& vmp_data, Result* result) { std::unique_ptr key(cert->GetRsaPublicKey()); std::string message(binary_info.binary_hash()); message += binary_info.flags() & 0xff; - if (!key->VerifySignature(message, binary_info.signature())) { + if (!key->VerifySignature( + message, HashAlgorithmProtoToEnum(binary_info.hash_algorithm()), + binary_info.signature())) { LOG(INFO) << "Code signature verification failed for file \"" << binary_info.file_name() << "\"."; *result = kTampered; diff --git a/common/vmp_checker_test.cc b/common/vmp_checker_test.cc index 5c67830..d10958a 100644 --- a/common/vmp_checker_test.cc +++ b/common/vmp_checker_test.cc @@ -6,14 +6,16 @@ // widevine-licensing@google.com. //////////////////////////////////////////////////////////////////////////////// +#include "common/vmp_checker.h" + #include #include "glog/logging.h" #include "testing/gmock.h" #include "testing/gunit.h" #include "absl/strings/escaping.h" +#include "common/hash_algorithm_util.h" #include "common/rsa_key.h" -#include "common/vmp_checker.h" #include "protos/public/errors.pb.h" #include "protos/public/verified_media_pipeline.pb.h" @@ -184,7 +186,9 @@ class VmpCheckerTest : public ::testing::Test { std::string message(binary_hash); message += flags & 0xff; std::string signature; - ASSERT_TRUE(signing_key_->GenerateSignature(message, &signature)); + ASSERT_TRUE(signing_key_->GenerateSignature( + message, HashAlgorithmProtoToEnum(new_binary->hash_algorithm()), + &signature)); new_binary->set_signature(signature); } diff --git a/common/wvm_test_keys.cc b/common/wvm_test_keys.cc index 7c39ad9..c0842a7 100644 --- a/common/wvm_test_keys.cc +++ b/common/wvm_test_keys.cc @@ -6,11 +6,12 @@ // widevine-licensing@google.com. //////////////////////////////////////////////////////////////////////////////// +#include "common/wvm_test_keys.h" + #include #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" -#include "common/wvm_test_keys.h" #include "common/wvm_token_handler.h" namespace widevine { diff --git a/common/wvm_token_handler_test.cc b/common/wvm_token_handler_test.cc index b3f0bf9..872096a 100644 --- a/common/wvm_token_handler_test.cc +++ b/common/wvm_token_handler_test.cc @@ -7,22 +7,23 @@ //////////////////////////////////////////////////////////////////////////////// #include "common/wvm_token_handler.h" + #include "testing/gmock.h" #include "testing/gunit.h" #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "common/wvm_test_keys.h" -using widevine::wvm_test_keys::kTestSystemId; -using widevine::wvm_test_keys::kTestSystemId3Des; -using widevine::wvm_test_keys::kTestPreprovKeyHex; +using widevine::wvm_test_keys::GetPreprovKeyVector; using widevine::wvm_test_keys::kTestDeviceKey1Hex; using widevine::wvm_test_keys::kTestDeviceKey2Hex; using widevine::wvm_test_keys::kTestDeviceKey3DesHex; +using widevine::wvm_test_keys::kTestPreprovKeyHex; +using widevine::wvm_test_keys::kTestSystemId; +using widevine::wvm_test_keys::kTestSystemId3Des; using widevine::wvm_test_keys::kTestToken1Hex; using widevine::wvm_test_keys::kTestToken2Hex; using widevine::wvm_test_keys::kTestToken3DesHex; -using widevine::wvm_test_keys::GetPreprovKeyVector; namespace widevine { diff --git a/common/x509_cert.h b/common/x509_cert.h index 6de4dcb..20f062e 100644 --- a/common/x509_cert.h +++ b/common/x509_cert.h @@ -13,12 +13,13 @@ #define COMMON_X509_CERT_H_ #include + #include #include #include #include -#include "base/thread_annotations.h" +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "openssl/pem.h" #include "openssl/x509.h" diff --git a/common/x509_cert_test.cc b/common/x509_cert_test.cc index efc9ee3..0ebe8a1 100644 --- a/common/x509_cert_test.cc +++ b/common/x509_cert_test.cc @@ -6,13 +6,14 @@ // widevine-licensing@google.com. //////////////////////////////////////////////////////////////////////////////// +#include "common/x509_cert.h" + #include #include "testing/gunit.h" #include "absl/strings/escaping.h" #include "common/rsa_key.h" #include "common/test_utils.h" -#include "common/x509_cert.h" namespace widevine { const char kTestRootCaDerCert[] = @@ -352,7 +353,6 @@ const char kTestDevCodeSigningCert[] = const char kDevCertFlagOid[] = "1.3.6.1.4.1.11129.4.1.2"; const bool kTestDevCodeSigningCertFlagValue = true; - TEST(X509CertTest, LoadCert) { X509Cert test_cert; EXPECT_EQ(OkStatus(), diff --git a/example/test_emmg_messages.h b/example/test_emmg_messages.h index aeac11a..cbc53ca 100644 --- a/example/test_emmg_messages.h +++ b/example/test_emmg_messages.h @@ -130,11 +130,11 @@ const char kTestEmmgDataProvision[] = { '\x47', '\x40', '\x00', '\x10', '\x0a', '\x0d', '\x77', '\x69', '\x64', '\x65', '\x76', '\x69', '\x6e', '\x65', '\x5f', '\x74', '\x65', '\x73', '\x74', '\x12', '\x09', '\x43', '\x61', '\x73', '\x54', '\x73', '\x46', - '\x61', '\x6b', '\x65', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', - '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', - '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', - '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', - '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', + '\x61', '\x6b', '\x65', '\x1a', '\x10', '\x66', '\x61', '\x6b', '\x65', + '\x4b', '\x65', '\x79', '\x49', '\x64', '\x31', '\x4b', '\x65', '\x79', + '\x49', '\x64', '\x31', '\x1a', '\x10', '\x66', '\x61', '\x6b', '\x65', + '\x4b', '\x65', '\x79', '\x49', '\x64', '\x32', '\x4b', '\x65', '\x79', + '\x49', '\x64', '\x32', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', @@ -149,6 +149,26 @@ const char kTestEmmgDataProvision[] = { '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'}; +const char kTestEmptyEmmgDataProvision[] = { + '\x02', // protocol_version + '\x02', '\x11', // message_type - Data_provision + '\x00', '\x00', // message_length + '\x00', '\x01', // parameter_type - client_id + '\x00', '\x04', // parameter_length + '\x4a', '\xd4', '\x00', '\x00', // parameter_value + '\x00', '\x03', // parameter_type - data_channel_id + '\x00', '\x02', // parameter_length + '\x00', '\x01', // parameter_value + '\x00', '\x04', // parameter_type - data_stream_id + '\x00', '\x02', // parameter_length + '\x00', '\x01', // parameter_value + '\x00', '\x08', // parameter_type - data_id + '\x00', '\x02', // parameter_length + '\x00', '\x01', // parameter_value + '\x00', '\x00', // parameter_type - datagram + '\x00', '\x00', // parameter_length +}; + const char kTestEmmgStreamCloseRequest[] = { '\x02', // protocol_version '\x01', '\x14', // message_type - Stream_close_request diff --git a/media_cas_packager_sdk/internal/BUILD b/media_cas_packager_sdk/internal/BUILD index 72844f9..767ec4a 100644 --- a/media_cas_packager_sdk/internal/BUILD +++ b/media_cas_packager_sdk/internal/BUILD @@ -90,6 +90,7 @@ cc_library( "@abseil_repo//absl/container:node_hash_map", "@abseil_repo//absl/memory", "@abseil_repo//absl/strings", + "@abseil_repo//absl/types:optional", "//common:crypto_util", "//common:random_util", "//common:status", @@ -242,3 +243,25 @@ cc_test( "//common:status", ], ) + +cc_library( + name = "emm", + srcs = ["emm.cc"], + hdrs = ["emm.h"], + deps = [ + "//base", + "@abseil_repo//absl/strings", + "//common:status", + "//common:string_util", + "//protos/public:media_cas_cc_proto", + ], +) + +cc_test( + name = "emm_test", + srcs = ["emm_test.cc"], + deps = [ + ":emm", + "//testing:gunit_main", + ], +) diff --git a/media_cas_packager_sdk/internal/ecm_test.cc b/media_cas_packager_sdk/internal/ecm_test.cc index 390323b..2d93259 100644 --- a/media_cas_packager_sdk/internal/ecm_test.cc +++ b/media_cas_packager_sdk/internal/ecm_test.cc @@ -95,10 +95,10 @@ class MockEcm : public Ecm { MockEcm() = default; ~MockEcm() override = default; - MOCK_CONST_METHOD0(age_restriction, uint8_t()); - MOCK_CONST_METHOD0(crypto_mode, CryptoMode()); - MOCK_CONST_METHOD0(paired_keys_required, bool()); - MOCK_CONST_METHOD0(content_iv_size, size_t()); + MOCK_METHOD(uint8_t, age_restriction, (), (const, override)); + MOCK_METHOD(CryptoMode, crypto_mode, (), (const, override)); + MOCK_METHOD(bool, paired_keys_required, (), (const, override)); + MOCK_METHOD(size_t, content_iv_size, (), (const, override)); std::string CallSerializeEcm(const std::vector& keys) { return SerializeEcm(keys); diff --git a/media_cas_packager_sdk/internal/ecmg_client_handler.cc b/media_cas_packager_sdk/internal/ecmg_client_handler.cc index e2823d8..4be8aa9 100644 --- a/media_cas_packager_sdk/internal/ecmg_client_handler.cc +++ b/media_cas_packager_sdk/internal/ecmg_client_handler.cc @@ -8,6 +8,7 @@ #include "media_cas_packager_sdk/internal/ecmg_client_handler.h" +#include #include #include "glog/logging.h" @@ -873,32 +874,55 @@ Status EcmgClientHandler::BuildEcmDatagram(const EcmgParameters& params, EcmgStreamInfo* stream_info = streams_info_.at(params.ecm_stream_id).get(); DCHECK(stream_info->ecm); + size_t key_count = params.cp_cw_combinations.size(); + // Number of keys can only be 1 or 2. + if (key_count != ecmg_config_->number_of_content_keys || key_count < 1 || + key_count > 2) { + return {error::INVALID_ARGUMENT, "Unexpected cp_cw_combinations size."}; + } + // Generate serialized ECM. CryptoMode crypto_mode = stream_info->crypto_mode == CryptoMode::kInvalid ? ecmg_config_->crypto_mode : stream_info->crypto_mode; + + // If two keys are present, the even key should be put first, followed by the + // odd key. std::vector keys; - keys.reserve(ecmg_config_->number_of_content_keys); - for (size_t i = 0; i < ecmg_config_->number_of_content_keys; i++) { - DCHECK(params.cp_cw_combinations[i].cp == params.cp_number + i); - keys.emplace_back(); - keys[i].key_value = params.cp_cw_combinations[i].cw; + keys.reserve(key_count); + for (const auto& cp_cw : params.cp_cw_combinations) { + auto key_info = cp_cw.cp % 2 == 0 ? keys.emplace(keys.begin()) + : keys.emplace(keys.end()); + key_info->key_value = cp_cw.cw; // Make content key to 16 bytes if crypto mode is Csa2. - if (crypto_mode == CryptoMode::kDvbCsa2 && keys[i].key_value.size() == 8) { - keys[i].key_value = keys[i].key_value + keys[i].key_value; + if (crypto_mode == CryptoMode::kDvbCsa2 && + key_info->key_value.size() == 8) { + key_info->key_value = + absl::StrCat(key_info->key_value, key_info->key_value); } - keys[i].key_id = crypto_util::DeriveKeyId(keys[i].key_value); - keys[i].content_iv = stream_info->content_ivs.empty() - ? content_ivs_[i] - : stream_info->content_ivs[i]; - if (!RandomBytes(kWrappedKeyIvSizeBytes, &keys[i].wrapped_key_iv)) { + key_info->key_id = crypto_util::DeriveKeyId(key_info->key_value); + auto generated_key_iv = GenerateRandomWrappedKeyIv(); + if (!generated_key_iv.has_value() || + generated_key_iv->size() != kWrappedKeyIvSizeBytes) { return {error::INTERNAL, "Unable to generate random wrapped key iv."}; } + key_info->wrapped_key_iv = generated_key_iv.value(); + } + + if (content_ivs_.size() < key_count && + stream_info->content_ivs.size() < key_count) { + return {error::INVALID_ARGUMENT, "Not enough content iv."}; + } + // The first iv received is for even key and second is for odd key. + for (size_t i = 0; i < key_count; ++i) { + keys[i].content_iv = stream_info->content_ivs.size() >= key_count + ? stream_info->content_ivs[i] + : content_ivs_[i]; } Status status; std::string serialized_ecm; - if (ecmg_config_->number_of_content_keys > 1) { + if (key_count > 1) { status = stream_info->ecm->GenerateEcm( &keys[0], &keys[1], stream_info->track_type, &serialized_ecm); } else { @@ -926,5 +950,14 @@ Status EcmgClientHandler::BuildEcmDatagram(const EcmgParameters& params, return OkStatus(); } +absl::optional EcmgClientHandler::GenerateRandomWrappedKeyIv() + const { + std::string output_iv; + if (RandomBytes(kWrappedKeyIvSizeBytes, &output_iv)) { + return output_iv; + } + return absl::nullopt; +} + } // namespace cas } // namespace widevine diff --git a/media_cas_packager_sdk/internal/ecmg_client_handler.h b/media_cas_packager_sdk/internal/ecmg_client_handler.h index fd11de4..7ffe6e7 100644 --- a/media_cas_packager_sdk/internal/ecmg_client_handler.h +++ b/media_cas_packager_sdk/internal/ecmg_client_handler.h @@ -16,6 +16,7 @@ #include #include "absl/container/node_hash_map.h" +#include "absl/types/optional.h" #include "common/status.h" #include "media_cas_packager_sdk/internal/ecm.h" #include "media_cas_packager_sdk/public/wv_cas_types.h" @@ -126,6 +127,10 @@ class EcmgClientHandler { Status BuildEcmDatagram(const EcmgParameters& params, uint8_t* ecm_datagram) const; + // Generates a random wrapped key iv string. Returns true on success, false + // otherwise. The main purpose for this function is easier testing. + virtual absl::optional GenerateRandomWrappedKeyIv() const; + EcmgConfig* ecmg_config_; // Per spec, "There is always one (and only one) channel per TCP connection". bool channel_id_set_; diff --git a/media_cas_packager_sdk/internal/ecmg_client_handler_test.cc b/media_cas_packager_sdk/internal/ecmg_client_handler_test.cc index 86637b6..0f04e79 100644 --- a/media_cas_packager_sdk/internal/ecmg_client_handler_test.cc +++ b/media_cas_packager_sdk/internal/ecmg_client_handler_test.cc @@ -12,12 +12,14 @@ #include #include +#include "testing/gmock.h" #include "testing/gunit.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "example/test_ecmg_messages.h" #include "media_cas_packager_sdk/internal/ecmg_constants.h" +#include "media_cas_packager_sdk/internal/mpeg2ts.h" #include "media_cas_packager_sdk/internal/simulcrypt_util.h" #include "media_cas_packager_sdk/internal/util.h" @@ -31,28 +33,36 @@ using simulcrypt_util::AddUint32Param; using simulcrypt_util::AddUint8Param; using simulcrypt_util::BuildMessageHeader; -static constexpr size_t kBufferSize = 1024; -static constexpr size_t kSuperCasId = 0x4AD40000; -static constexpr size_t kChannelId = 1; -static constexpr size_t kStreamId = 1; -static constexpr size_t kEcmId = 2; -static constexpr size_t kNominalCpDuration = 0x64; -static constexpr size_t kCpNumber = 0; -static constexpr char kContentKeyEven[] = "0123456701234567"; -static constexpr char kContentKeyEven8Bytes[] = "01234567"; -static constexpr char kContentKeyOdd[] = "abcdefghabcdefgh"; -static constexpr char kContentKeyOdd8Bytes[] = "abcdefgh"; -static constexpr char kEntitlementKeyIdEven[] = "0123456701234567"; -static constexpr char kEntitlementKeyValueEven[] = - "01234567012345670123456701234567"; -static constexpr char kEntitlementKeyIdOdd[] = "abcdefghabcdefgh"; -static constexpr char kEntitlementKeyValueOdd[] = - "abcdefghabcdefghabcdefghabcdefgh"; -static constexpr size_t kAgeRestriction = 3; -static constexpr char kCryptoMode[] = "AesScte"; -static constexpr char kCryptoModeCsa2[] = "DvbCsa2"; -static constexpr char kTrackTypesSD[] = "SD"; -static constexpr char kTrackTypesHD[] = "HD"; +constexpr size_t kBufferSize = 1024; +constexpr size_t kSuperCasId = 0x4AD40000; +constexpr size_t kChannelId = 1; +constexpr size_t kStreamId = 1; +constexpr size_t kEcmId = 2; +constexpr size_t kNominalCpDuration = 0x64; +constexpr size_t kCpNumber = 0; +constexpr char kContentKeyEven[] = "0123456701234567"; +constexpr char kContentKeyEven8Bytes[] = "01234567"; +constexpr char kContentKeyOdd[] = "abcdefghabcdefgh"; +constexpr char kContentKeyOdd8Bytes[] = "abcdefgh"; +constexpr char kEntitlementKeyIdEven[] = "0123456701234567"; +constexpr char kEntitlementKeyValueEven[] = "01234567012345670123456701234567"; +constexpr char kEntitlementKeyIdOdd[] = "abcdefghabcdefgh"; +constexpr char kEntitlementKeyValueOdd[] = "abcdefghabcdefghabcdefghabcdefgh"; +constexpr size_t kAgeRestriction = 3; +constexpr char kCryptoMode[] = "AesScte"; +constexpr char kCryptoModeCsa2[] = "DvbCsa2"; +constexpr char kTrackTypesSD[] = "SD"; +constexpr char kTrackTypesHD[] = "HD"; +constexpr absl::string_view kWrappedKeyIv = "0123456701234567"; + +class MockEcmgClientHandler : public EcmgClientHandler { + public: + explicit MockEcmgClientHandler(EcmgConfig* ecmg_config) + : EcmgClientHandler(ecmg_config) {} + absl::optional GenerateRandomWrappedKeyIv() const override { + return std::string(kWrappedKeyIv); + } +}; class EcmgClientHandlerTest : public ::testing::Test { protected: @@ -63,7 +73,7 @@ class EcmgClientHandlerTest : public ::testing::Test { config_.max_comp_time = 100; config_.access_criteria_transfer_mode = 1; config_.number_of_content_keys = 2; - handler_ = absl::make_unique(&config_); + handler_ = absl::make_unique(&config_); } protected: @@ -564,6 +574,37 @@ TEST_F(EcmgClientHandlerTest, WrongMessageLength) { CheckChannelError(UNKNOWN_PARAMETER_TYPE_VALUE, response_, response_len_); } +TEST_F(EcmgClientHandlerTest, BuildEcmDatagramSequenceOfEvenOdd) { + SetupValidChannelStream(); + + std::vector cp_cw_combination = { + {kCpNumber, kContentKeyEven}, {kCpNumber + 1, kContentKeyOdd}}; + BuildCwProvisionRequest(kChannelId, kStreamId, kCpNumber, cp_cw_combination, + request_, &request_len_); + handler_->HandleRequest(request_, response_, &response_len_); + EXPECT_EQ(response_len_, sizeof(kTestEcmgEcmResponse)); + std::string first_response(response_, response_len_); + + // Change the sequence of cp_cw_combination. + cp_cw_combination = {{kCpNumber + 1, kContentKeyOdd}, + {kCpNumber, kContentKeyEven}}; + BuildCwProvisionRequest(kChannelId, kStreamId, kCpNumber, cp_cw_combination, + request_, &request_len_); + handler_->HandleRequest(request_, response_, &response_len_); + // Sequence of cp_cw_combination does not matter as even/odd is based on cp + // number. + EXPECT_EQ(std::string(response_, response_len_), first_response); + + // Swap the key value in cp_cw_combination. + cp_cw_combination = {{kCpNumber, kContentKeyOdd}, + {kCpNumber + 1, kContentKeyEven}}; + BuildCwProvisionRequest(kChannelId, kStreamId, kCpNumber, cp_cw_combination, + request_, &request_len_); + handler_->HandleRequest(request_, response_, &response_len_); + // Swapping key value changes generated ecm. + EXPECT_NE(std::string(response_, response_len_), first_response); +} + } // namespace } // namespace cas } // namespace widevine diff --git a/media_cas_packager_sdk/internal/emm.cc b/media_cas_packager_sdk/internal/emm.cc new file mode 100644 index 0000000..0f3266c --- /dev/null +++ b/media_cas_packager_sdk/internal/emm.cc @@ -0,0 +1,136 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#include "media_cas_packager_sdk/internal/emm.h" + +#include + +#include "glog/logging.h" +#include "absl/strings/str_cat.h" +#include "common/status.h" +#include "common/string_util.h" + +namespace widevine { +namespace cas { + +namespace { + +constexpr int kNumBitsVersionField = 8; +constexpr int kNumBitsHeaderLengthField = 8; +constexpr int kNumBitsTimestampLengthField = 64; +constexpr int kNumBitsPayloadLengthField = 16; + +// Version - this should be incremented if there are changes to the EMM. +constexpr uint8_t kEmmVersion = 1; +} // namespace + +Status Emm::SetFingerprinting( + const std::vector& fingerprintings) { + // First clear all current fingerprinting payload. + emm_payload_.clear_fingerprinting(); + + // TODO(b/161149665): validate passed in data. + Status status; + for (const auto& fingerprinting_param : fingerprintings) { + Fingerprinting* fingerprinting_payload = emm_payload_.add_fingerprinting(); + for (const auto& channel : fingerprinting_param.channels) { + fingerprinting_payload->add_channels(channel); + } + fingerprinting_payload->set_control(fingerprinting_param.control); + } + return status; +} + +Status Emm::SetServiceBlocking( + const std::vector& service_blockings) { + // First clear all current service blocking payload. + emm_payload_.clear_service_blocking(); + + // TODO(b/161149665): validate passed in data. + Status status; + for (const auto& service_blocking_param : service_blockings) { + ServiceBlocking* service_blocking_payload = + emm_payload_.add_service_blocking(); + for (const auto& channel : service_blocking_param.channels) { + service_blocking_payload->add_channels(channel); + } + for (const auto& device_group : service_blocking_param.device_groups) { + service_blocking_payload->add_device_groups(device_group); + } + if (service_blocking_param.start_time != 0) { + service_blocking_payload->set_start_time_sec( + service_blocking_param.start_time); + } + service_blocking_payload->set_end_time_sec(service_blocking_param.end_time); + } + return status; +} + +Status Emm::GenerateEmm(std::string* serialized_emm) const { + if (serialized_emm == nullptr) { + return {error::INVALID_ARGUMENT, "No return emm std::string pointer."}; + } + + EmmSerializingParameters serializing_params; + serializing_params.payload = emm_payload_.SerializeAsString(); + serializing_params.timestamp = GenerateTimestamp(); + + // Generate serialized emm (without signature yet). + Status status = + GenerateSerializedEmmNoSignature(serializing_params, serialized_emm); + if (!status.ok()) { + return status; + } + + // Calculate and append signature. + absl::StrAppend(serialized_emm, GenerateSignature(*serialized_emm)); + return OkStatus(); +} + +Status Emm::GenerateSerializedEmmNoSignature( + const EmmSerializingParameters& params, std::string* serialized_emm) const { + if (serialized_emm == nullptr) { + return {error::INVALID_ARGUMENT, "No return emm std::string pointer."}; + } + + std::bitset version(kEmmVersion); + std::bitset header_length( + sizeof(params.timestamp)); + std::bitset timestamp(params.timestamp); + std::bitset payload_length( + params.payload.length()); + + std::string emm_bitset = + absl::StrCat(version.to_string(), header_length.to_string(), + timestamp.to_string(), payload_length.to_string()); + + Status status = + string_util::BitsetStringToBinaryString(emm_bitset, serialized_emm); + if (!status.ok() || serialized_emm->empty()) { + LOG(ERROR) << "Failed to convert EMM bitset to std::string"; + return {error::INTERNAL, "Failed to convert EMM bitset to std::string"}; + } + + // Appends payload. + absl::StrAppend(serialized_emm, params.payload); + return OkStatus(); +} + +int64_t Emm::GenerateTimestamp() const { + // TODO(b/161252065): Generate timestamp. + return 0; +} + +std::string Emm::GenerateSignature(const std::string& content) const { + // TODO(b/161252442): Calculate signature. + std::string signature(32, 'x'); // A fake 32 bytes signature. + return signature; +} + +} // namespace cas +} // namespace widevine diff --git a/media_cas_packager_sdk/internal/emm.h b/media_cas_packager_sdk/internal/emm.h new file mode 100644 index 0000000..cfb6a84 --- /dev/null +++ b/media_cas_packager_sdk/internal/emm.h @@ -0,0 +1,74 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef MEDIA_CAS_PACKAGER_SDK_INTERNAL_EMM_H_ +#define MEDIA_CAS_PACKAGER_SDK_INTERNAL_EMM_H_ + +#include +#include + +#include +#include "common/status.h" +#include "protos/public/media_cas.pb.h" + +namespace widevine { +namespace cas { + +struct FingerprintingInitParameters { + std::vector channels; + std::string control; +}; + +struct ServiceBlockingInitParameters { + std::vector channels; + std::vector device_groups; + // Value 0 in start_time means immediate. + int64_t start_time = 0; + int64_t end_time; +}; + +// Generator for producing Widevine CAS-compliant EMMs. Used to construct the +// Transport Stream packet payload of an EMM containing messages including +// fingerprinting and service blocking. +// Class Emm is not thread safe. +class Emm { + public: + Emm() = default; + Emm(const Emm&) = delete; + Emm& operator=(const Emm&) = delete; + virtual ~Emm() = default; + + // Replaces current fingerprinting info with |fingerprintings|. + Status SetFingerprinting( + const std::vector& fingerprintings); + + // Replaces current service blocking info with |service_blockings|. + Status SetServiceBlocking( + const std::vector& service_blockings); + + // Generates serialized EMM to |serialized_emm|. + Status GenerateEmm(std::string* serialized_emm) const; + + private: + struct EmmSerializingParameters { + int64_t timestamp; + std::string payload; + }; + + Status GenerateSerializedEmmNoSignature( + const EmmSerializingParameters& params, + std::string* serialized_emm) const; + int64_t GenerateTimestamp() const; + std::string GenerateSignature(const std::string& content) const; + + EmmPayload emm_payload_; +}; + +} // namespace cas +} // namespace widevine +#endif // MEDIA_CAS_PACKAGER_SDK_INTERNAL_EMM_H_ diff --git a/media_cas_packager_sdk/internal/emm_test.cc b/media_cas_packager_sdk/internal/emm_test.cc new file mode 100644 index 0000000..ffa4e1c --- /dev/null +++ b/media_cas_packager_sdk/internal/emm_test.cc @@ -0,0 +1,236 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +#include "media_cas_packager_sdk/internal/emm.h" + +#include "testing/gmock.h" +#include "testing/gunit.h" + + +namespace widevine { +namespace cas { +namespace { + +constexpr char kChannelOne[] = "CH1"; +constexpr char kChannelTwo[] = "CH2"; +constexpr char kChannelThree[] = "CH3"; +constexpr char kFingerprintingControl[] = "controls"; +constexpr char kDeviceGroupOne[] = "Group1"; +constexpr char kDeviceGroupTwo[] = "Group2"; +constexpr int64_t kServiceBockingStartTime = 100; +constexpr int64_t kServiceBockingEndTime = 1000; + +// Length in bytes when there is no payload. +constexpr uint8_t kExpectedNoPayloadLengthBytes = 44; +constexpr uint8_t kExpectedEmmVersion = 1; +constexpr uint8_t kExpectedHeaderLength = 8; + +constexpr uint8_t kVersionStartIndex = 0; +constexpr uint8_t kHeaderLengthStartIndex = 1; +constexpr uint8_t kPayloadLengthStartIndex = 10; +constexpr uint8_t kPayloadStartIndex = 12; +constexpr uint8_t kSignatureLength = 32; + +FingerprintingInitParameters GetValidFingerprintingParams() { + FingerprintingInitParameters fingerprinting_params; + fingerprinting_params.channels = {kChannelOne, kChannelTwo}; + fingerprinting_params.control = kFingerprintingControl; + return fingerprinting_params; +} + +void LoadExpectedFingerprintingProto(Fingerprinting* fingerprinting_payload) { + fingerprinting_payload->add_channels(kChannelOne); + fingerprinting_payload->add_channels(kChannelTwo); + fingerprinting_payload->set_control(kFingerprintingControl); +} + +ServiceBlockingInitParameters GetValidServiceBlockingParams() { + ServiceBlockingInitParameters service_blocking_params; + service_blocking_params.channels = {kChannelOne, kChannelTwo}; + service_blocking_params.device_groups = {kDeviceGroupOne, kDeviceGroupTwo}; + service_blocking_params.start_time = kServiceBockingStartTime; + service_blocking_params.end_time = kServiceBockingEndTime; + return service_blocking_params; +} + +void LoadExpectedServiceBlockingProto( + ServiceBlocking* service_blocking_payload) { + service_blocking_payload->add_channels(kChannelOne); + service_blocking_payload->add_channels(kChannelTwo); + service_blocking_payload->add_device_groups(kDeviceGroupOne); + service_blocking_payload->add_device_groups(kDeviceGroupTwo); + service_blocking_payload->set_start_time_sec(kServiceBockingStartTime); + service_blocking_payload->set_end_time_sec(kServiceBockingEndTime); +} + +TEST(EmmTest, GenerateEmmSinglePayloadSuccess) { + Emm emm_gen; + FingerprintingInitParameters fingerprinting = GetValidFingerprintingParams(); + ServiceBlockingInitParameters service_blocking = + GetValidServiceBlockingParams(); + EXPECT_EQ(emm_gen.SetFingerprinting({fingerprinting}), OkStatus()); + EXPECT_EQ(emm_gen.SetServiceBlocking({service_blocking}), OkStatus()); + + std::string result; + EXPECT_EQ(emm_gen.GenerateEmm(&result), OkStatus()); + EXPECT_GT(result.length(), kExpectedNoPayloadLengthBytes); + + EXPECT_EQ(static_cast(result[kVersionStartIndex]), + kExpectedEmmVersion); + EXPECT_EQ(static_cast(result[kHeaderLengthStartIndex]), + kExpectedHeaderLength); + int payload_lengh = static_cast(result[kPayloadLengthStartIndex]) + << 8 | + static_cast(result[kPayloadLengthStartIndex + 1]); + ASSERT_GT(payload_lengh, 0); + ASSERT_EQ(result.length(), + kPayloadStartIndex + payload_lengh + kSignatureLength); + + std::string payload_section = + result.substr(kPayloadStartIndex, payload_lengh); + // Parse the payload and validate fields. + EmmPayload emm_payload; + ASSERT_TRUE(emm_payload.ParseFromString(payload_section)); + + EmmPayload expected_payload; + LoadExpectedFingerprintingProto(expected_payload.add_fingerprinting()); + LoadExpectedServiceBlockingProto(expected_payload.add_service_blocking()); + std::string serialized_expected_payload; + expected_payload.SerializeToString(&serialized_expected_payload); + EXPECT_EQ(payload_section, serialized_expected_payload); +} + +TEST(EmmTest, GenerateEmmMultiplePayloadSuccess) { + Emm emm_gen; + FingerprintingInitParameters fingerprinting_params; + fingerprinting_params.channels = {kChannelThree}; + fingerprinting_params.control = kFingerprintingControl; + EXPECT_EQ(emm_gen.SetFingerprinting( + {GetValidFingerprintingParams(), fingerprinting_params}), + OkStatus()); + + ServiceBlockingInitParameters service_blocking_params; + service_blocking_params.channels = {kChannelThree}; + service_blocking_params.device_groups = {kDeviceGroupOne, kDeviceGroupTwo}; + service_blocking_params.start_time = kServiceBockingStartTime; + service_blocking_params.end_time = kServiceBockingEndTime; + EXPECT_EQ(emm_gen.SetServiceBlocking( + {GetValidServiceBlockingParams(), service_blocking_params}), + OkStatus()); + + std::string result; + EXPECT_EQ(emm_gen.GenerateEmm(&result), OkStatus()); + EXPECT_GT(result.length(), kExpectedNoPayloadLengthBytes); + + EXPECT_EQ(static_cast(result[kVersionStartIndex]), + kExpectedEmmVersion); + EXPECT_EQ(static_cast(result[kHeaderLengthStartIndex]), + kExpectedHeaderLength); + int payload_lengh = static_cast(result[kPayloadLengthStartIndex]) + << 8 | + static_cast(result[kPayloadLengthStartIndex + 1]); + ASSERT_GT(payload_lengh, 0); + ASSERT_EQ(result.length(), + kPayloadStartIndex + payload_lengh + kSignatureLength); + + std::string payload_section = + result.substr(kPayloadStartIndex, payload_lengh); + // Parse the payload and validate fields. + EmmPayload emm_payload; + ASSERT_TRUE(emm_payload.ParseFromString(payload_section)); + EXPECT_EQ(emm_payload.fingerprinting_size(), 2); + EXPECT_EQ(emm_payload.service_blocking_size(), 2); +} + +TEST(EmmTest, GenerateEmmFingerprintingOnlySuccess) { + Emm emm_gen; + FingerprintingInitParameters fingerprinting = GetValidFingerprintingParams(); + EXPECT_EQ(emm_gen.SetFingerprinting({fingerprinting}), OkStatus()); + // OK to be called again. + EXPECT_EQ(emm_gen.SetFingerprinting({fingerprinting}), OkStatus()); + + std::string result; + EXPECT_EQ(emm_gen.GenerateEmm(&result), OkStatus()); + EXPECT_GT(result.length(), kExpectedNoPayloadLengthBytes); + + EXPECT_EQ(static_cast(result[kVersionStartIndex]), + kExpectedEmmVersion); + EXPECT_EQ(static_cast(result[kHeaderLengthStartIndex]), + kExpectedHeaderLength); + int payload_lengh = static_cast(result[kPayloadLengthStartIndex]) + << 8 | + static_cast(result[kPayloadLengthStartIndex + 1]); + ASSERT_GT(payload_lengh, 0); + ASSERT_EQ(result.length(), + kPayloadStartIndex + payload_lengh + kSignatureLength); + + std::string payload_section = + result.substr(kPayloadStartIndex, payload_lengh); + // Parse the payload and validate fields. + EmmPayload emm_payload; + ASSERT_TRUE(emm_payload.ParseFromString(payload_section)); + EmmPayload expected_payload; + LoadExpectedFingerprintingProto(expected_payload.add_fingerprinting()); + std::string serialized_expected_payload; + expected_payload.SerializeToString(&serialized_expected_payload); + EXPECT_EQ(payload_section, serialized_expected_payload); +} + +TEST(EmmTest, GenerateEmmServiceBlockingOnlySuccess) { + Emm emm_gen; + ServiceBlockingInitParameters service_blocking = + GetValidServiceBlockingParams(); + EXPECT_EQ(emm_gen.SetServiceBlocking({service_blocking}), OkStatus()); + + std::string result; + EXPECT_EQ(emm_gen.GenerateEmm(&result), OkStatus()); + EXPECT_GT(result.length(), kExpectedNoPayloadLengthBytes); + + EXPECT_EQ(static_cast(result[kVersionStartIndex]), + kExpectedEmmVersion); + EXPECT_EQ(static_cast(result[kHeaderLengthStartIndex]), + kExpectedHeaderLength); + int payload_lengh = static_cast(result[kPayloadLengthStartIndex]) + << 8 | + static_cast(result[kPayloadLengthStartIndex + 1]); + ASSERT_GT(payload_lengh, 0); + ASSERT_EQ(result.length(), + kPayloadStartIndex + payload_lengh + kSignatureLength); + + std::string payload_section = + result.substr(kPayloadStartIndex, payload_lengh); + // Parse the payload and validate fields. + EmmPayload emm_payload; + ASSERT_TRUE(emm_payload.ParseFromString(payload_section)); + EmmPayload expected_payload; + LoadExpectedServiceBlockingProto(expected_payload.add_service_blocking()); + std::string serialized_expected_payload; + expected_payload.SerializeToString(&serialized_expected_payload); + EXPECT_EQ(payload_section, serialized_expected_payload); +} + +TEST(EmmTest, GenerateEmmNoPayloadSuccess) { + Emm emm_gen; + std::string result; + EXPECT_EQ(emm_gen.GenerateEmm(&result), OkStatus()); + EXPECT_EQ(result.length(), kExpectedNoPayloadLengthBytes); + + EXPECT_EQ(static_cast(result[kVersionStartIndex]), + kExpectedEmmVersion); + EXPECT_EQ(static_cast(result[kHeaderLengthStartIndex]), + kExpectedHeaderLength); + int payload_lengh = static_cast(result[kPayloadLengthStartIndex]) + << 8 | + static_cast(result[kPayloadLengthStartIndex + 1]); + EXPECT_EQ(payload_lengh, 0); + EXPECT_EQ(result.length(), kPayloadStartIndex + kSignatureLength); +} + +} // namespace +} // namespace cas +} // namespace widevine diff --git a/media_cas_packager_sdk/internal/emmg.cc b/media_cas_packager_sdk/internal/emmg.cc index bb18540..baed4bd 100644 --- a/media_cas_packager_sdk/internal/emmg.cc +++ b/media_cas_packager_sdk/internal/emmg.cc @@ -28,6 +28,10 @@ // Minimum sending interval in milliseconds. static constexpr uint16_t KMinSendIntervalMs = 2; +// Maximum number of entitlement key ids shown in private data. +static constexpr uint16_t kMaxNumOfEntitlementKeyIds = 2; +// Entitlement key id length is fixed to 16 bytes. +static constexpr uint16_t kEntitlementKeyIdLength = 16; namespace widevine { namespace cas { @@ -153,8 +157,8 @@ void Emmg::Start() { for (size_t i = 0; i < emmg_config_->max_num_message; i++) { SendDataProvision(); - absl::SleepFor( - absl::Milliseconds(std::max(KMinSendIntervalMs, send_interval_ms_))); + absl::SleepFor(std::max(absl::Milliseconds(KMinSendIntervalMs), + absl::Milliseconds(send_interval_ms_))); } SendStreamCloseRequest(); @@ -220,13 +224,33 @@ void Emmg::BuildStreamBwRequest() { Host16ToBigEndian(request_ + 3, &total_param_length); } -Status Emmg::GeneratePrivateData(const std::string& content_provider, - const std::string& content_id, uint8_t* buffer) { +Status Emmg::GeneratePrivateData( + const std::string& content_provider, const std::string& content_id, + const std::vector& entitlement_key_ids, uint8_t* buffer) { DCHECK(buffer); // Generate payload. CaDescriptorPrivateData private_data; private_data.set_provider(content_provider); private_data.set_content_id(content_id); + + if (entitlement_key_ids.size() > kMaxNumOfEntitlementKeyIds) { + return Status( + error::INVALID_ARGUMENT, + absl::StrCat("Number of entitlement key ids shouldn't exceed ", + kMaxNumOfEntitlementKeyIds)); + } + for (const auto& entitlement_key_id : entitlement_key_ids) { + if (entitlement_key_id.size() != kEntitlementKeyIdLength) { + return Status( + error::INVALID_ARGUMENT, + absl::StrCat("Entitlement key id length must be ", + kEntitlementKeyIdLength, ". The offending key id is ", + entitlement_key_id)); + } + } + for (const auto& entitlement_key_id : entitlement_key_ids) { + private_data.add_entitlement_key_ids(entitlement_key_id); + } std::string private_data_str = private_data.SerializeAsString(); std::string payload_filler(kMaxTsPayloadSize - private_data_str.size(), 0); @@ -268,8 +292,13 @@ void Emmg::BuildDataProvision() { &request_length_); uint8_t datagram[kTsPacketSize]; - GeneratePrivateData(emmg_config_->content_provider, emmg_config_->content_id, - datagram); + Status status = GeneratePrivateData( + emmg_config_->content_provider, emmg_config_->content_id, + emmg_config_->entitlement_key_ids, datagram); + if (!status.ok()) { + LOG(ERROR) << "Fail to generate private data. " << status.ToString(); + return; + } simulcrypt_util::AddParam(EMMG_DATAGRAM, datagram, kTsPacketSize, request_, &request_length_); diff --git a/media_cas_packager_sdk/internal/emmg.h b/media_cas_packager_sdk/internal/emmg.h index ce5c1af..005b0a6 100644 --- a/media_cas_packager_sdk/internal/emmg.h +++ b/media_cas_packager_sdk/internal/emmg.h @@ -11,6 +11,7 @@ #include #include +#include #include #include "common/status.h" @@ -30,6 +31,7 @@ struct EmmgConfig { uint8_t data_type; std::string content_provider; std::string content_id; + std::vector entitlement_key_ids; uint16_t bandwidth; uint32_t max_num_message; }; @@ -82,8 +84,9 @@ class Emmg { void UpdateSendInterval(uint16_t bandwidth_kbps); - Status GeneratePrivateData(const std::string& content_provider, - const std::string& content_id, uint8_t* buffer); + Status GeneratePrivateData( + const std::string& content_provider, const std::string& content_id, + const std::vector& entitlement_key_ids, uint8_t* buffer); void ReceiveResponseAndVerify(uint16_t expected_type); void Send(uint16_t message_type); diff --git a/media_cas_packager_sdk/internal/emmg_test.cc b/media_cas_packager_sdk/internal/emmg_test.cc index b47967a..bcf4e58 100644 --- a/media_cas_packager_sdk/internal/emmg_test.cc +++ b/media_cas_packager_sdk/internal/emmg_test.cc @@ -51,6 +51,7 @@ class EmmgTest : public ::testing::Test { config_.data_type = 0x01; config_.content_provider = "widevine_test"; config_.content_id = "CasTsFake"; + config_.entitlement_key_ids = {"fakeKeyId1KeyId1", "fakeKeyId2KeyId2"}; config_.bandwidth = 100; config_.max_num_message = 100; emmg_ = absl::make_unique(&config_); @@ -103,6 +104,25 @@ TEST_F(EmmgTest, BuildDataProvision) { sizeof(kTestEmmgDataProvision))); } +TEST_F(EmmgTest, + BuildDataProvisionFailedWhenNumOfEntitlementKeyIdsExceedLimit) { + config_.entitlement_key_ids = {"fakeKeyId1KeyId1", "fakeKeyId2KeyId2", + "fakeKeyId3KeyId3"}; + emmg_ = absl::make_unique(&config_); + emmg_->PublicBuildDataProvision(); + EXPECT_EQ(0, memcmp(kTestEmptyEmmgDataProvision, emmg_->GetRequest(), + sizeof(kTestEmptyEmmgDataProvision))); +} + +TEST_F(EmmgTest, + BuildDataProvisionFailedWhenEntitlementKeyIdLengthExceedLimit) { + config_.entitlement_key_ids = {"fakeKeyId1KeyId1", "fakeKeyId2KeyId2KeyId2"}; + emmg_ = absl::make_unique(&config_); + emmg_->PublicBuildDataProvision(); + EXPECT_EQ(0, memcmp(kTestEmptyEmmgDataProvision, emmg_->GetRequest(), + sizeof(kTestEmptyEmmgDataProvision))); +} + TEST_F(EmmgTest, BuildStreamCloseRequest) { emmg_->PublicBuildStreamCloseRequest(); EXPECT_EQ(0, memcmp(kTestEmmgStreamCloseRequest, emmg_->GetRequest(), diff --git a/media_cas_packager_sdk/public/wv_cas_ca_descriptor.cc b/media_cas_packager_sdk/public/wv_cas_ca_descriptor.cc index b3d0c9d..d6ab327 100644 --- a/media_cas_packager_sdk/public/wv_cas_ca_descriptor.cc +++ b/media_cas_packager_sdk/public/wv_cas_ca_descriptor.cc @@ -9,6 +9,7 @@ #include "media_cas_packager_sdk/public/wv_cas_ca_descriptor.h" #include +#include #include "glog/logging.h" #include "absl/strings/str_cat.h" @@ -21,39 +22,44 @@ namespace cas { namespace { // Size of fixed portion of CA descriptor (before any private bytes). -static constexpr uint32_t kCaDescriptorBaseSize = 6; +constexpr uint32_t kCaDescriptorBaseSize = 6; // Size of fixed portion of CA descriptor that follows the length field. // This and the size of any private bytes must be placed in the length field. -static constexpr uint32_t kCaDescriptorBasePostLengthSize = 4; +constexpr uint32_t kCaDescriptorBasePostLengthSize = 4; // Bitfield lengths for the CA descriptor fields -static constexpr int kNumBitsCaDescriptorTagField = 8; -static constexpr int kNumBitsCaDescriptorLengthField = 8; -static constexpr int kNumBitsCaSystemIdField = 16; -static constexpr int kNumBitsCaDescriptorReservedField = 3; -static constexpr int kNumBitsCaDescriptorPidField = 13; +constexpr int kNumBitsCaDescriptorTagField = 8; +constexpr int kNumBitsCaDescriptorLengthField = 8; +constexpr int kNumBitsCaSystemIdField = 16; +constexpr int kNumBitsCaDescriptorReservedField = 3; +constexpr int kNumBitsCaDescriptorPidField = 13; // Bitfield constants for the CA descriptor fields. // CA descriptor tag value, from table 2-45. -static constexpr uint32_t kCaDescriptorTag = 9; +constexpr uint32_t kCaDescriptorTag = 9; // CA System ID for Widevine. From table in // https://en.wikipedia.org/wiki/Conditional_access -static constexpr uint32_t kWidevineCaSystemId = 0x4AD4; +constexpr uint32_t kWidevineCaSystemId = 0x4AD4; // Value for CA descriptor reserved field should be set to 1. -static constexpr uint32_t kReservedBit = 0x0007; +constexpr uint32_t kReservedBit = 0x0007; // The range of valid PIDs, from section 2.4.3.3, and table 2-3. -static constexpr uint32_t kMinValidPID = 0x0010; -static constexpr uint32_t kMaxValidPID = 0x1FFE; +constexpr uint32_t kMinValidPID = 0x0010; +constexpr uint32_t kMaxValidPID = 0x1FFE; +// Maximum number of entitlement key ids shown in private data. +constexpr uint32_t kMaxNumOfEntitlementKeyIds = 2; +// Entitlement key id length is fixed to 16 bytes. +constexpr uint16_t kEntitlementKeyIdLength = 16; } // namespace Status WvCasCaDescriptor::GenerateCaDescriptor( uint16_t ca_pid, const std::string& provider, const std::string& content_id, + const std::vector& entitlement_key_ids, std::string* serialized_ca_desc) const { if (serialized_ca_desc == nullptr) { return {error::INVALID_ARGUMENT, @@ -62,10 +68,25 @@ Status WvCasCaDescriptor::GenerateCaDescriptor( if (ca_pid < kMinValidPID || ca_pid > kMaxValidPID) { return {error::INVALID_ARGUMENT, "PID value is out of the valid range."}; } + if (entitlement_key_ids.size() > kMaxNumOfEntitlementKeyIds) { + return {error::INVALID_ARGUMENT, + absl::StrCat("Number of entitlement key ids shouldn't exceed ", + kMaxNumOfEntitlementKeyIds)}; + } + for (const auto& entitlement_key_id : entitlement_key_ids) { + if (entitlement_key_id.size() != kEntitlementKeyIdLength) { + return {error::INVALID_ARGUMENT, + absl::StrCat("Entitlement key id length must be ", + kEntitlementKeyIdLength, + ". The offending key id is ", entitlement_key_id)}; + } + } std::string private_data = ""; + // Field of Entitlement_key_ids could be empty. if (!provider.empty() && !content_id.empty()) { - private_data = GeneratePrivateData(provider, content_id); + private_data = + GeneratePrivateData(provider, content_id, entitlement_key_ids); } const size_t descriptor_length = @@ -107,10 +128,14 @@ size_t WvCasCaDescriptor::CaDescriptorBaseSize() const { } std::string WvCasCaDescriptor::GeneratePrivateData( - const std::string& provider, const std::string& content_id) const { + const std::string& provider, const std::string& content_id, + const std::vector& entitlement_key_ids) const { CaDescriptorPrivateData private_data; private_data.set_provider(provider); private_data.set_content_id(content_id); + for (const auto& entitlement_key_id : entitlement_key_ids) { + private_data.add_entitlement_key_ids(entitlement_key_id); + } return private_data.SerializeAsString(); } diff --git a/media_cas_packager_sdk/public/wv_cas_ca_descriptor.h b/media_cas_packager_sdk/public/wv_cas_ca_descriptor.h index 1a044fc..428457c 100644 --- a/media_cas_packager_sdk/public/wv_cas_ca_descriptor.h +++ b/media_cas_packager_sdk/public/wv_cas_ca_descriptor.h @@ -12,6 +12,7 @@ #include #include +#include #include #include "common/status.h" @@ -48,6 +49,9 @@ class WvCasCaDescriptor { // |ca_pid| the 13-bit PID of the ECMs // |provider| provider name, put in private data for client to construct pssh // |content_id| content ID, put in private data for client to construct pssh + // |entitlement_key_ids| entitlement key ids, put in private data for client + // to select entitlement keys from single fat license. This field is only used + // when client uses single fat license. // |serialized_ca_desc| a std::string object to receive the encoded descriptor. // // Notes: @@ -55,10 +59,10 @@ class WvCasCaDescriptor { // section (for an EMM stream) or into a TS Program Map Table section (for an // ECM stream). The descriptor will be 6 bytes plus any bytes added as // (user-defined) private data. - virtual Status GenerateCaDescriptor(uint16_t ca_pid, - const std::string& provider, - const std::string& content_id, - std::string* serialized_ca_desc) const; + virtual Status GenerateCaDescriptor( + uint16_t ca_pid, const std::string& provider, const std::string& content_id, + const std::vector& entitlement_key_ids, + std::string* serialized_ca_desc) const; // Return the base size (before private data is added) of the CA // descriptor. The user can call this to plan the layout of the Table section @@ -66,8 +70,9 @@ class WvCasCaDescriptor { virtual size_t CaDescriptorBaseSize() const; // Return private data in the CA descriptor. - virtual std::string GeneratePrivateData(const std::string& provider, - const std::string& content_id) const; + virtual std::string GeneratePrivateData( + const std::string& provider, const std::string& content_id, + const std::vector& entitlement_key_ids) const; }; } // namespace cas diff --git a/media_cas_packager_sdk/public/wv_cas_ca_descriptor_test.cc b/media_cas_packager_sdk/public/wv_cas_ca_descriptor_test.cc index af14eee..852d5b2 100644 --- a/media_cas_packager_sdk/public/wv_cas_ca_descriptor_test.cc +++ b/media_cas_packager_sdk/public/wv_cas_ca_descriptor_test.cc @@ -20,9 +20,11 @@ namespace cas { namespace { // Random value for PID -static constexpr int kTestPid = 50; -static constexpr char kProvider[] = "widevine_test"; -static constexpr char kContentId[] = "1234"; +constexpr int kTestPid = 50; +constexpr char kProvider[] = "widevine_test"; +constexpr char kContentId[] = "1234"; +const std::vector* const kEntitlementKeyIds = + new std::vector({"fakekey1fakekey1", "fakekey2fakekey2"}); } // namespace @@ -31,6 +33,7 @@ class WvCasCaDescriptorTest : public Test { WvCasCaDescriptorTest() {} WvCasCaDescriptor ca_descriptor_; std::string actual_ca_descriptor_; + std::vector entitlement_key_ids_; }; TEST_F(WvCasCaDescriptorTest, BaseSize) { @@ -38,38 +41,46 @@ TEST_F(WvCasCaDescriptorTest, BaseSize) { } TEST_F(WvCasCaDescriptorTest, BasicGoodGen) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(kTestPid, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + kTestPid, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xE0\x32", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, NoReturnStringFail) { EXPECT_EQ(error::INVALID_ARGUMENT, - ca_descriptor_.GenerateCaDescriptor(kTestPid, "", "", nullptr) + ca_descriptor_ + .GenerateCaDescriptor( + kTestPid, /*provider=*/"", /*content_id=*/"", + /*entitlement_key_ids=*/{}, /*serialized_ca_desc=*/nullptr) .error_code()); } TEST_F(WvCasCaDescriptorTest, PidTooLowFail) { const uint32_t bad_pid = 0x10 - 1; - EXPECT_EQ(error::INVALID_ARGUMENT, - ca_descriptor_ - .GenerateCaDescriptor(bad_pid, "", "", &actual_ca_descriptor_) - .error_code()); + EXPECT_EQ( + error::INVALID_ARGUMENT, + ca_descriptor_ + .GenerateCaDescriptor(bad_pid, /*provider=*/"", /*content_id=*/"", + entitlement_key_ids_, &actual_ca_descriptor_) + .error_code()); } TEST_F(WvCasCaDescriptorTest, PidMinOK) { const uint32_t min_pid = 0x10; - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(min_pid, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + min_pid, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xE0\x10", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidMaxOK) { const uint32_t max_pid = 0x1FFE; - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(max_pid, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + max_pid, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xff\xfe"); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } @@ -78,83 +89,93 @@ TEST_F(WvCasCaDescriptorTest, PidTooHighFail) { const uint32_t bad_pid = 0x1FFF; EXPECT_EQ(error::INVALID_ARGUMENT, ca_descriptor_ - .GenerateCaDescriptor(bad_pid, "", "", &actual_ca_descriptor_) + .GenerateCaDescriptor( + bad_pid, /*provider=*/"", /*content_id=*/"", + /*entitlement_key_ids=*/{}, &actual_ca_descriptor_) .error_code()); } TEST_F(WvCasCaDescriptorTest, PidOneByte) { - EXPECT_OK( - ca_descriptor_.GenerateCaDescriptor(255, "", "", &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 255, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xe0\xff", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidSecondByte) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(0x1F00, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 0x1F00, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xff\x00", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidTwelveBits) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(0xFFF, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 0xFFF, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xef\xff"); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidThirteenthBit) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(0x1000, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 0x1000, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xf0\x00", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidTwelthBit) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(0x800, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 0x800, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xe8\x00", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidElevenththBit) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(0x400, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 0x400, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xe4\x00", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidTenthBit) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(0x200, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 0x200, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xe2\x00", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } TEST_F(WvCasCaDescriptorTest, PidNinthBit) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(0x100, "", "", - &actual_ca_descriptor_)); + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + 0x100, /*provider=*/"", /*content_id=*/"", /*entitlement_key_ids=*/{}, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xe1\x00", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } -TEST_F(WvCasCaDescriptorTest, PrivateDataOnlyProviderIgnored) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(kTestPid, kProvider, "", - &actual_ca_descriptor_)); +TEST_F(WvCasCaDescriptorTest, PrivateDataWithNoContentIdIgnored) { + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + kTestPid, kProvider, "", entitlement_key_ids_, &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xe0\x32", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } -TEST_F(WvCasCaDescriptorTest, PrivateDataOnlyContentIdIgnored) { - EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(kTestPid, "", kContentId, - &actual_ca_descriptor_)); +TEST_F(WvCasCaDescriptorTest, PrivateDataWithNoProviderIgnored) { + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor( + kTestPid, "", kContentId, entitlement_key_ids_, &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x04\x4a\xd4\xe0\x32", 6); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } -TEST_F(WvCasCaDescriptorTest, PrivateData) { +TEST_F(WvCasCaDescriptorTest, PrivateDataWithNoEntitlementKeyIds) { EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(kTestPid, kProvider, kContentId, - &actual_ca_descriptor_)); + {}, &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x19\x4a\xd4\xe0\x32", 6); CaDescriptorPrivateData private_data; private_data.set_provider(kProvider); @@ -163,6 +184,44 @@ TEST_F(WvCasCaDescriptorTest, PrivateData) { actual_ca_descriptor_); } +TEST_F(WvCasCaDescriptorTest, + PrivateDataFailedWhenNumberOfEntitlementKeyIdsExceedLimit) { + const std::vector entitlement_key_ids = { + "fakekey1fakekey1", "fakekey2fakekey2", "fakekey3fakekey3"}; + Status status = {error::INVALID_ARGUMENT, + "Number of entitlement key ids shouldn't exceed 2"}; + EXPECT_EQ(status, ca_descriptor_.GenerateCaDescriptor( + kTestPid, kProvider, kContentId, entitlement_key_ids, + &actual_ca_descriptor_)); +} + +TEST_F(WvCasCaDescriptorTest, + PrivateDataFailedWhenEntitlementKeyIdLengthExceedLimit) { + const std::vector entitlement_key_ids = { + "fakekey1fakekey1", "fakekey2fakekey2fakekey2"}; + Status status = {error::INVALID_ARGUMENT, + "Entitlement key id length must be 16. The offending key id " + "is fakekey2fakekey2fakekey2"}; + EXPECT_EQ(status, ca_descriptor_.GenerateCaDescriptor( + kTestPid, kProvider, kContentId, entitlement_key_ids, + &actual_ca_descriptor_)); +} + +TEST_F(WvCasCaDescriptorTest, PrivateData) { + EXPECT_OK(ca_descriptor_.GenerateCaDescriptor(kTestPid, kProvider, kContentId, + *kEntitlementKeyIds, + &actual_ca_descriptor_)); + const std::string expected_ca_descriptor("\x09\x3d\x4a\xd4\xe0\x32", 6); + CaDescriptorPrivateData private_data; + private_data.set_provider(kProvider); + private_data.set_content_id(kContentId); + for (const auto& entitlementKeyId : *kEntitlementKeyIds) { + private_data.add_entitlement_key_ids(entitlementKeyId); + } + EXPECT_EQ(expected_ca_descriptor + private_data.SerializeAsString(), + actual_ca_descriptor_); +} + class FakePrivateDataCaDescriptor : public WvCasCaDescriptor { public: void set_private_data(std::string private_data) { @@ -170,8 +229,8 @@ class FakePrivateDataCaDescriptor : public WvCasCaDescriptor { } std::string GeneratePrivateData( - const std::string& provider, - const std::string& content_id) const override { + const std::string& provider, const std::string& content_id, + const std::vector& entitlement_key_ids) const override { return private_data_; } @@ -183,7 +242,8 @@ TEST_F(WvCasCaDescriptorTest, PrivateDataOneByte) { FakePrivateDataCaDescriptor fake_descriptor; fake_descriptor.set_private_data("X"); EXPECT_OK(fake_descriptor.GenerateCaDescriptor( - kTestPid, kProvider, kContentId, &actual_ca_descriptor_)); + kTestPid, kProvider, kContentId, *kEntitlementKeyIds, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x05\x4a\xd4\xe0\x32X", 7); EXPECT_EQ(expected_ca_descriptor, actual_ca_descriptor_); } @@ -193,7 +253,8 @@ TEST_F(WvCasCaDescriptorTest, PrivateDataMultipleBytes) { FakePrivateDataCaDescriptor fake_descriptor; fake_descriptor.set_private_data(private_data_bytes); EXPECT_OK(fake_descriptor.GenerateCaDescriptor( - kTestPid, kProvider, kContentId, &actual_ca_descriptor_)); + kTestPid, kProvider, kContentId, *kEntitlementKeyIds, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\x0e\x4a\xd4\xe0\x32", 6); EXPECT_EQ(expected_ca_descriptor + private_data_bytes, actual_ca_descriptor_); } @@ -203,7 +264,8 @@ TEST_F(WvCasCaDescriptorTest, PrivateDataMaxNumberBytes) { FakePrivateDataCaDescriptor fake_descriptor; fake_descriptor.set_private_data(private_data_bytes); EXPECT_OK(fake_descriptor.GenerateCaDescriptor( - kTestPid, kProvider, kContentId, &actual_ca_descriptor_)); + kTestPid, kProvider, kContentId, *kEntitlementKeyIds, + &actual_ca_descriptor_)); const std::string expected_ca_descriptor("\x09\xff\x4a\xd4\xe0\x32", 6); EXPECT_EQ(expected_ca_descriptor + private_data_bytes, actual_ca_descriptor_); } @@ -212,11 +274,12 @@ TEST_F(WvCasCaDescriptorTest, PrivateDataTooManyBytesFail) { const std::string private_data_bytes(252, 'X'); FakePrivateDataCaDescriptor fake_descriptor; fake_descriptor.set_private_data(private_data_bytes); - EXPECT_EQ(error::INVALID_ARGUMENT, - fake_descriptor - .GenerateCaDescriptor(kTestPid, kProvider, kContentId, - &actual_ca_descriptor_) - .error_code()); + EXPECT_EQ( + error::INVALID_ARGUMENT, + fake_descriptor + .GenerateCaDescriptor(kTestPid, kProvider, kContentId, + *kEntitlementKeyIds, &actual_ca_descriptor_) + .error_code()); } } // namespace cas diff --git a/media_cas_packager_sdk/public/wv_cas_key_fetcher_test.cc b/media_cas_packager_sdk/public/wv_cas_key_fetcher_test.cc index b0a96bc..8c51684 100644 --- a/media_cas_packager_sdk/public/wv_cas_key_fetcher_test.cc +++ b/media_cas_packager_sdk/public/wv_cas_key_fetcher_test.cc @@ -18,7 +18,7 @@ using testing::_; using testing::DoAll; using testing::Return; -using testing::SetArgumentPointee; +using testing::SetArgPointee; namespace { @@ -59,9 +59,10 @@ class HardcodedWvCasKeyFetcher : public WvCasKeyFetcher { const std::string& signing_iv) : WvCasKeyFetcher(signing_provider, signing_key, signing_iv) {} ~HardcodedWvCasKeyFetcher() override {} - MOCK_CONST_METHOD2(MakeHttpRequest, - Status(const std::string& signed_request_json, - std::string* http_response_json)); + MOCK_METHOD(Status, MakeHttpRequest, + (const std::string& signed_request_json, + std::string* http_response_json), + (const, override)); }; class MockWvCasKeyFetcher : public WvCasKeyFetcher { @@ -168,7 +169,7 @@ TEST_F(WvCasKeyFetcherTest, TestRequestEntitlementKey) { EXPECT_EQ(signed_request_json, kSignedCasEncryptionRequest); EXPECT_CALL(mock_key_fetcher, MakeHttpRequest(kSignedCasEncryptionRequest, _)) - .WillOnce(DoAll(SetArgumentPointee<1>(std::string(kHttpResponse)), + .WillOnce(DoAll(SetArgPointee<1>(std::string(kHttpResponse)), Return(OkStatus()))); std::string actual_signed_response; EXPECT_OK(mock_key_fetcher.MakeHttpRequest(signed_request_json, diff --git a/media_cas_packager_sdk/public/wv_emmg.cc b/media_cas_packager_sdk/public/wv_emmg.cc index 60e97b9..cc0f5e9 100644 --- a/media_cas_packager_sdk/public/wv_emmg.cc +++ b/media_cas_packager_sdk/public/wv_emmg.cc @@ -43,6 +43,9 @@ ABSL_FLAG(int32_t, data_id, 0, "EMMG data_id."); ABSL_FLAG(int32_t, data_type, 1, "EMMG data_type"); ABSL_FLAG(std::string, content_provider, "", "Content provider"); ABSL_FLAG(std::string, content_id, "", "Content id"); +ABSL_FLAG( + std::vector, entitlement_key_ids, {}, + "Comma-separated list of entitlement_key_ids to put into private data"); ABSL_FLAG(int32_t, bandwidth, 100, "Requested bandwidth in kbps"); ABSL_FLAG(int32_t, max_num_message, 100, "Maximum number of messages that can be sent"); @@ -59,6 +62,7 @@ void BuildEmmgConfig(widevine::cas::EmmgConfig *config) { config->data_type = absl::GetFlag(FLAGS_data_type); config->content_provider = absl::GetFlag(FLAGS_content_provider); config->content_id = absl::GetFlag(FLAGS_content_id); + config->entitlement_key_ids = absl::GetFlag(FLAGS_entitlement_key_ids); config->bandwidth = absl::GetFlag(FLAGS_bandwidth); config->max_num_message = absl::GetFlag(FLAGS_max_num_message); } diff --git a/protos/public/BUILD b/protos/public/BUILD index cd40ec4..a0b2373 100644 --- a/protos/public/BUILD +++ b/protos/public/BUILD @@ -18,6 +18,7 @@ filegroup( proto_library( name = "media_cas_encryption_proto", srcs = ["media_cas_encryption.proto"], + deps = ["hash_algorithm_proto"], ) cc_proto_library( @@ -34,3 +35,18 @@ cc_proto_library( name = "media_cas_cc_proto", deps = [":media_cas_proto"], ) + +proto_library( + name = "hash_algorithm_proto", + srcs = ["hash_algorithm.proto"], +) + +cc_proto_library( + name = "hash_algorithm_cc_proto", + deps = [":hash_algorithm_proto"], +) + +java_proto_library( + name = "hash_algorithm_java_proto", + deps = [":hash_algorithm_proto"], +) diff --git a/protos/public/hash_algorithm.proto b/protos/public/hash_algorithm.proto new file mode 100644 index 0000000..bcb0751 --- /dev/null +++ b/protos/public/hash_algorithm.proto @@ -0,0 +1,20 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 2020 Google LLC. +// +// This software is licensed under the terms defined in the Widevine Master +// License Agreement. For a copy of this agreement, please contact +// widevine-licensing@google.com. +//////////////////////////////////////////////////////////////////////////////// + +syntax = "proto3"; + +package widevine; + +// LINT.IfChange +enum HashAlgorithmProto { + // Unspecified hash algorithm: SHA_256 shall be used for ECC based algorithms + // and SHA_1 shall be used otherwise. + HASH_ALGORITHM_UNSPECIFIED = 0; + HASH_ALGORITHM_SHA_1 = 1; + HASH_ALGORITHM_SHA_256 = 2; +} diff --git a/protos/public/media_cas.proto b/protos/public/media_cas.proto index 928c3ce..e9aea06 100644 --- a/protos/public/media_cas.proto +++ b/protos/public/media_cas.proto @@ -18,4 +18,35 @@ message CaDescriptorPrivateData { // Content ID. optional bytes content_id = 2; + + // Entitlement key IDs for current content per track. Each track will allow up + // to 2 entitlement key ids (odd and even entitlement keys). + repeated bytes entitlement_key_ids = 3; +} + +// Widevine fingerprinting. +message Fingerprinting { + // Channels that will be applied with the controls. + repeated bytes channels = 1; + // Fingerprinting controls are opaque to Widevine. + optional bytes control = 2; +} + +// Widevine service blocking. +message ServiceBlocking { + // Channels that will be blocked. + repeated bytes channels = 1; + // Device groups that will be blocked. Group definition is opaque to Widevine. + repeated bytes device_groups = 2; + // Blocking start time in seconds since epoch. Start time is "immediate" if + // this field is not set. + optional int64 start_time_sec = 3; + // Required. Blocking end time in seconds since epoch. + optional int64 end_time_sec = 4; +} + +// The payload field for an EMM. +message EmmPayload { + repeated Fingerprinting fingerprinting = 1; + repeated ServiceBlocking service_blocking = 2; } diff --git a/protos/public/media_cas_encryption.proto b/protos/public/media_cas_encryption.proto index 84f8044..c40712a 100644 --- a/protos/public/media_cas_encryption.proto +++ b/protos/public/media_cas_encryption.proto @@ -12,8 +12,11 @@ syntax = "proto2"; package widevine; +import "protos/public/hash_algorithm.proto"; + option java_package = "com.google.video.widevine.mediacasencryption"; + message CasEncryptionRequest { optional bytes content_id = 1; optional string provider = 2; @@ -23,6 +26,10 @@ message CasEncryptionRequest { // return one key for EVEN and one key for ODD, otherwise only a single key is // returned. optional bool key_rotation = 4; + // Optional value which can be used to indicate a group. + // If present the CasEncryptionResponse will return key based on the group + // id. + optional bytes group_id = 5; } message CasEncryptionResponse { @@ -54,6 +61,8 @@ message CasEncryptionResponse { optional string status_message = 2; optional bytes content_id = 3; repeated KeyInfo entitlement_keys = 4; + // If this is a group key license, this is the group identifier. + optional bytes group_id = 5; } message SignedCasEncryptionRequest { @@ -61,6 +70,8 @@ message SignedCasEncryptionRequest { optional bytes signature = 2; // Identifies the entity sending / signing the request. optional string signer = 3; + // Optional field that indicates the hash algorithm used in signature scheme. + optional HashAlgorithmProto hash_algorithm = 4; } message SignedCasEncryptionResponse { diff --git a/util/error_space.h b/util/error_space.h index 2e66a39..ce63880 100644 --- a/util/error_space.h +++ b/util/error_space.h @@ -17,7 +17,9 @@ namespace util { class ErrorSpace { public: std::string SpaceName() const { return space_name_func_(this); } - std::string String(int code) const { return code_to_string_func_(this, code); } + std::string String(int code) const { + return code_to_string_func_(this, code); + } protected: // typedef instead of using statements for SWIG compatibility.