From a2b9e085e987446a2aafb87a7f5625c892537a47 Mon Sep 17 00:00:00 2001 From: Matt Feddersen Date: Tue, 20 May 2025 20:30:59 -0700 Subject: [PATCH] OEMCrypto and OPK v20 prerelease initial commit --- oemcrypto/Android.bp | 145 ++ oemcrypto/AndroidTest.xml | 42 + oemcrypto/include/OEMCryptoCENC.h | 299 +-- oemcrypto/include/level3.h | 25 + oemcrypto/odk/Android.bp | 148 +- oemcrypto/odk/include/OEMCryptoCENCCommon.h | 132 +- .../odk/include/core_message_deserialize.h | 12 + oemcrypto/odk/include/core_message_features.h | 29 +- .../odk/include/core_message_serialize.h | 4 - oemcrypto/odk/include/odk.h | 25 +- oemcrypto/odk/include/odk_message.h | 8 +- oemcrypto/odk/include/odk_structs.h | 50 +- oemcrypto/odk/include/odk_target.h | 1 + .../odk/src/core_message_deserialize.cpp | 98 +- oemcrypto/odk/src/core_message_features.cpp | 94 +- oemcrypto/odk/src/core_message_serialize.cpp | 33 +- .../odk/src/core_message_serialize_proto.cpp | 81 +- oemcrypto/odk/src/odk.c | 182 +- oemcrypto/odk/src/odk_message_priv.h | 8 +- oemcrypto/odk/src/odk_overflow.c | 42 +- oemcrypto/odk/src/odk_overflow.h | 6 + oemcrypto/odk/src/odk_serialize.c | 206 +- oemcrypto/odk/src/odk_serialize.h | 5 + oemcrypto/odk/src/odk_structs_priv.h | 11 +- oemcrypto/odk/src/odk_timer.c | 27 +- oemcrypto/odk/src/odk_util.c | 2 + oemcrypto/odk/src/serialization_base.c | 146 +- oemcrypto/odk/src/serialization_base.h | 12 +- oemcrypto/odk/test/fuzzing/Android.bp | 113 +- .../odk/test/fuzzing/odk_fuzz_helper.cpp | 119 +- oemcrypto/odk/test/fuzzing/odk_fuzz_helper.h | 24 +- ...odk_license_response_fuzz_with_mutator.cpp | 4 +- .../fuzzing/odk_renewal_response_fuzz.cpp | 1 + oemcrypto/odk/test/odk_core_message_test.cpp | 9 +- oemcrypto/odk/test/odk_golden_v16.cpp | 11 +- oemcrypto/odk/test/odk_golden_v17.cpp | 11 +- oemcrypto/odk/test/odk_golden_v18.cpp | 17 +- oemcrypto/odk/test/odk_golden_v19.cpp | 17 +- oemcrypto/odk/test/odk_test.cpp | 613 +++--- oemcrypto/odk/test/odk_test_helper.cpp | 412 +++- oemcrypto/odk/test/odk_test_helper.h | 10 +- oemcrypto/odk/test/odk_timer_test.cpp | 72 +- oemcrypto/oemcrypto_security_tests.gyp | 5 +- oemcrypto/oemcrypto_unittests.gyp | 7 +- oemcrypto/opk/Android.bp | 6 +- oemcrypto/opk/build/ree-sources.mk | 5 +- oemcrypto/opk/oemcrypto_ta/oemcrypto.c | 288 +-- .../opk/oemcrypto_ta/oemcrypto_api_macros.h | 6 +- oemcrypto/opk/oemcrypto_ta/oemcrypto_key.c | 8 +- .../oemcrypto_key_control_block.c | 1 + .../oemcrypto_serialized_usage_table.h | 4 +- .../opk/oemcrypto_ta/oemcrypto_session.c | 67 +- .../opk/oemcrypto_ta/oemcrypto_session.h | 20 +- .../opk/oemcrypto_ta/oemcrypto_usage_table.c | 23 +- .../opk/oemcrypto_ta/oemcrypto_usage_table.h | 4 +- oemcrypto/opk/oemcrypto_ta/wtpi/README.md | 45 +- .../wtpi/wtpi_clock_interface_layer1.h | 2 +- .../wtpi/wtpi_crypto_asymmetric_interface.h | 4 +- .../oemcrypto_ta/wtpi_reference/cose_util.c | 20 +- .../wtpi_reference/wtpi_clock_and_gn_layer1.c | 4 +- .../wtpi_reference/wtpi_ref_compat.c | 2 +- .../opk/oemcrypto_ta/wtpi_test/README.md | 2 +- .../wtpi_test/common/GEN_common_serializer.c | 94 +- .../wtpi_test/common/GEN_common_serializer.h | 12 +- .../wtpi_test/common/common_special_cases.c | 56 +- .../wtpi_test/common/common_special_cases.h | 21 +- .../oemcrypto_ta/wtpi_test/crypto_test.cpp | 3 +- .../ree/GEN_oemcrypto_tee_test_api.c | 4 +- .../wtpi_test/ree/GEN_ree_serializer.c | 7 +- .../wtpi_test/ree/GEN_ree_serializer.h | 2 +- .../wtpi_test/tee/GEN_dispatcher.c | 22 +- .../wtpi_test/tee/GEN_tee_serializer.c | 4 +- .../wtpi_test/tee/GEN_tee_serializer.h | 3 +- .../wtpi_test/tee/tee_special_cases.c | 5 +- .../wtpi_test/tee/tee_special_cases.h | 3 +- .../linux/cas/tee/tee_simulator_cas/Makefile | 11 +- .../cas/tee/tee_simulator_cas/data_share.h | 2 +- .../opk/ports/linux/common/posix_services.h | 15 +- .../oemcrypto_ref_compat_test.cpp | 2 +- .../ports/linux/oemcrypto_tee_simulator.gyp | 4 +- oemcrypto/opk/ports/linux/ta/common/clock.cpp | 28 +- .../opk/ports/linux/wtpi_tee_simulator.gyp | 1 - .../ta/common/wtpi_impl/genkeypair_ecc.c | 2 +- .../ports/trusty/ta/reference/rustfmt.toml | 5 + .../common/GEN_common_serializer.c | 106 +- .../common/GEN_common_serializer.h | 13 +- .../common/common_special_cases.c | 39 + .../common/common_special_cases.h | 6 + .../common/include/marshaller_base.h | 3 + .../serialization/common/marshaller_base.c | 9 +- .../common/shared_buffer_allocator.c | 2 +- .../os_interfaces/opk_dispatcher.h | 6 +- .../os_interfaces/tos_transport_interface.h | 13 +- .../opk/serialization/ree/GEN_oemcrypto_api.c | 106 +- .../serialization/ree/GEN_ree_serializer.c | 116 +- .../serialization/ree/GEN_ree_serializer.h | 19 +- .../opk/serialization/tee/GEN_dispatcher.c | 79 +- .../serialization/tee/GEN_tee_serializer.c | 105 +- .../serialization/tee/GEN_tee_serializer.h | 20 +- oemcrypto/test/Android.mk | 27 - oemcrypto/test/GEN_api_lock_file.c | 12 +- oemcrypto/test/common.mk | 76 - oemcrypto/test/fuzz_tests/README.md | 10 +- .../oemcrypto_deactivate_usage_entry_fuzz.cc | 1 - .../test/fuzz_tests/oemcrypto_fuzz_structs.h | 2 - .../test/fuzz_tests/oemcrypto_fuzztests.gypi | 6 +- ...ypto_generate_certificate_key_pair_fuzz.cc | 32 - .../fuzz_tests/oemcrypto_get_random_fuzz.cc | 2 +- .../oemcrypto_install_oem_private_key_fuzz.cc | 25 - .../oemcrypto_opk_dispatcher_fuzz.cc | 12 +- .../fuzz_tests/oemcrypto_opk_fuzztests.gyp | 25 +- .../oemcrypto_release_request_fuzz.cc | 1 - .../fuzz_tests/oemcrypto_report_usage_fuzz.cc | 1 - .../partner_oemcrypto_fuzztests.gyp | 17 +- .../partner_oemcrypto_fuzztests.gypi | 2 +- oemcrypto/test/oec_decrypt_fallback_chain.h | 5 +- oemcrypto/test/oec_device_features.cpp | 11 +- oemcrypto/test/oec_device_features.h | 3 +- oemcrypto/test/oec_session_util.cpp | 508 +++-- oemcrypto/test/oec_session_util.h | 369 +++- oemcrypto/test/oemcrypto_basic_test.cpp | 51 +- oemcrypto/test/oemcrypto_cast_test.h | 39 +- oemcrypto/test/oemcrypto_license_test.cpp | 86 +- oemcrypto/test/oemcrypto_license_test.h | 4 +- .../test/oemcrypto_provisioning_test.cpp | 281 +-- oemcrypto/test/oemcrypto_security_test.cpp | 30 +- oemcrypto/test/oemcrypto_security_tests.gypi | 2 +- .../test/oemcrypto_session_tests_helper.cpp | 189 +- .../test/oemcrypto_session_tests_helper.h | 4 +- oemcrypto/test/oemcrypto_unittests.gypi | 3 +- oemcrypto/test/oemcrypto_usage_table_test.cpp | 148 +- oemcrypto/util/Android.bp | 127 ++ oemcrypto/util/WVCRC32.md | 71 + oemcrypto/util/build.gyp | 103 + .../util/include/bcc_validation_config.h | 27 + oemcrypto/util/include/bcc_validator.h | 14 +- oemcrypto/util/include/cose_utils.h | 462 ++++ oemcrypto/util/include/oemcrypto_cose_key.h | 343 +++ oemcrypto/util/include/oemcrypto_drm_key.h | 99 +- oemcrypto/util/include/oemcrypto_ecc_key.h | 72 + oemcrypto/util/include/oemcrypto_ed_key.h | 282 +++ oemcrypto/util/include/oemcrypto_oem_cert.h | 29 +- .../util/include/oemcrypto_oem_cert_chain.h | 130 ++ oemcrypto/util/include/oemcrypto_rsa_key.h | 108 +- oemcrypto/util/include/wvcrc32.h | 115 +- oemcrypto/util/oec_ref_util.gyp | 20 - oemcrypto/util/oec_ref_util.gypi | 42 - oemcrypto/util/oec_ref_util_unittests.gypi | 5 +- oemcrypto/util/src/bcc_validator.cpp | 18 +- oemcrypto/util/src/cose_utils.cpp | 757 +++++++ oemcrypto/util/src/oemcrypto_cose_key.cpp | 874 ++++++++ oemcrypto/util/src/oemcrypto_drm_key.cpp | 214 +- oemcrypto/util/src/oemcrypto_ecc_key.cpp | 430 +++- oemcrypto/util/src/oemcrypto_ed_key.cpp | 855 ++++++++ oemcrypto/util/src/oemcrypto_oem_cert.cpp | 211 +- .../util/src/oemcrypto_oem_cert_chain.cpp | 179 ++ oemcrypto/util/src/oemcrypto_rsa_key.cpp | 214 +- oemcrypto/util/src/wvcrc.cpp | 88 - oemcrypto/util/src/wvcrc32.cpp | 156 ++ oemcrypto/util/test/cose_utils_unittest.cpp | 1335 ++++++++++++ oemcrypto/util/test/hmac_unittest.cpp | 4 +- .../util/test/oemcrypto_cose_key_unittest.cpp | 1073 ++++++++++ .../util/test/oemcrypto_ecc_key_unittest.cpp | 422 ++-- .../util/test/oemcrypto_ed_key_unittest.cpp | 415 ++++ .../oemcrypto_oem_cert_chain_unittest.cpp | 226 ++ .../util/test/oemcrypto_oem_cert_unittest.cpp | 6 +- .../util/test/oemcrypto_rsa_key_unittest.cpp | 151 +- .../util/test/oemcrypto_wvcrc32_unittest.cpp | 74 - oemcrypto/util/test/wvcrc32_unittest.cpp | 248 +++ oemcrypto/util/wvcrc32.gyp | 21 + oemcrypto/util/wvcrc32_unittests.gypi | 15 + util/include/buffer_reader.h | 78 + util/include/clock.h | 11 +- util/include/hls_attribute_list.h | 211 ++ util/include/rw_lock.h | 65 - util/include/string_utils.h | 89 + util/include/wv_date_time.h | 223 ++ util/include/wv_duration.h | 286 +++ util/include/wv_timestamp.h | 170 ++ util/src/buffer_reader.cpp | 133 ++ util/src/hls_attribute_list.cpp | 981 +++++++++ util/src/rw_lock.cpp | 60 - util/src/string_utils.cpp | 244 +++ util/src/time_struct.h | 83 + util/src/wv_date_time.cpp | 195 ++ util/src/wv_duration.cpp | 71 + util/test/buffer_reader_test.cpp | 841 ++++++++ util/test/hls_attribute_list_unittest.cpp | 1385 ++++++++++++ util/test/string_utils_unittest.cpp | 1890 +++++++++++++++++ util/test/test_clock.cpp | 21 +- util/test/wv_date_time_unittest.cpp | 464 ++++ util/test/wv_duration_unittest.cpp | 1158 ++++++++++ util/test/wv_timestamp_unittest.cpp | 300 +++ 193 files changed, 22480 insertions(+), 3275 deletions(-) create mode 100644 oemcrypto/Android.bp create mode 100644 oemcrypto/AndroidTest.xml create mode 100644 oemcrypto/opk/ports/trusty/ta/reference/rustfmt.toml delete mode 100644 oemcrypto/test/Android.mk delete mode 100644 oemcrypto/test/common.mk delete mode 100644 oemcrypto/test/fuzz_tests/oemcrypto_install_oem_private_key_fuzz.cc create mode 100644 oemcrypto/util/Android.bp create mode 100644 oemcrypto/util/WVCRC32.md create mode 100644 oemcrypto/util/build.gyp create mode 100644 oemcrypto/util/include/bcc_validation_config.h create mode 100644 oemcrypto/util/include/cose_utils.h create mode 100644 oemcrypto/util/include/oemcrypto_cose_key.h create mode 100644 oemcrypto/util/include/oemcrypto_ed_key.h create mode 100644 oemcrypto/util/include/oemcrypto_oem_cert_chain.h delete mode 100644 oemcrypto/util/oec_ref_util.gyp delete mode 100644 oemcrypto/util/oec_ref_util.gypi create mode 100644 oemcrypto/util/src/cose_utils.cpp create mode 100644 oemcrypto/util/src/oemcrypto_cose_key.cpp create mode 100644 oemcrypto/util/src/oemcrypto_ed_key.cpp create mode 100644 oemcrypto/util/src/oemcrypto_oem_cert_chain.cpp delete mode 100644 oemcrypto/util/src/wvcrc.cpp create mode 100644 oemcrypto/util/src/wvcrc32.cpp create mode 100644 oemcrypto/util/test/cose_utils_unittest.cpp create mode 100644 oemcrypto/util/test/oemcrypto_cose_key_unittest.cpp create mode 100644 oemcrypto/util/test/oemcrypto_ed_key_unittest.cpp create mode 100644 oemcrypto/util/test/oemcrypto_oem_cert_chain_unittest.cpp delete mode 100644 oemcrypto/util/test/oemcrypto_wvcrc32_unittest.cpp create mode 100644 oemcrypto/util/test/wvcrc32_unittest.cpp create mode 100644 oemcrypto/util/wvcrc32.gyp create mode 100644 oemcrypto/util/wvcrc32_unittests.gypi create mode 100644 util/include/buffer_reader.h create mode 100644 util/include/hls_attribute_list.h delete mode 100644 util/include/rw_lock.h create mode 100644 util/include/string_utils.h create mode 100644 util/include/wv_date_time.h create mode 100644 util/include/wv_duration.h create mode 100644 util/include/wv_timestamp.h create mode 100644 util/src/buffer_reader.cpp create mode 100644 util/src/hls_attribute_list.cpp delete mode 100644 util/src/rw_lock.cpp create mode 100644 util/src/string_utils.cpp create mode 100644 util/src/time_struct.h create mode 100644 util/src/wv_date_time.cpp create mode 100644 util/src/wv_duration.cpp create mode 100644 util/test/buffer_reader_test.cpp create mode 100644 util/test/hls_attribute_list_unittest.cpp create mode 100644 util/test/string_utils_unittest.cpp create mode 100644 util/test/wv_date_time_unittest.cpp create mode 100644 util/test/wv_duration_unittest.cpp create mode 100644 util/test/wv_timestamp_unittest.cpp diff --git a/oemcrypto/Android.bp b/oemcrypto/Android.bp new file mode 100644 index 0000000..3416917 --- /dev/null +++ b/oemcrypto/Android.bp @@ -0,0 +1,145 @@ +// Copyright 2024 Google LLC. This file and proprietary +// source code may only be used and distributed under the Widevine +// License Agreement. + +package { + default_applicable_licenses: ["vendor_widevine_license"], +} + +// Header Library for public OEMCrypto API. +cc_library_headers { + name: "oemcrypto_api_headers", + + export_include_dirs: [ + "include", + ], + + owner: "widevine", + min_sdk_version: "34", +} + +cc_defaults { + name: "oemcrypto_test_defaults", + team: "trendy_team_media_framework_drm", + + // The unit tests can access v15 functions through the dynamic adapter: + cflags: ["-DTEST_OEMCRYPTO_V15"], + + srcs: [ + "test/GEN_api_lock_file.c", + "test/oec_device_features.cpp", + "test/oec_decrypt_fallback_chain.cpp", + "test/oec_key_deriver.cpp", + "test/oec_session_util.cpp", + "test/oemcrypto_corpus_generator_helper.cpp", + "test/oemcrypto_session_tests_helper.cpp", + "test/oemcrypto_basic_test.cpp", + "test/oemcrypto_cast_test.cpp", + "test/oemcrypto_decrypt_test.cpp", + "test/oemcrypto_generic_crypto_test.cpp", + "test/oemcrypto_license_test.cpp", + "test/oemcrypto_provisioning_test.cpp", + "test/oemcrypto_security_test.cpp", + "test/oemcrypto_usage_table_test.cpp", + "test/oemcrypto_test.cpp", + "test/oemcrypto_test_android.cpp", + "test/oemcrypto_test_main.cpp", + "test/ota_keybox_test.cpp", + "util/src/bcc_validator.cpp", + "util/src/cbor_validator.cpp", + "util/src/cose_utils.cpp", + "util/src/device_info_validator.cpp", + "util/src/oemcrypto_cose_key.cpp", + "util/src/oemcrypto_ecc_key.cpp", + "util/src/oemcrypto_ed_key.cpp", + "util/src/oemcrypto_rsa_key.cpp", + "util/src/prov4_validation_helper.cpp", + "util/src/signed_csr_payload_validator.cpp", + "util/src/wvcrc32.cpp", + ":cdm_test_sleep", + ], + + include_dirs: [ + "external/googletest/googlemock/include", + "vendor/widevine/libwvdrmengine/cdm/core/include", + "vendor/widevine/libwvdrmengine/cdm/util/include", + "vendor/widevine/libwvdrmengine/cdm/util/test", + ], + local_include_dirs: [ + "test/fuzz_tests", + "include", + "odk/include", + "util/include", + ], + + static_libs: [ + "libcppbor", + "libjsmn", + "libgmock", + "libgtest", + "libgtest_main", + "libPlatformProperties", + ], + + shared_libs: [ + "libbase", + "libcrypto", + "libdl", + "libbinder_ndk", + "liblog", + "libmedia_omx", + "libprotobuf-cpp-lite", + "libstagefright_foundation", + "libutils", + "libz", + ], +} + +cc_defaults { + name: "oemcrypto_test_defaults_vendor", + static_libs: [ + "libcdm", + "libcdm_protos", + "libcdm_utils", + "libwv_odk", + "libwv_kdo", + "libwvlevel3", + ], +} + +cc_defaults { + name: "oemcrypto_test_defaults_system", + static_libs: [ + "libcdm.system", + "libcdm_protos.system", + "libcdm_utils.system", + "libwv_odk.system", + "libwv_kdo.system", + "libwvlevel3.system", + ], +} + +cc_test { + name: "oemcrypto_test", + team: "trendy_team_media_framework_drm", + owner: "widevine", + defaults: [ + "oemcrypto_test_defaults", + "oemcrypto_test_defaults_vendor", + ], + proprietary: true, +} + +cc_test { + name: "oemcrypto_test.system", + team: "trendy_team_media_framework_drm", + owner: "widevine", + defaults: [ + "oemcrypto_test_defaults", + "oemcrypto_test_defaults_system", + ], + test_config_template: "AndroidTest.xml", + test_suites: [ + "general-tests", + ], +} diff --git a/oemcrypto/AndroidTest.xml b/oemcrypto/AndroidTest.xml new file mode 100644 index 0000000..4d1ff2b --- /dev/null +++ b/oemcrypto/AndroidTest.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + diff --git a/oemcrypto/include/OEMCryptoCENC.h b/oemcrypto/include/OEMCryptoCENC.h index f2cf389..fc4fc2e 100644 --- a/oemcrypto/include/OEMCryptoCENC.h +++ b/oemcrypto/include/OEMCryptoCENC.h @@ -3,7 +3,7 @@ // License Agreement. /** - * @mainpage OEMCrypto API v19.5 + * @mainpage OEMCrypto API v20.0 * * OEMCrypto is the low level library implemented by the OEM to provide key and * content protection, usually in a separate secure memory or process space. The @@ -444,7 +444,7 @@ typedef struct { */ typedef struct { uint8_t signature[20]; // -- HMAC SHA1 of the rest of the report. - uint8_t status; // current status of entry. (OEMCrypto_Usage_Entry_Status) + uint8_t status; // current status of entry. (OEMCrypto_UsageEntryStatus) uint8_t clock_security_level; uint8_t pst_length; uint8_t padding; // make int64's word aligned. @@ -455,17 +455,6 @@ typedef struct { } __attribute__((packed)) OEMCrypto_PST_Report; #endif -/** - * Valid values for clock_security_level in OEMCrypto_PST_Report. - */ -typedef enum OEMCrypto_Clock_Security_Level { - kInsecureClock = 0, - kMonotonicClock = 1, - kSecureTimer = 1, // DEPRECATED. Do not use. - kSecureClock = 2, - kHardwareSecureClock = 3 -} OEMCrypto_Clock_Security_Level; - typedef uint8_t RSA_Padding_Scheme; // RSASSA-PSS with SHA1. Scheme used for DRM certificates for signing a license // request. @@ -766,6 +755,8 @@ typedef enum OEMCrypto_SignatureHashAlgorithm { #define OEMCrypto_GetBCCSignatureType _oecc156 #define OEMCrypto_GetPVRKey _oecc157 #define OEMCrypto_LoadPVRKey _oecc158 +#define OEMCrypto_LoadLicenseData _oecc159 +#define OEMCrypto_SaveLicenseData _oecc160 // clang-format on /// @addtogroup initcontrol @@ -1027,7 +1018,10 @@ OEMCryptoResult OEMCrypto_CloseSession(OEMCrypto_SESSION session); * state, an error of OEMCrypto_ERROR_INVALID_CONTEXT is returned. * * @param[in] session: handle for the session to be used. - * @param[out] nonce: pointer to memory to receive the computed nonce. + * @param[out] nonce pointer to memory to receive the computed nonce. The nonce + * will only be stored into this memory location if the function returns + * OEMCrypto_SUCCESS. If any other OEMCryptoResult is returned, the contents + * of the memory pointed to by nonce will remain unchanged. * * Results: * nonce: the nonce is also stored in secure memory. @@ -1192,14 +1186,12 @@ OEMCryptoResult OEMCrypto_PrepAndSignLicenseRequest( * larger than the supported size. * * @threading - * This is a "Session Function" and may be called simultaneously with session - * functions for other sessions but not simultaneously with other functions - * for this session. It will not be called simultaneously with initialization - * or usage table functions. It is as if the CDM holds a write lock for this - * session, and a read lock on the OEMCrypto system. + * This is a "Usage Table Function" and will not be called simultaneously + * with any other function, as if the CDM holds a write lock on the OEMCrypto + * system. * * @version - * This method is new in API version 19. + * This method is new in API version 20. */ OEMCryptoResult OEMCrypto_PrepAndSignReleaseRequest( OEMCrypto_SESSION session, uint8_t* message, size_t message_length, @@ -1319,12 +1311,12 @@ OEMCryptoResult OEMCrypto_PrepAndSignRenewalRequest( * mac_keys may be deleted. * * If the field license_type is OEMCrypto_ContentLicense, then the fields - * key_id and key_data in an OEMCrypto_KeyObject are loaded in to the + * key_id and key_data in an OEMCrypto_KeyObjectV2 are loaded in to the * content_key_id and content_key_data fields of the key table entry. In this * case, entitlement key ids and entitlement key data is left blank. * * If the field license_type is OEMCrypto_EntitlementLicense, then the - * fields key_id and key_data in an OEMCrypto_KeyObject are loaded in to the + * fields key_id and key_data in an OEMCrypto_KeyObjectV2 are loaded in to the * entitlement_key_id and entitlement_key_data fields of the key table entry. * In this case, content key ids and content key data will be loaded later * with a call to OEMCrypto_LoadEntitledContentKeys(). @@ -1381,8 +1373,8 @@ OEMCryptoResult OEMCrypto_PrepAndSignRenewalRequest( * 22. If key_array_length == 0, then return * OEMCrypto_ERROR_INVALID_CONTEXT. * 23. If this session is associated with a usage table entry, and that - * entry is marked as "inactive" (either kInactiveUsed or - * kInactiveUnused), then the keys are not loaded, and the error + * entry is marked as "inactive" (either OEMCrypto_InactiveUsed or + * OEMCrypto_InactiveUnused), then the keys are not loaded, and the error * OEMCrypto_ERROR_LICENSE_INACTIVE is returned. * 24. The data in enc_mac_keys_iv is not identical to the 16 bytes before * enc_mac_keys. If it is, return OEMCrypto_ERROR_INVALID_CONTEXT. @@ -1643,14 +1635,12 @@ OEMCryptoResult OEMCrypto_LoadRenewal(OEMCrypto_SESSION session, * @retval OEMCrypto_ERROR_SIGNATURE_FAILURE * * @threading - * This is a "Session Function" and may be called simultaneously with session - * functions for other sessions but not simultaneously with other functions - * for this session. It will not be called simultaneously with initialization - * or usage table functions. It is as if the CDM holds a write lock for this - * session, and a read lock on the OEMCrypto system. + * This is a "Usage Table Function" and will not be called simultaneously + * with any other function, as if the CDM holds a write lock on the OEMCrypto + * system. * * @version - * This method is new in API version 19. + * This method is new in API version 20. */ OEMCryptoResult OEMCrypto_LoadRelease(OEMCrypto_SESSION session, const uint8_t* message, @@ -2491,8 +2481,8 @@ OEMCryptoResult OEMCrypto_GetKeyHandle(OEMCrypto_SESSION session, * OEMCrypto_WARNING_MIXED_OUTPUT_PROTECTION. (See note on delayed * error conditions below) * 5. If the current key has an entry in the Usage Table, and the status of - * that entry is either kInactiveUsed or kInactiveUnused, then return the - * error OEMCrypto_ERROR_LICENSE_INACTIVE. + * that entry is either OEMCrypto_InactiveUsed or OEMCrypto_InactiveUnused, + * then return the error OEMCrypto_ERROR_LICENSE_INACTIVE. * 6. If a Decrypt Hash has been initialized via OEMCrypto_SetDecryptHash(), * and the current key's control block does not have the * Allow_Hash_Verification bit set, then do not compute a hash and @@ -2732,8 +2722,8 @@ OEMCryptoResult OEMCrypto_CopyBuffer( * 1. The control bit for the key shall have the Allow_Encrypt set. If not, * return OEMCrypto_ERROR_UNKNOWN_FAILURE. * 2. If the key has an entry in the Usage Table, and the status of that - * entry is either kInactiveUsed or kInactiveUnused, then return the - * error OEMCrypto_ERROR_LICENSE_INACTIVE. + * entry is either OEMCrypto_InactiveUsed or OEMCrypto_InactiveUnused, + * then return the error OEMCrypto_ERROR_LICENSE_INACTIVE. * * @param[in] key_handle: pointer to a buffer containing the key handle for a * key previously installed with OEMCrypto_GetKeyHandle(). @@ -2816,8 +2806,8 @@ OEMCryptoResult OEMCrypto_Generic_Encrypt( * 2. If the key's control block has the Data_Path_Type bit set, then return * OEMCrypto_ERROR_DECRYPT_FAILED. * 3. If the key has an entry in the Usage Table, and the status of that - * entry is either kInactiveUsed or kInactiveUnused, then return the - * error OEMCrypto_ERROR_LICENSE_INACTIVE. + * entry is either OEMCrypto_InactiveUsed or OEMCrypto_InactiveUnused, + * then return the error OEMCrypto_ERROR_LICENSE_INACTIVE. * * @param[in] key_handle: pointer to a buffer containing the key handle for a * key previously installed with OEMCrypto_GetKeyHandle(). @@ -2896,8 +2886,8 @@ OEMCryptoResult OEMCrypto_Generic_Decrypt( * returned, and the data is not signed. * 1. The control bit for the key shall have the Allow_Sign set. * 2. If the key has an entry in the Usage Table, and the status of that - * entry is either kInactiveUsed or kInactiveUnused, then return the - * error OEMCrypto_ERROR_LICENSE_INACTIVE. + * entry is either OEMCrypto_InactiveUsed or OEMCrypto_InactiveUnused, + * then return the error OEMCrypto_ERROR_LICENSE_INACTIVE. * * @param[in] key_handle: pointer to a buffer containing the key handle for a * key previously installed with OEMCrypto_GetKeyHandle(). @@ -2986,8 +2976,8 @@ OEMCryptoResult OEMCrypto_Generic_Sign(const uint8_t* key_handle, * signature mismatch will always take the same time as a successful * comparison). * 4. If the key has an entry in the Usage Table, and the status of that - * entry is either kInactiveUsed or kInactiveUnused, then return the - * error OEMCrypto_ERROR_LICENSE_INACTIVE. + * entry is either OEMCrypto_InactiveUsed or OEMCrypto_InactiveUnused, + * then return the error OEMCrypto_ERROR_LICENSE_INACTIVE. * * @param[in] key_handle: pointer to a buffer containing the key handle for a * key previously installed with OEMCrypto_GetKeyHandle(). @@ -4314,8 +4304,8 @@ OEMCryptoResult OEMCrypto_LoadProvisioning( * Receiver certificates may refuse to load these keys and return an error of * OEMCrypto_ERROR_NOT_IMPLEMENTED. The main use case for these alternative * signing algorithms is to support devices that use X509 certificates for - * authentication when acting as a ChromeCast receiver. This is not needed for - * devices that wish to send data to a ChromeCast. Keys loaded from this + * authentication when acting as a Google Cast receiver. This is not needed for + * devices that wish to send data to a Google Cast. Keys loaded from this * function may not be used with OEMCrypto_PrepAndSignLicenseRequest(). * * First, OEMCrypto should generate three secondary keys, mac_key[server], @@ -4388,8 +4378,8 @@ OEMCryptoResult OEMCrypto_LoadProvisioning( * algorithms may refuse to load these keys and return an error of * OEMCrypto_ERROR_NOT_IMPLEMENTED. The main use case for these * alternative signing algorithms is to support devices that use X.509 - * certificates for authentication when acting as a ChromeCast receiver. - * This is not needed for devices that wish to send data to a ChromeCast. + * certificates for authentication when acting as a Google Cast receiver. + * This is not needed for devices that wish to send data to a Google Cast. * 7. After possibly skipping past the first 8 bytes signifying the allowed * signing algorithm, the rest of the buffer private_key contains an ECC * private key or an RSA private key in PKCS#8 binary DER encoded @@ -4562,7 +4552,7 @@ OEMCryptoResult OEMCrypto_LoadTestRSAKey(void); * * The second padding scheme is for devices that use X509 certificates for * authentication. The main example is devices that work as a Cast receiver, - * like a ChromeCast, not for devices that wish to send to the Cast device, + * like a Google Cast, not for devices that wish to send to the Cast device, * such as almost all Android devices. OEMs that do not support X509 * certificate authentication need not implement this function and can return * OEMCrypto_ERROR_NOT_IMPLEMENTED. @@ -4703,8 +4693,8 @@ OEMCryptoResult OEMCrypto_PrepAndSignProvisioningRequest( * will encrypt and sign the new, empty, header and return it in the provided * buffer. * - * The new entry should be created with a status of kUnused and all times - * times should be set to 0. + * The new entry should be created with a status of OEMCrypto_Unused and all + * times times should be set to 0. * * Devices that do not implement a Session Usage Table may return * OEMCrypto_ERROR_NOT_IMPLEMENTED. @@ -5056,14 +5046,15 @@ OEMCryptoResult OEMCrypto_DeactivateUsageEntry(OEMCrypto_SESSION session, * * Valid values for status are: * - * - 0 = kUnused -- the keys have not been used to decrypt. - * - 1 = kActive -- the keys have been used, and have not been deactivated. - * - 2 = kInactive - deprecated. Use kInactiveUsed or kInactiveUnused. - * - 3 = kInactiveUsed -- the keys have been marked inactive after being - * active. - * - 4 = kInactiveUnused -- they keys have been marked inactive, but were - * never active. - * The clock_security_level is reported as follows: + * - 0 = OEMCrypto_Unused -- the keys have not been used to decrypt. + * - 1 = OEMCrypto_Active -- the keys have been used, and have not been + * deactivated. + * - 2 = OEMCrypto_Inactive - deprecated. Use OEMCrypto_InactiveUsed or + * OEMCrypto_InactiveUnused. + * - 3 = OEMCrypto_InactiveUsed -- the keys have been marked inactive after + * being active. + * - 4 = OEMCrypto_InactiveUnused -- they keys have been marked inactive, but + * were never active. The clock_security_level is reported as follows: * * - 0 = Insecure Clock - clock just uses system time. * - 1 = Secure Timer - clock runs from a secure timer which is initialized @@ -5142,7 +5133,7 @@ OEMCryptoResult OEMCrypto_ReportUsage(OEMCrypto_SESSION session, * or deactivating the license. * * @param[in] session: handle for the session to be used. - * @param[out] status: the enumeration of OEMCrypto_Usage_Entry_Status. + * @param[out] status: the enumeration of OEMCrypto_UsageEntryStatus. * @param[out] seconds_since_license_received: the time since the license being * requested in seconds. * @param[out] seconds_since_first_decrypt: the time since playback has @@ -5163,7 +5154,7 @@ OEMCryptoResult OEMCrypto_ReportUsage(OEMCrypto_SESSION session, * This method is new in API version 19. */ OEMCryptoResult OEMCrypto_GetUsageEntryInfo( - OEMCrypto_SESSION session, OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_SESSION session, OEMCrypto_UsageEntryStatus* status, int64_t* seconds_since_license_received, int64_t* seconds_since_first_decrypt); @@ -5306,19 +5297,13 @@ OEMCryptoResult OEMCrypto_GetBootCertificateChain( size_t* additional_signature_length); /** - * Generates a key pair used in OEM and DRM certificate provisioning. The public - * key is supposed to be certified by the server. The private key is wrapped - * with the encryption key so it can be stored in the file system. + * Generates a key pair used in DRM certificate provisioning. The public key is + * supposed to be certified by the server. The private key is wrapped with the + * encryption key, so it can be stored in the file system. * - * The |public_key_signature| output is formatted differently depending - * on whether or not an OEM private key has been loaded. - * - * If an OEM private key is unavailable, the request is assumed to be for OEM - * certificate provisioning. In this case, the public key is signed by the - * device private key. The format of |public_key_signature| in this case is a - * COSE_Sign1 CBOR array. The format is described in RFC 8152 Section 4.2 and - * 4.4, as well as Android IRemotelyProvisionedComponent.aidl (under - * "SignedData") + * The public key is signed by the device private key. |public_key_signature| is + * a COSE_Sign1 CBOR array, described in RFC 8152 Section 4.2 and 4.4 as well as + * Android IRemotelyProvisionedComponent.aidl (under "SignedData"). * * ~~~ * |public_key_signature|: COSE_Sign1 CBOR array @@ -5356,35 +5341,25 @@ OEMCryptoResult OEMCrypto_GetBootCertificateChain( * ] * ~~~ * - * If an OEM private key is available, the request is assumed to be for DRM - * certificate provisioning and the public key is signed by the OEM private key. - * If the OEM private key is an RSA key, then |public_key_signature| is the raw - * output of the RSA sign operation with RSASSA-PSS padding. If the OEM private - * key is an ECC key, then |public_key_signature| is the ASN.1 DER-encoded (R,S) - * signature as specified in RFC 3279 2.2.3. - * - * After this function completes successfully, the session will hold a private - * key and will be ready for a call to - * OEMCrypto_PrepAndSignProvisioningRequest(). In particular, when this - * function is used to generate a DRM Certificate key pair, the session will be - * ready to sign a provisioning request with the DRM Cert private key. When this - * function is used to generate an OEM Certificate key pair, the session will be - * ready to sign a provisioning request with the OEM Cert private key. + * After this function completes successfully, the session will hold the private + * key and be ready for a call to OEMCrypto_PrepAndSignProvisioningRequest(). In + * particular, the session will be ready to sign a provisioning request with the + * DRM cert private key. * * The public key shall be an ASN.1 DER-encoded SubjectPublicKeyInfo as * specified in RFC 5280. Widevine recommends ECC keys for Provisioning 4.0, but * an RSA key may also be used. If the key is an RSA key, then the encoding - * should use "rsaEncryption" (OID 1.2.840.113549.1.1.1), and not RSASSA-PSS. + * should use "rsaEncryption" (OID 1.2.840.113549.1.1.1) instead of RSASSA-PSS. * * @param[in] session: session id. * @param[out] public_key: pointer to the buffer that receives the public key - * that is to be certified by the server. The key must be an ASN.1 - * DER-encoded SubjectPublicKeyInfo as specified in RFC 5280. + * that is to be certified by the server. The key is an ASN.1 DER-encoded + * SubjectPublicKeyInfo as specified in RFC 5280. * @param[in,out] public_key_length: on input, size of the caller's public_key * buffer. On output, the number of bytes written into the buffer. * @param[out] public_key_signature: pointer to the buffer that receives the - * signature of the public key. The format depends on whether an OEM private - * key has been loaded. + * signature of the public key. The signature is a COSE_Sign1 CBOR array as + * described in in RFC 8152. * @param[in,out] public_key_signature_length: on input, size of the caller's * public_key_signature buffer. On output, the number of bytes written into * the buffer. @@ -5410,7 +5385,7 @@ OEMCryptoResult OEMCrypto_GetBootCertificateChain( * session, and a read lock on the OEMCrypto system. * * @version - * This method is new in API version 17. + * This method was new in API version 17 and changed in API version 20. */ OEMCryptoResult OEMCrypto_GenerateCertificateKeyPair( OEMCrypto_SESSION session, uint8_t* public_key, size_t* public_key_length, @@ -5570,43 +5545,6 @@ OEMCryptoResult OEMCrypto_GetDeviceSignedCsrPayload( const uint8_t* encoded_device_info, size_t encoded_device_info_length, uint8_t* signed_csr_payload, size_t* signed_csr_payload_length); -/** - * Loads an OEM private key to a session. The key will be used in signing DRM - * certificate request, or the public key generated by calling - * OEMCrypto_GenerateCertificateKeyPair. - * - * @param[in] session: session id. - * @param[in] key_type: type of the leaf key (RSA or ECC). - * @param[in] wrapped_private_key: the encrypted private key. This is the - * wrapped key generated by OEMCrypto_GenerateCertificateKeyPair. - * @param[in] wrapped_private_key_length: length of |wrapped_private_key| in - * bytes. - * - * @retval OEMCrypto_SUCCESS - * @retval OEMCrypto_ERROR_INVALID_CONTEXT - * @retval OEMCrypto_ERROR_NO_DEVICE_KEY - * @retval OEMCrypto_ERROR_INVALID_SESSION - * @retval OEMCrypto_ERROR_INVALID_KEY - * @retval OEMCrypto_ERROR_INSUFFICIENT_RESOURCES - * @retval OEMCrypto_ERROR_UNKNOWN_FAILURE - * @retval OEMCrypto_ERROR_SESSION_LOST_STATE - * @retval OEMCrypto_ERROR_SYSTEM_INVALIDATED - * @retval OEMCrypto_ERROR_NOT_IMPLEMENTED - * @retval OEMCrypto_ERROR_UNKNOWN_FAILURE - * - * @threading - * This is a "Session Function" and may be called simultaneously with session - * functions for other sessions but not simultaneously with other functions - * for this session. It will not be called simultaneously with initialization - * or usage table functions. It is as if the CDM holds a write lock for this - * session, and a read lock on the OEMCrypto system. - * - * @version - * This method is new in API version 17. - */ -OEMCryptoResult OEMCrypto_InstallOemPrivateKey( - OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, - const uint8_t* wrapped_private_key, size_t wrapped_private_key_length); /// @} /// @addtogroup test_verify @@ -5850,6 +5788,37 @@ OEMCryptoResult OEMCrypto_FreeSecureBuffer( OEMCrypto_SESSION session, OEMCrypto_DestBufferDesc* output_descriptor, int secure_fd); +/** + * Fill a buffer with hardware-generated random data to diagnose entropy issues. + * + * @param[out] random_data: pointer to the buffer that receives random data + * @param[in] random_data_length: length of the random data buffer in bytes + * + * @retval OEMCrypto_SUCCESS success + * @retval OEMCrypto_ERROR_NOT_IMPLEMENTED OEMCrypto is a production build, and + * does not support debug or test-only functions. + * @retval OEMCrypto_ERROR_RNG_FAILED + * @retval OEMCrypto_ERROR_BUFFER_TOO_LARGE + * @retval OEMCrypto_ERROR_SYSTEM_INVALIDATED + * + * @buffer_size + * OEMCrypto shall support random_data_length sizes of at least 100 KiB for + * random number generation. + * OEMCrypto shall return OEMCrypto_ERROR_BUFFER_TOO_LARGE if the buffer is + * larger than the supported size. + * + * @threading + * This is a "Property Function" and may be called simultaneously with any + * other property function or session function, but not any initialization or + * usage table function, as if the CDM holds a read lock on the OEMCrypto + * system. + * + * @version + * This method changed in API version 20. + */ +OEMCryptoResult OEMCrypto_GetRandom(OEMCrypto_SharedMemory* random_data, + size_t random_data_length); + /// @} /* @@ -6082,7 +6051,7 @@ OEMCryptoResult OEMCrypto_DeleteOldUsageTable(void); */ OEMCryptoResult OEMCrypto_CreateOldUsageEntry( uint64_t time_since_license_received, uint64_t time_since_first_decrypt, - uint64_t time_since_last_decrypt, OEMCrypto_Usage_Entry_Status status, + uint64_t time_since_last_decrypt, OEMCrypto_UsageEntryStatus status, uint8_t* server_mac_key, uint8_t* client_mac_key, const uint8_t* pst, size_t pst_length); @@ -6223,14 +6192,6 @@ OEMCryptoResult OEMCrypto_RefreshKeys( const uint8_t* signature, size_t signature_length, size_t num_keys, const OEMCrypto_KeyRefreshObject* key_array); -/** - * OEMCrypto_GetRandom - * @deprecated - * OEMCrypto_GetRandom is not needed to export random numbers. - */ -OEMCryptoResult OEMCrypto_GetRandom(uint8_t* random_data, - size_t random_data_length); - /** * OEMCrypto_SelectKey * @deprecated @@ -6348,6 +6309,36 @@ OEMCryptoResult OEMCrypto_LoadProvisioning_V18( size_t signature_length, uint8_t* wrapped_private_key, size_t* wrapped_private_key_length); +/** + * OEMCrypto_InstallOemPrivateKey + * TODO(b/374834498): Mark as deprecated once build system is updated + * to handle removing symbol. + * Not required for the current version of OEMCrypto. Declared here to + * help with backward compatibility. + * + * @param[in] session: session id. + * @param[in] key_type: type of the leaf key (RSA or ECC). + * @param[in] wrapped_private_key: the encrypted private key. This is the + * wrapped key generated by OEMCrypto_GenerateCertificateKeyPair. + * @param[in] wrapped_private_key_length: length of |wrapped_private_key| in + * bytes. + * + * @retval OEMCrypto_SUCCESS + * @retval OEMCrypto_ERROR_INVALID_CONTEXT + * @retval OEMCrypto_ERROR_NO_DEVICE_KEY + * @retval OEMCrypto_ERROR_INVALID_SESSION + * @retval OEMCrypto_ERROR_INVALID_KEY + * @retval OEMCrypto_ERROR_INSUFFICIENT_RESOURCES + * @retval OEMCrypto_ERROR_UNKNOWN_FAILURE + * @retval OEMCrypto_ERROR_SESSION_LOST_STATE + * @retval OEMCrypto_ERROR_SYSTEM_INVALIDATED + * @retval OEMCrypto_ERROR_NOT_IMPLEMENTED + * @retval OEMCrypto_ERROR_UNKNOWN_FAILURE + */ +OEMCryptoResult OEMCrypto_InstallOemPrivateKey( + OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, + const uint8_t* wrapped_private_key, size_t wrapped_private_key_length); + /****************************************************************************/ /****************************************************************************/ /* The following functions are used by internal L3 CDMs and are not required by @@ -6398,6 +6389,44 @@ OEMCryptoResult OEMCrypto_UseSecondaryKey(OEMCrypto_SESSION session_id, */ OEMCryptoResult OEMCrypto_MarkOfflineSession(OEMCrypto_SESSION session); +/** + * Loads the license data into the given session. + * + * @param[in] session: session id for operation. + * @param[in] data: the buffer to import. + * @param[in] data_length: the number of bytes in |data|. + * + * @ignore + * @retval OEMCrypto_SUCCESS on success + * @retval OEMCrypto_ERROR_INVALID_SESSION + * @retval OEMCrypto_ERROR_INVALID_CONTEXT + * @retval OEMCrypto_ERROR_SESSION_STATE_LOST + * @retval OEMCrypto_ERROR_SYSTEM_INVALIDATED + * @retval OEMCrypto_ERROR_NOT_IMPLEMENTED + */ +OEMCryptoResult OEMCrypto_LoadLicenseData(OEMCrypto_SESSION session, + const uint8_t* data, + size_t data_length); + +/** + * Saves the license data for the given session. + * + * @param[in] session: session id for operation. + * @param[out] data: the buffer to export into. + * @param[in,out] data_length: (in) length of the data buffer, in bytes. + * (out) actual length of the data, in bytes. + * + * @ignore + * @retval OEMCrypto_SUCCESS on success + * @retval OEMCrypto_ERROR_INVALID_SESSION + * @retval OEMCrypto_ERROR_INVALID_CONTEXT + * @retval OEMCrypto_ERROR_SESSION_STATE_LOST + * @retval OEMCrypto_ERROR_SYSTEM_INVALIDATED + * @retval OEMCrypto_ERROR_NOT_IMPLEMENTED + */ +OEMCryptoResult OEMCrypto_SaveLicenseData(OEMCrypto_SESSION session, + uint8_t* data, size_t* data_length); + #ifdef __cplusplus } #endif diff --git a/oemcrypto/include/level3.h b/oemcrypto/include/level3.h index 4efdb82..9be64ca 100644 --- a/oemcrypto/include/level3.h +++ b/oemcrypto/include/level3.h @@ -25,6 +25,7 @@ #define Level3_InstallKeyboxOrOEMCert _lcc03 #define Level3_GetKeyData _lcc04 #define Level3_IsKeyboxOrOEMCertValid _lcc05 +#define Level3_GetRandom _lcc06 #define Level3_GetDeviceID _lcc07 #define Level3_WrapKeyboxOrOEMCert _lcc08 #define Level3_OpenSession _lcc09 @@ -132,12 +133,15 @@ #define Level3_GetBCCSignatureType _lcc156 #define Level3_GetPVRKey _lcc157 #define Level3_LoadPVRKey _lcc158 +#define Level3_LoadLicenseData _lcc159 +#define Level3_SaveLicenseData _lcc160 #else #define Level3_Initialize _oecc01 #define Level3_Terminate _oecc02 #define Level3_InstallKeyboxOrOEMCert _oecc03 #define Level3_GetKeyData _oecc04 #define Level3_IsKeyboxOrOEMCertValid _oecc05 +#define Level3_GetRandom _oecc06 #define Level3_GetDeviceID _oecc07 #define Level3_WrapKeyboxOrOEMCert _oecc08 #define Level3_OpenSession _oecc09 @@ -248,6 +252,8 @@ #define Level3_GetBCCSignatureType _oecc156 #define Level3_GetPVRKey _oecc157 #define Level3_LoadPVRKey _oecc158 +#define Level3_LoadLicenseData _oecc159 +#define Level3_SaveLicenseData _oecc160 #endif #define Level3_GetInitializationState _oecl3o01 @@ -288,6 +294,8 @@ OEMCryptoResult Level3_GetOEMPublicCertificate(uint8_t* public_cert, size_t* public_cert_length); OEMCryptoResult Level3_GetDeviceID(uint8_t* deviceID, size_t* idLength); OEMCryptoResult Level3_GetKeyData(uint8_t* keyData, size_t* keyDataLength); +OEMCryptoResult Level3_GetRandom(OEMCrypto_SharedMemory* randomData, + size_t randomDataLength); OEMCryptoResult Level3_LoadOEMPrivateKey(OEMCrypto_SESSION session); OEMCryptoResult Level3_LoadDRMPrivateKey(OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, @@ -353,6 +361,10 @@ OEMCryptoResult Level3_DeactivateUsageEntry(OEMCrypto_SESSION session, OEMCryptoResult Level3_ReportUsage(OEMCrypto_SESSION session, const uint8_t* pst, size_t pst_length, uint8_t* buffer, size_t* buffer_length); +OEMCryptoResult Level3_GetUsageEntryInfo( + OEMCrypto_SESSION session, OEMCrypto_UsageEntryStatus* status, + int64_t* seconds_since_license_received, + int64_t* seconds_since_first_decrypt); bool Level3_IsSRMUpdateSupported(); OEMCryptoResult Level3_GetCurrentSRMVersion(uint16_t* version); OEMCryptoResult Level3_LoadSRM(const uint8_t* buffer, size_t buffer_length); @@ -413,6 +425,12 @@ OEMCryptoResult Level3_LoadRenewal(OEMCrypto_SESSION session, size_t core_message_length, const uint8_t* signature, size_t signature_length); +OEMCryptoResult Level3_LoadRelease(OEMCrypto_SESSION session, + const uint8_t* message, + size_t message_length, + size_t core_message_length, + const uint8_t* signature, + size_t signature_length); OEMCryptoResult Level3_RefreshKeys(OEMCrypto_SESSION session, const uint8_t* message, size_t message_length, @@ -435,6 +453,9 @@ OEMCryptoResult Level3_PrepAndSignLicenseRequest( OEMCryptoResult Level3_PrepAndSignRenewalRequest( OEMCrypto_SESSION session, uint8_t* message, size_t message_length, size_t* core_message_length, uint8_t* signature, size_t* signature_length); +OEMCryptoResult Level3_PrepAndSignReleaseRequest( + OEMCrypto_SESSION session, uint8_t* message, size_t message_length, + size_t* core_message_size, uint8_t* signature, size_t* signature_length); size_t Level3_MaximumUsageTableHeaderSize(); OEMCryptoResult Level3_AllocateSecureBuffer( OEMCrypto_SESSION session, size_t buffer_size, @@ -555,6 +576,10 @@ OEMCryptoResult Level3_GetEmbeddedDrmCertificate(uint8_t* public_cert, OEMCryptoResult Level3_UseSecondaryKey(OEMCrypto_SESSION session_id, bool dual_key); OEMCryptoResult Level3_MarkOfflineSession(OEMCrypto_SESSION session_id); +OEMCryptoResult Level3_LoadLicenseData(OEMCrypto_SESSION session, + const uint8_t* data, size_t data_length); +OEMCryptoResult Level3_SaveLicenseData(OEMCrypto_SESSION session, uint8_t* data, + size_t* data_length); OEMCryptoResult Level3_GetBCCSignatureType( OEMCrypto_BCCSignatureType* bcc_signature_type); diff --git a/oemcrypto/odk/Android.bp b/oemcrypto/odk/Android.bp index 6b966ce..8518b94 100644 --- a/oemcrypto/odk/Android.bp +++ b/oemcrypto/odk/Android.bp @@ -9,6 +9,7 @@ // CONSULT THE OWNERS AND opensource-licensing@google.com BEFORE // DEPENDING ON IT IN YOUR PROJECT. *** package { + default_team: "trendy_team_media_framework_drm", // See: http://go/android-license-faq // A large-scale-change added 'default_applicable_licenses' to import // all of the 'license_kinds' from "vendor_widevine_license" @@ -18,17 +19,49 @@ package { default_applicable_licenses: ["vendor_widevine_license"], } -cc_library_static { - name: "libwv_odk", - include_dirs: [ - "vendor/widevine/libwvdrmengine/oemcrypto/include", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/include", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/src", +// Header library for public ODK struct definitions. +// This intended for targets which need access to C struct definitions, +// but not the library functions (ex. corpus generator). +// +// Note: This header library technically exposes the API in addition to +// the struct definitions; however, the API is not intended to be used +// directly by dependencies. +cc_library_headers { + name: "odk_struct_headers", + + export_include_dirs: [ + "include", + ], + + owner: "widevine", + min_sdk_version: "34", +} + +// ---------------------------------------------------------------- +// Builds: libwv_odk.a +// The ODK Library - TEE (and sometimes REE) library for preparing +// and parsing client-side core messages. +cc_defaults { + name: "libwv_odk_defaults", + defaults: [ + "widevine_code_protected", + ], + + // Includes for building. + local_include_dirs: [ + "include", + "src", ], header_libs: [ - "jni_headers", - "libbase_headers", - "liblog_headers", + "oemcrypto_api_headers", + ], + + // Includes for components using libwv_odk.a. + export_include_dirs: [ + "include", + ], + export_header_lib_headers: [ + "oemcrypto_api_headers", ], srcs: [ @@ -40,53 +73,96 @@ cc_library_static { "src/odk_util.c", "src/serialization_base.c", ], - proprietary: true, owner: "widevine", min_sdk_version: "34", } -// ---------------------------------------------------------------- -// Builds libwv_kdo.a, The ODK Library companion (libwv_kdo) is used by -// the CDM and by oemcrypto tests, but not by oemcrypto implementations. cc_library_static { - name: "libwv_kdo", - include_dirs: [ - "vendor/widevine/libwvdrmengine/oemcrypto/include", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/include", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/src", + name: "libwv_odk", + + defaults: ["libwv_odk_defaults"], + proprietary: true, +} + +cc_library_static { + name: "libwv_odk.system", + + defaults: ["libwv_odk_defaults"], + system_ext_specific: true, +} + +// ---------------------------------------------------------------- +// Builds: libwv_kdo.a +// The ODK Library companion - Server (and sometimes REE) library +// for preparing and parsing server-side core messages. +cc_defaults { + name: "libwv_kdo_defaults", + defaults: [ + "widevine_code_protected", + ], + + // Includes for building. + local_include_dirs: [ + "include", + "src", ], header_libs: [ - "jni_headers", - "libbase_headers", - "liblog_headers", + "oemcrypto_api_headers", + ], + + // Includes for components using libwv_kdo.a. + export_include_dirs: [ + "include", + ], + export_header_lib_headers: [ + "oemcrypto_api_headers", ], srcs: [ - "src/core_message_deserialize.cpp", - "src/core_message_features.cpp", - "src/core_message_serialize.cpp", - "src/core_message_serialize_proto.cpp", + "src/core_message_deserialize.cpp", + "src/core_message_features.cpp", + "src/core_message_serialize.cpp", + "src/core_message_serialize_proto.cpp", ], + owner: "widevine", +} + +cc_library_static { + name: "libwv_kdo", + + defaults: ["libwv_kdo_defaults"], + proprietary: true, + static_libs: [ "libcdm_protos", "libwv_odk", ], +} - proprietary: true, +cc_library_static { + name: "libwv_kdo.system", - owner: "widevine", + defaults: ["libwv_kdo_defaults"], + system_ext_specific: true, + + static_libs: [ + "libcdm_protos.system", + "libwv_odk.system", + ], } // ---------------------------------------------------------------- -// Builds odk_test executable, which tests the ODK library. +// Builds: odk_test executable +// Unittests for the ODK library. cc_test { name: "odk_test", - include_dirs: [ - "vendor/widevine/libwvdrmengine/oemcrypto/include", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/include", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/src", + + // Includes for building. + local_include_dirs: [ + "src", + "test", ], // WARNING: Module tags are not supported in Soong. @@ -96,11 +172,16 @@ cc_test { // - Both 32 & 64 bit versions will be built (as appropriate) owner: "widevine", + team: "trendy_team_media_framework_drm", proprietary: true, + // Required to run as module tests in ATP + test_suites: ["general-tests"], + gtest: true, + require_root: true, + static_libs: [ "libcdm_protos", - "libcdm", "libwv_odk", "libwv_kdo", ], @@ -118,5 +199,4 @@ cc_test { "test/odk_test_helper.cpp", "test/odk_timer_test.cpp", ], - } diff --git a/oemcrypto/odk/include/OEMCryptoCENCCommon.h b/oemcrypto/odk/include/OEMCryptoCENCCommon.h index ffc70ab..7caed4a 100644 --- a/oemcrypto/odk/include/OEMCryptoCENCCommon.h +++ b/oemcrypto/odk/include/OEMCryptoCENCCommon.h @@ -76,9 +76,10 @@ typedef enum OEMCryptoResult { OEMCrypto_ERROR_LICENSE_INACTIVE = 47, OEMCrypto_ERROR_ENTRY_NEEDS_UPDATE = 48, OEMCrypto_ERROR_ENTRY_IN_USE = 49, - OEMCrypto_ERROR_USAGE_TABLE_UNRECOVERABLE = 50, /* Obsolete. Don't use. */ - /* Use OEMCrypto_ERROR_NO_CONTENT_KEY instead of KEY_NOT_LOADED. */ - OEMCrypto_KEY_NOT_LOADED = 51, /* Obsolete. */ + /** Obsolete. Don't use. */ + OEMCrypto_ERROR_USAGE_TABLE_UNRECOVERABLE = 50, + /** Obsolete. Use OEMCrypto_ERROR_NO_CONTENT_KEY instead of KEY_NOT_LOADED. */ + OEMCrypto_KEY_NOT_LOADED = 51, OEMCrypto_KEY_NOT_ENTITLED = 52, OEMCrypto_ERROR_BAD_HASH = 53, OEMCrypto_ERROR_OUTPUT_TOO_LARGE = 54, @@ -115,13 +116,14 @@ typedef enum OEMCryptoResult { /** * Valid values for status in the usage table. */ -typedef enum OEMCrypto_Usage_Entry_Status { - kUnused = 0, - kActive = 1, - kInactive = 2, /* Deprecated. Use kInactiveUsed or kInactiveUnused. */ - kInactiveUsed = 3, - kInactiveUnused = 4, -} OEMCrypto_Usage_Entry_Status; +typedef enum OEMCrypto_UsageEntryStatus { + OEMCrypto_Unused = 0, + OEMCrypto_Active = 1, + OEMCrypto_Inactive = 2, // Deprecated. Use OEMCrypto_InactiveUsed or + // OEMCrypto_InactiveUnused. + OEMCrypto_InactiveUsed = 3, + OEMCrypto_InactiveUnused = 4, +} OEMCrypto_UsageEntryStatus; /* Not used publicly. Not documented with Doxygen. */ typedef enum OEMCrypto_ProvisioningRenewalType { @@ -143,9 +145,9 @@ typedef enum OEMCrypto_LicenseType { * Private key type used in the provisioning response. */ typedef enum OEMCrypto_PrivateKeyType { - OEMCrypto_RSA_Private_Key = 0, - OEMCrypto_ECC_Private_Key = 1, - OEMCrypto_PrivateKeyType_MaxValue = OEMCrypto_ECC_Private_Key, + OEMCrypto_RSAPrivateKey = 0, + OEMCrypto_ECCPrivateKey = 1, + OEMCrypto_PrivateKeyType_MaxValue = OEMCrypto_ECCPrivateKey, } OEMCrypto_PrivateKeyType; /** @@ -153,10 +155,10 @@ typedef enum OEMCrypto_PrivateKeyType { * starts. */ typedef enum OEMCrypto_TimerDelayBase { - OEMCrypto_License_Start = 0, - OEMCrypto_License_Load = 1, - OEMCrypto_First_Decrypt = 2, - OEMCrypto_TimerDelayBase_MaxValue = OEMCrypto_First_Decrypt, + OEMCrypto_LicenseStart = 0, + OEMCrypto_LicenseLoad = 1, + OEMCrypto_FirstDecrypt = 2, + OEMCrypto_TimerDelayBase_MaxValue = OEMCrypto_FirstDecrypt, } OEMCrypto_TimerDelayBase; /** @@ -254,6 +256,102 @@ typedef struct { OEMCrypto_Substring key_control; } OEMCrypto_KeyObject; +/** + * Valid values for clock_security_level in OEMCrypto_PST_Report. + */ +typedef enum OEMCrypto_ClockSecurityLevel { + OEMCrypto_InsecureClock = 0, + OEMCrypto_MonotonicClock = 1, + OEMCrypto_SecureTimer = 1, // DEPRECATED. Do not use. + OEMCrypto_SecureClock = 2, + OEMCrypto_HardwareSecureClock = 3 +} OEMCrypto_ClockSecurityLevel; + +/** + * Content mitigation option for DeCENC attacks. + * + * For this enum, multiples of 2 will be used for the DeCENC mitigation options + * in order to correlate to how mitigation options are supported via bitmask. + */ +typedef enum OEMCrypto_DeCENC_Mitigation_Option { + OEMCrypto_DeCENC_Mitigation_Option_None = 0, + /** If the device can authenticate the signature of the bitstream. This is + the "long term" solution that requires the content stream to be signed by a + new key. + */ + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream = 1, + /** If the device can validate the bitstream. This means that OEMCrypto will + parse the decrypted bit stream and replace any I_PCM blocks with zeros. */ + OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream = 2, + /** If the device can restrict the decoding. This means that the codec will + parse the decrypted bit stream as usual, but will **not** pass through any + I_PCM blocks. */ + OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding = 4, + OEMCrypto_DeCENC_Mitigation_Option_MaxValue = + OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding, +} OEMCrypto_DeCENC_Mitigation_Option; + +/** + * Authentication key information for chosen DeCENC mitigation. + */ +typedef struct { + OEMCrypto_Substring authentication_key; + OEMCrypto_Substring authentication_key_iv; +} OEMCrypto_AuthenticationKeyInfo; + +/** + * Information about the required DeCENC mitigation option and the configuration + * options for the mitigation option. + * @param mitigation_option: the DeCENC mitigation option that the content + * provider wants to use for the key. + * @param configuration_options: optional configuration options to be used with + * required_decenc_mitigation. + */ +typedef struct { + OEMCrypto_DeCENC_Mitigation_Option mitigation_option; + union { + // Configuration option for OEMCrypto_DeCENC_Mitigation_Option = 1 where + // the bitstream is authenticated. + OEMCrypto_AuthenticationKeyInfo authentication_key_info; + } configuration_options; +} OEMCrypto_DeCENC_Mitigation_Info; + +/** + * Version 2 of OEMCrypto_KeyObject. Created for DeCENC attack mitigation which + * needs extra fields. + * + * Points to the relevant fields for a content key. The fields are extracted + * from the License Response message offered to ODK_ParseLicense(). Each + * field points to one of the components of the key. Key data, key control, + * and both IV fields are 128 bits (16 bytes): + * @param key_id: the unique id of this key. + * @param key_id_length: the size of key_id. OEMCrypto may assume this is at + * most 16. However, OEMCrypto shall correctly handle key id lengths + * from 1 to 16 bytes. + * @param key_data_iv: the IV for performing AES-128-CBC decryption of the + * key_data field. + * @param key_data - the key data. It is encrypted (AES-128-CBC) with the + * session's derived encrypt key and the key_data_iv. + * @param key_control_iv: the IV for performing AES-128-CBC decryption of the + * key_control field. + * @param key_control: the key control block. It is encrypted (AES-128-CBC) with + * the content key from the key_data field. + * @param decenc_mitigation_info: the DeCENC mitigation option that the + * content provider wants to use for the key and the configuration + * options for the mitigation option. + * + * The memory for the OEMCrypto_KeyObject fields is allocated and freed + * by the caller of ODK_ParseLicense(). + */ +typedef struct { + OEMCrypto_Substring key_id; + OEMCrypto_Substring key_data_iv; + OEMCrypto_Substring key_data; + OEMCrypto_Substring key_control_iv; + OEMCrypto_Substring key_control; + OEMCrypto_DeCENC_Mitigation_Info decenc_mitigation_info; +} OEMCrypto_KeyObjectV2; + /// @} #ifdef __cplusplus diff --git a/oemcrypto/odk/include/core_message_deserialize.h b/oemcrypto/odk/include/core_message_deserialize.h index fb80dc5..74c3a7a 100644 --- a/oemcrypto/odk/include/core_message_deserialize.h +++ b/oemcrypto/odk/include/core_message_deserialize.h @@ -99,6 +99,18 @@ bool CoreRenewedProvisioningRequestFromMessage( bool CoreCommonRequestFromMessage(const std::string& oemcrypto_core_message, ODK_CommonRequest* core_common_request); +/** + * Counterpart (deserializer) of decenc_mitigation_options_supported in + * ODK_PrepareCoreLicenseRequest (serializer) + * + * Parameters: + * [in] oemcrypto_core_message + * [out] decenc_mitigation_options_supported + */ +bool CoreDecencMitigationOptionsSupportedFromMessage( + const std::string& oemcrypto_core_message, + uint16_t& decenc_mitigation_options_supported); + } // namespace deserialize } // namespace oemcrypto_core_message diff --git a/oemcrypto/odk/include/core_message_features.h b/oemcrypto/odk/include/core_message_features.h index 3edfc02..d4d3531 100644 --- a/oemcrypto/odk/include/core_message_features.h +++ b/oemcrypto/odk/include/core_message_features.h @@ -18,11 +18,23 @@ namespace features { // this structure, we can turn off features at runtime. This is plain data, and // is essentially a version number. struct CoreMessageFeatures { + // If |prerelease_odk_messages| is true, then the server will serve prerelease + // devices. + explicit CoreMessageFeatures(bool prerelease_odk_messages = false); + // A default set of features. static const CoreMessageFeatures kDefaultFeatures; - // Create the default feature set for the given major version number. - static CoreMessageFeatures DefaultFeatures(uint32_t maximum_major_version); + // Create the default feature set for the given major version number. This + // function should only be used when running tests, or in special cases where + // a license is being generated out of band, like for an ATSC license. + static CoreMessageFeatures DefaultFeatures( + uint32_t maximum_major_version, + bool serve_prerelease_odk_messages = false); + + // If this is true, then the server should serve messages to a prerelease + // device. If this is false, requests from a prerelease device are rejected. + bool serve_prerelease_odk_messages = false; // This is the published version of the ODK Core Message library. The default // behavior is for the server to restrict messages to at most this version @@ -34,6 +46,19 @@ struct CoreMessageFeatures { bool operator!=(const CoreMessageFeatures &other) const { return !(*this == other); } + + // Verify that the give version number is valid. If the version is prerelease, + // and serve_prereease_messages is false, then this returns false. + bool ValidateRequestVersion(uint16_t major_version, + uint16_t minor_version) const; + + // Compute the minimum of the feature version and the request version. Return + // false if the request is prerelease and the serve_prerelease_messages is + // false. + bool GetResponseVersion(uint16_t request_major_version, + uint16_t request_minor_version, + uint16_t &response_major_version, + uint16_t &response_minor_version) const; }; std::ostream &operator<<(std::ostream &os, const CoreMessageFeatures &features); diff --git a/oemcrypto/odk/include/core_message_serialize.h b/oemcrypto/odk/include/core_message_serialize.h index 64e9427..40371e6 100644 --- a/oemcrypto/odk/include/core_message_serialize.h +++ b/oemcrypto/odk/include/core_message_serialize.h @@ -51,14 +51,10 @@ bool CreateCoreLicenseResponse(const CoreMessageFeatures& features, * Parameters: * [in] features feature support for response message. * [in] core_request - * [in] seconds_since_license_requested - * [in] seconds_since_first_decrypt * [out] oemcrypto_core_message */ bool CreateCoreReleaseResponse(const CoreMessageFeatures& features, const ODK_ReleaseRequest& core_request, - int64_t seconds_since_license_requested, - int64_t seconds_since_first_decrypt, std::string* oemcrypto_core_message); /** diff --git a/oemcrypto/odk/include/odk.h b/oemcrypto/odk/include/odk.h index 679d4e0..a29a0c3 100644 --- a/oemcrypto/odk/include/odk.h +++ b/oemcrypto/odk/include/odk.h @@ -158,7 +158,7 @@ OEMCryptoResult ODK_ReloadClockValues(ODK_ClockValues* clock_values, uint64_t time_of_license_request_signed, uint64_t time_of_first_decrypt, uint64_t time_of_last_decrypt, - enum OEMCrypto_Usage_Entry_Status status, + enum OEMCrypto_UsageEntryStatus status, uint64_t system_time_seconds); /** @@ -263,7 +263,9 @@ OEMCryptoResult ODK_DeactivateUsageEntry(ODK_ClockValues* clock_values); * bytes. (out) actual length of the core message, in bytes. * @param[in] nonce_values: pointer to the session's nonce data. * @param[in] message_count_info: information used for server-side anomaly - * detection + * detection + * @param[in] decenc_mitigation_options_supported: the supported decenc + * mitigation options from the device. * * @retval OEMCrypto_SUCCESS * @retval OEMCrypto_ERROR_SHORT_BUFFER: core_message_size is too small @@ -275,7 +277,8 @@ OEMCryptoResult ODK_DeactivateUsageEntry(ODK_ClockValues* clock_values); OEMCryptoResult ODK_PrepareCoreLicenseRequest( uint8_t* message, size_t message_length, size_t* core_message_size, const ODK_NonceValues* nonce_values, - const ODK_MessageCounterInfo* counter_info); + const ODK_MessageCounterInfo* counter_info, + uint16_t decenc_mitigation_options_supported); /** * Modifies the message to include a core license release at the beginning of @@ -295,13 +298,16 @@ OEMCryptoResult ODK_PrepareCoreLicenseRequest( * of the message. (in) size of buffer reserved for the core message, in * bytes. (out) actual length of the core message, in bytes. * @param[in] nonce_values: pointer to the session's nonce data. - * @param[in] status: the enumeration of OEMCrypto_Usage_Entry_Status + * @param[in] usage_entry_status: the enumeration of + * OEMCrypto_UsageEntryStatus * @param[in] clock_security_level: the enumeration of * OEMCryto_Clock_Security_Level * @param[in] seconds_since_license_requested: the time between the license * being requested and the release being generated in seconds * @param[in] seconds_since_first_decrypt: The time since playback has started * in seconds + * @param[in] pst: the provider session token. + * @param[in] pst_length: the length of the pst array. * @param[in,out] clock_values: the session's clock values. * @param[in] system_time_seconds: the current time on OEMCrypto's clock, in * seconds. @@ -311,14 +317,15 @@ OEMCryptoResult ODK_PrepareCoreLicenseRequest( * @retval OEMCrypto_ERROR_INVALID_CONTEXT * * @version - * This method is new in version 19 of the API. + * This method is new in version 20 of the API. */ OEMCryptoResult ODK_PrepareCoreReleaseRequest( uint8_t* message, size_t message_length, size_t* core_message_size, - ODK_NonceValues* nonce_values, uint32_t status, - uint32_t clock_security_level, int64_t seconds_since_license_requested, - int64_t seconds_since_first_decrypt, ODK_ClockValues* clock_values, - uint64_t system_time_seconds); + ODK_NonceValues* nonce_values, + OEMCrypto_UsageEntryStatus usage_entry_status, + uint32_t clock_security_level, uint64_t seconds_since_license_requested, + uint64_t seconds_since_first_decrypt, const uint8_t* pst, size_t pst_length, + ODK_ClockValues* clock_values, uint64_t system_time_seconds); /** * Modifies the message to include a core renewal request at the beginning of diff --git a/oemcrypto/odk/include/odk_message.h b/oemcrypto/odk/include/odk_message.h index aedd31a..648eade 100644 --- a/oemcrypto/odk/include/odk_message.h +++ b/oemcrypto/odk/include/odk_message.h @@ -5,14 +5,14 @@ #ifndef WIDEVINE_ODK_INCLUDE_ODK_MESSAGE_H_ #define WIDEVINE_ODK_INCLUDE_ODK_MESSAGE_H_ -#ifdef __cplusplus -extern "C" { -#endif - #include #include #include +#ifdef __cplusplus +extern "C" { +#endif + /* * ODK_Message is the structure that defines the serialized messages passed * between the REE and TEE. ODK_Message is an abstract data type that represents diff --git a/oemcrypto/odk/include/odk_structs.h b/oemcrypto/odk/include/odk_structs.h index 0b905aa..e0f4dc6 100644 --- a/oemcrypto/odk/include/odk_structs.h +++ b/oemcrypto/odk/include/odk_structs.h @@ -5,21 +5,21 @@ #ifndef WIDEVINE_ODK_INCLUDE_ODK_STRUCTS_H_ #define WIDEVINE_ODK_INCLUDE_ODK_STRUCTS_H_ -#ifdef __cplusplus -extern "C" { -#endif - #include #include "OEMCryptoCENCCommon.h" #include "odk_target.h" +#ifdef __cplusplus +extern "C" { +#endif + /* The version of this library. */ -#define ODK_MAJOR_VERSION 19 -#define ODK_MINOR_VERSION 5 +#define ODK_MAJOR_VERSION 20 +#define ODK_MINOR_VERSION 0 /* ODK Version string. Date changed automatically on each release. */ -#define ODK_RELEASE_DATE "ODK v19.5 2025-03-11" +#define ODK_RELEASE_DATE "ODK v20.0 2025-05-20" /* The lowest version number for an ODK message. */ #define ODK_FIRST_VERSION 16 @@ -33,6 +33,7 @@ extern "C" { * https://www.rfc-editor.org/rfc/rfc8949.html#name-specification-of-the-cbor-e * for an estimation of the required length. */ #define ODK_DEVICE_INFO_LEN_MAX 768 +#define ODK_PST_LEN_MAX 256 /// @addtogroup odk_timer /// @{ @@ -127,7 +128,7 @@ typedef struct { uint64_t time_of_renewal_request; uint64_t time_when_timer_expires; uint32_t timer_status; - enum OEMCrypto_Usage_Entry_Status status; + enum OEMCrypto_UsageEntryStatus status; } ODK_ClockValues; /** @@ -227,6 +228,10 @@ typedef struct { * @param dtcp2_required: indicates dtcp2 requirements of the license. * @param renewal_delay_base: indicates which time is used for the renewal timer * and playback timer starting point. + * @param decenc_mitigation_option_used: indicates whether DeCENC + * mitigation is used. Used to determine packing/unpacking of DeCENC + * fields in the key object. This field represents a boolean (zero or + * nonzero). * @param key_array_length: number of keys present. * @param key_array: set of keys to be installed. * @@ -244,8 +249,12 @@ typedef struct { uint32_t watermarking; OEMCrypto_DTCP2_CMI_Packet dtcp2_required; OEMCrypto_TimerDelayBase renewal_delay_base; - uint32_t key_array_length; - OEMCrypto_KeyObject key_array[ODK_MAX_NUM_KEYS]; + uint16_t decenc_mitigation_option_used; + uint16_t key_array_length; + // If |decenc_mitigation_option_used| is 0, then the key object V2 array will + // be treated like an array of OEMCrypto_KeyObject. The extra fields in + // OEMCrypto_KeyObjectV2 will be ignored. + OEMCrypto_KeyObjectV2 key_array[ODK_MAX_NUM_KEYS]; } ODK_ParsedLicense; /** @@ -267,12 +276,15 @@ typedef struct { * @param dtcp2_required: indicates dtcp2 requirements of the license. * @param renewal_delay_base: indicates which time is used for the renewal timer * and playback timer starting point. + * @param decenc_mitigation_option_used: indicates whether DeCENC + * mitigation is used. Used to determine packing/unpacking of DeCENC + * fields in the key object. * @param key_array_length: number of keys present. * @param key_array: set of keys to be installed. This is a pointer to an array * to allow packing a number of keys greater than |ODK_MAX_NUM_KEYS|. * * @version - * This struct changed in API version 18. + * This struct changed in API version 19. */ typedef struct { OEMCrypto_Substring enc_mac_keys_iv; @@ -285,8 +297,9 @@ typedef struct { uint32_t watermarking; OEMCrypto_DTCP2_CMI_Packet dtcp2_required; OEMCrypto_TimerDelayBase renewal_delay_base; - uint32_t key_array_length; - OEMCrypto_KeyObject* key_array; + uint16_t decenc_mitigation_option_used; + uint16_t key_array_length; + OEMCrypto_KeyObjectV2* key_array; } ODK_Packing_ParsedLicense; /** @@ -299,15 +312,24 @@ typedef struct { * @param enc_private_key_iv: IV for decrypting new private key. Size is 128 * bits. * @param encrypted_message_key: used for provisioning 3.0 to derive keys. + * @param curr_server_sealing_key: the server sealing key for the current + * firmware version. + * @param server_sealing_key_array_length: number of server sealing keys present + * for previous versions. + * @param server_sealing_key_array: set of server sealing keys for previous + * firmware versions. Set to a maximum of 5 for now. * * @version - * This struct changed in API version 16.2. + * This struct changed in API version 20.0. */ typedef struct { OEMCrypto_PrivateKeyType key_type; OEMCrypto_Substring enc_private_key; OEMCrypto_Substring enc_private_key_iv; OEMCrypto_Substring encrypted_message_key; /* Used for Prov 3.0 */ + OEMCrypto_Substring curr_server_sealing_key; + uint32_t server_sealing_key_array_length; + OEMCrypto_Substring server_sealing_key_array[ODK_MAX_NUM_SERVER_SEALING_KEYS]; } ODK_ParsedProvisioning; /// @} diff --git a/oemcrypto/odk/include/odk_target.h b/oemcrypto/odk/include/odk_target.h index 825a263..98eca91 100644 --- a/oemcrypto/odk/include/odk_target.h +++ b/oemcrypto/odk/include/odk_target.h @@ -9,5 +9,6 @@ // Maximum number of keys can be modified to suit target's resource tier. #define ODK_MAX_NUM_KEYS 32 +#define ODK_MAX_NUM_SERVER_SEALING_KEYS 5 #endif // WIDEVINE_ODK_INCLUDE_ODK_TARGET_H_ diff --git a/oemcrypto/odk/src/core_message_deserialize.cpp b/oemcrypto/odk/src/core_message_deserialize.cpp index e1cdbd0..f0a4634 100644 --- a/oemcrypto/odk/src/core_message_deserialize.cpp +++ b/oemcrypto/odk/src/core_message_deserialize.cpp @@ -4,7 +4,8 @@ #include "core_message_deserialize.h" -#include +#include + #include #include #include @@ -38,7 +39,7 @@ bool ParseRequest(uint32_t message_type, } const uint8_t* buf = - reinterpret_cast(oemcrypto_core_message.c_str()); + reinterpret_cast(oemcrypto_core_message.data()); const size_t buf_length = oemcrypto_core_message.size(); ODK_Message msg = ODK_Message_Create(const_cast(buf), buf_length); @@ -89,19 +90,18 @@ bool ParseRequest(uint32_t message_type, return true; } -} // namespace - -static bool GetNonceFromMessage(const std::string& oemcrypto_core_message, - ODK_NonceValues* nonce_values) { +bool GetNonceFromMessage(const std::string& oemcrypto_core_message, + ODK_NonceValues* nonce_values) { if (nonce_values == nullptr) return false; if (oemcrypto_core_message.size() < sizeof(ODK_CoreMessage)) return false; - ODK_CoreMessage core_message; const uint8_t* buf = - reinterpret_cast(oemcrypto_core_message.c_str()); + reinterpret_cast(oemcrypto_core_message.data()); ODK_Message msg = ODK_Message_Create(const_cast(buf), oemcrypto_core_message.size()); ODK_Message_SetSize(&msg, sizeof(ODK_CoreMessage)); + + ODK_CoreMessage core_message = {}; Unpack_ODK_CoreMessage(&msg, &core_message); if (ODK_Message_GetStatus(&msg) != MESSAGE_STATUS_OK) return false; *nonce_values = core_message.nonce_values; @@ -109,7 +109,7 @@ static bool GetNonceFromMessage(const std::string& oemcrypto_core_message, } bool CopyCounterInfo(ODK_MessageCounter* dest, ODK_MessageCounterInfo* src) { - if (!src || !dest) return false; + if (src == nullptr || dest == nullptr) return false; dest->master_generation_number = src->master_generation_number; dest->license_count = src->license_count; @@ -125,9 +125,21 @@ bool CopyCounterInfo(ODK_MessageCounter* dest, ODK_MessageCounterInfo* src) { return true; } +// Checks if the buffer is padded with zeros after |buffer_length| +// up to the given |max_length|. +bool IsBufferZeroPadded(const uint8_t* buffer, size_t buffer_length, + size_t max_length) { + if (buffer == nullptr || buffer_length > max_length) return false; + for (size_t i = buffer_length; i < max_length; ++i) { + if (buffer[i] != 0) return false; + } + return true; +} +} // namespace + bool CoreLicenseRequestFromMessage(const std::string& oemcrypto_core_message, ODK_LicenseRequest* core_license_request) { - ODK_NonceValues nonce; + ODK_NonceValues nonce = {}; if (!GetNonceFromMessage(oemcrypto_core_message, &nonce)) return false; if (nonce.api_major_version <= 17) { const auto unpacker = Unpack_ODK_PreparedLicenseRequestV17; @@ -174,7 +186,7 @@ bool CoreProvisioningRequestFromMessage( ODK_ProvisioningRequest* core_provisioning_request) { // Need to partially parse in order to get the nonce values, which will tell // us the major/minor version - ODK_NonceValues nonce; + ODK_NonceValues nonce = {}; if (!GetNonceFromMessage(oemcrypto_core_message, &nonce)) return false; if (nonce.api_major_version >= 18) { @@ -204,19 +216,19 @@ bool CoreProvisioningRequestFromMessage( unpacker)) { return false; } - const uint8_t* device_id = prepared_provision.device_id; const uint32_t device_id_length = prepared_provision.device_id_length; if (device_id_length > ODK_DEVICE_ID_LEN_MAX) { return false; } + const uint8_t* device_id = prepared_provision.device_id; if (device_id_length > 0) { - uint8_t zero[ODK_DEVICE_ID_LEN_MAX] = {}; - if (memcmp(zero, device_id + device_id_length, - ODK_DEVICE_ID_LEN_MAX - device_id_length) != 0) { + // Ensure all bytes after the set device ID bytes are zero. + if (!IsBufferZeroPadded(device_id, device_id_length, + ODK_DEVICE_ID_LEN_MAX)) { return false; } - core_provisioning_request->device_id.assign( - reinterpret_cast(device_id), device_id_length); + core_provisioning_request->device_id.assign(device_id, + device_id + device_id_length); } core_provisioning_request->renewal_type = OEMCrypto_NoRenewal; core_provisioning_request->renewal_data.clear(); @@ -237,18 +249,18 @@ bool CoreProvisioning40RequestFromMessage( &prepared_provision.counter_info)) { return false; } - const uint8_t* device_info = prepared_provision.device_info; const uint32_t device_info_length = prepared_provision.device_info_length; if (device_info_length > ODK_DEVICE_INFO_LEN_MAX) { return false; } - uint8_t zero[ODK_DEVICE_INFO_LEN_MAX] = {}; - if (memcmp(zero, device_info + device_info_length, - ODK_DEVICE_INFO_LEN_MAX - device_info_length) != 0) { + const uint8_t* device_info = prepared_provision.device_info; + // Ensure all bytes after the set device info bytes are zero. + if (!IsBufferZeroPadded(device_info, device_info_length, + ODK_DEVICE_INFO_LEN_MAX)) { return false; } core_provisioning_request->device_info.assign( - reinterpret_cast(device_info), device_info_length); + device_info, device_info + device_info_length); return true; } @@ -262,27 +274,25 @@ bool CoreRenewedProvisioningRequestFromMessage( &prepared_provision, unpacker)) { return false; } - const uint8_t* device_id = prepared_provision.device_id; const uint32_t device_id_length = prepared_provision.device_id_length; if (device_id_length > ODK_DEVICE_ID_LEN_MAX) { return false; } - uint8_t zero[ODK_DEVICE_ID_LEN_MAX] = {}; - if (memcmp(zero, device_id + device_id_length, - ODK_DEVICE_ID_LEN_MAX - device_id_length) != 0) { + const uint8_t* device_id = prepared_provision.device_id; + if (!IsBufferZeroPadded(device_id, device_id_length, ODK_DEVICE_ID_LEN_MAX)) { return false; } - core_provisioning_request->device_id.assign( - reinterpret_cast(device_id), device_id_length); + core_provisioning_request->device_id.assign(device_id, + device_id + device_id_length); - if (prepared_provision.renewal_data_length > - sizeof(prepared_provision.renewal_data)) { + const uint32_t renewal_data_length = prepared_provision.renewal_data_length; + if (renewal_data_length > sizeof(prepared_provision.renewal_data)) { return false; } + const uint8_t* renewal_data = prepared_provision.renewal_data; core_provisioning_request->renewal_type = OEMCrypto_RenewalACert; core_provisioning_request->renewal_data.assign( - reinterpret_cast(prepared_provision.renewal_data), - prepared_provision.renewal_data_length); + renewal_data, renewal_data + renewal_data_length); return true; } @@ -302,5 +312,29 @@ bool CoreCommonRequestFromMessage(const std::string& oemcrypto_core_message, return success; } +bool CoreDecencMitigationOptionsSupportedFromMessage( + const std::string& oemcrypto_core_message, + uint16_t& decenc_mitigation_options_supported) { + ODK_NonceValues nonce = {}; + if (!GetNonceFromMessage(oemcrypto_core_message, &nonce)) return false; + // DeCENC mitigation options are only supported in v18.0 and above. + if (nonce.api_major_version >= 18) { + const auto unpacker = Unpack_ODK_PreparedLicenseRequest; + ODK_LicenseRequest core_license_request; + ODK_PreparedLicenseRequest prepared_license = {}; + if (!ParseRequest(ODK_License_Request_Type, oemcrypto_core_message, + &core_license_request, &prepared_license, unpacker)) { + return false; + } + decenc_mitigation_options_supported = + prepared_license.decenc_mitigation_options_supported; + return true; + } + // For V17 and below, DeCENC mitigation options are not supported. Function + // will not return false, but |decenc_mitigation_options_supported| returns 0 + // by default. + return true; +} + } // namespace deserialize } // namespace oemcrypto_core_message diff --git a/oemcrypto/odk/src/core_message_features.cpp b/oemcrypto/odk/src/core_message_features.cpp index 3e8abf4..f8c3dce 100644 --- a/oemcrypto/odk/src/core_message_features.cpp +++ b/oemcrypto/odk/src/core_message_features.cpp @@ -4,20 +4,33 @@ #include "core_message_features.h" +#include + +#include #include +#include +#include + +#include "odk_structs.h" namespace oemcrypto_core_message { namespace features { +namespace { +// The first major version where we decided to prerelease a test version the ODK +// library. +constexpr int kFirstPrereleaseVersion = 20; +} // namespace + const CoreMessageFeatures CoreMessageFeatures::kDefaultFeatures; -bool CoreMessageFeatures::operator==(const CoreMessageFeatures &other) const { +bool CoreMessageFeatures::operator==(const CoreMessageFeatures& other) const { return maximum_major_version == other.maximum_major_version && maximum_minor_version == other.maximum_minor_version; } CoreMessageFeatures CoreMessageFeatures::DefaultFeatures( - uint32_t maximum_major_version) { - CoreMessageFeatures features; + uint32_t maximum_major_version, bool serve_prerelease_odk_messages) { + CoreMessageFeatures features(serve_prerelease_odk_messages); features.maximum_major_version = maximum_major_version; // The default minor version is the highest for each major version. This also // needs to be updated with new version releases in @@ -33,7 +46,10 @@ CoreMessageFeatures CoreMessageFeatures::DefaultFeatures( features.maximum_minor_version = 4; // 18.4 break; case 19: - features.maximum_minor_version = 5; // 19.5 + features.maximum_minor_version = 2; // 19.2 + break; + case 20: + features.maximum_minor_version = 0; // 20.0 break; default: features.maximum_minor_version = 0; @@ -41,8 +57,74 @@ CoreMessageFeatures CoreMessageFeatures::DefaultFeatures( return features; } -std::ostream &operator<<(std::ostream &os, - const CoreMessageFeatures &features) { +CoreMessageFeatures::CoreMessageFeatures(bool prerelease_odk_messages) + : serve_prerelease_odk_messages(prerelease_odk_messages) { + if (prerelease_odk_messages) { + // If we are allowed to serve prerelease messages, then the maximum version + // is the same as the defined constants. + maximum_major_version = ODK_MAJOR_VERSION; + maximum_minor_version = ODK_MINOR_VERSION; + } +} + +// Validate request version. +bool CoreMessageFeatures::ValidateRequestVersion(uint16_t major_version, + uint16_t minor_version) const { + // First, make sure that this object has valid values. I.e. the maximum is not + // a version from the future. + if ((maximum_major_version > ODK_MAJOR_VERSION) || + (maximum_major_version == ODK_MAJOR_VERSION && + maximum_minor_version > ODK_MINOR_VERSION)) { + // TODO(b/147513335): this should be logged. + return false; + } + if (major_version < ODK_FIRST_VERSION) return false; + // If we are not allowed to serve messages to a prerelease device, then check + // to see if the request has a prerelease version number. + // Minor version of 0 indicates prerelease. + if (minor_version == 0 && !serve_prerelease_odk_messages) { + // if (major_version < kFirstPrereleaseVersion) Older devices are OK. There + // were no prerelease devices before kFirstPrereleaseVersion. + if (major_version < kFirstPrereleaseVersion) { + return true; + } + + // If (major_version > ODK_MAJOR_VERSION), then the device is newer than the + // server. That happens when we are testing a prerelease device against a + // special test server that is explicitly for backwards compatibility + // testing. Since we want to allow prerelease devices to be tested for + // bacwards compatibility, we allow devices that are newer than the server. + // + // The UAT server, used for release validation testing, will be + // running recent software. The main reason to reject a prerelease device is + // to prevent a device from passing GTS tests with prerelease software. This + // statement assumes that UAT is updated quarterly, and that a prerelease + // ODK library will not be created in less than a few quarters after a major + // release. + if (major_version > ODK_MAJOR_VERSION) { + return true; + } + // All other prerelease devices are rejected. + return false; + } + return true; +} + +bool CoreMessageFeatures::GetResponseVersion( + uint16_t request_major_version, uint16_t request_minor_version, + uint16_t& response_major_version, uint16_t& response_minor_version) const { + if (!ValidateRequestVersion(request_major_version, request_minor_version)) { + return false; + } + std::tie(response_major_version, response_minor_version) = + std::min(std::pair(request_major_version, request_minor_version), + std::pair(static_cast(maximum_major_version), + static_cast(maximum_minor_version))); + return true; +} + +std::ostream& operator<<(std::ostream& os, + const CoreMessageFeatures& features) { return os << "v" << features.maximum_major_version << "." << features.maximum_minor_version; } diff --git a/oemcrypto/odk/src/core_message_serialize.cpp b/oemcrypto/odk/src/core_message_serialize.cpp index 888b957..b56df71 100644 --- a/oemcrypto/odk/src/core_message_serialize.cpp +++ b/oemcrypto/odk/src/core_message_serialize.cpp @@ -11,6 +11,7 @@ #include #include "core_message_types.h" +#include "odk_message.h" #include "odk_serialize.h" #include "odk_structs.h" #include "odk_structs_priv.h" @@ -33,33 +34,16 @@ bool CreateResponseHeader(const CoreMessageFeatures& features, ODK_MessageType message_type, ODK_CoreMessage* response_header, const S& core_request) { - // Bad major version. - if ((features.maximum_major_version > ODK_MAJOR_VERSION) || - (features.maximum_major_version == ODK_MAJOR_VERSION && - features.maximum_minor_version > ODK_MINOR_VERSION)) { + if (!features.GetResponseVersion( + core_request.api_major_version, core_request.api_minor_version, + response_header->nonce_values.api_major_version, + response_header->nonce_values.api_minor_version)) { // TODO(b/147513335): this should be logged. return false; } - response_header->message_type = message_type; - response_header->nonce_values.api_major_version = - core_request.api_major_version; - response_header->nonce_values.api_minor_version = - core_request.api_minor_version; response_header->nonce_values.nonce = core_request.nonce; response_header->nonce_values.session_id = core_request.session_id; - // The message API version for the response is the minimum of our version and - // the request's version. - if (core_request.api_major_version > features.maximum_major_version) { - response_header->nonce_values.api_major_version = - features.maximum_major_version; - response_header->nonce_values.api_minor_version = - features.maximum_minor_version; - } else if (core_request.api_major_version == features.maximum_major_version && - core_request.api_minor_version > features.maximum_minor_version) { - response_header->nonce_values.api_minor_version = - features.maximum_minor_version; - } return true; } @@ -86,12 +70,13 @@ bool CreateResponse(ODK_MessageType message_type, static constexpr size_t BUF_CAPACITY = 2048; std::vector buf(BUF_CAPACITY, 0); - ODK_Message msg = ODK_Message_Create(buf.data(), buf.capacity()); + ODK_Message msg = ODK_Message_Create(buf.data(), buf.size()); packer(&msg, &response); if (!ODK_Message_IsValid(&msg)) { return false; } + // Inserting message_length at the beginning of the buffer. uint32_t message_length = static_cast(ODK_Message_GetSize(&msg)); msg = ODK_Message_Create(buf.data() + sizeof(response_header->message_type), sizeof(response_header->message_length)); @@ -140,11 +125,7 @@ bool CreateCoreLicenseResponse(const CoreMessageFeatures& features, bool CreateCoreReleaseResponse(const CoreMessageFeatures& features, const ODK_ReleaseRequest& core_request, - int64_t seconds_since_license_requested, - int64_t seconds_since_first_decrypt, std::string* oemcrypto_core_message) { - (void)seconds_since_license_requested; - (void)seconds_since_first_decrypt; ODK_ReleaseResponse release_response{}; if (!CreateResponseHeader(features, ODK_Release_Response_Type, &release_response.core_message, core_request)) { diff --git a/oemcrypto/odk/src/core_message_serialize_proto.cpp b/oemcrypto/odk/src/core_message_serialize_proto.cpp index edb9453..de2d32d 100644 --- a/oemcrypto/odk/src/core_message_serialize_proto.cpp +++ b/oemcrypto/odk/src/core_message_serialize_proto.cpp @@ -4,7 +4,6 @@ #include "core_message_serialize_proto.h" -#include #include #include #include @@ -13,12 +12,11 @@ #include #include "OEMCryptoCENCCommon.h" +#include "core_message_features.h" #include "core_message_serialize.h" +#include "core_message_types.h" #include "license_protocol.pb.h" -#include "odk_serialize.h" #include "odk_structs.h" -#include "odk_structs_priv.h" -#include "serialization_base.h" namespace oemcrypto_core_message { namespace serialize { @@ -36,18 +34,17 @@ using oemcrypto_core_message::features::CoreMessageFeatures; */ OEMCrypto_Substring GetOecSubstring(std::string_view message, std::string_view field) { - OEMCrypto_Substring substring = {}; - size_t pos = message.find(field); - if (pos != std::string::npos) { - substring = OEMCrypto_Substring{pos, field.length()}; + const size_t pos = message.find(field); + if (pos == std::string::npos) { + return OEMCrypto_Substring{0, 0}; } - return substring; + return OEMCrypto_Substring{pos, field.length()}; } -OEMCrypto_KeyObject KeyContainerToOecKey( +OEMCrypto_KeyObjectV2 KeyContainerToOecKey( std::string_view proto, const video_widevine::License::KeyContainer& k, const bool uses_padding) { - OEMCrypto_KeyObject obj = {}; + OEMCrypto_KeyObjectV2 obj = {}; obj.key_id = GetOecSubstring(proto, k.id()); obj.key_data_iv = GetOecSubstring(proto, k.iv()); @@ -67,6 +64,48 @@ OEMCrypto_KeyObject KeyContainerToOecKey( obj.key_control_iv = GetOecSubstring(proto, key_control.iv()); obj.key_control = GetOecSubstring(proto, key_control.key_control_block()); } + obj.decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_None; + memset(&(obj.decenc_mitigation_info.configuration_options), 0, + sizeof(obj.decenc_mitigation_info.configuration_options)); + if (k.has_decenc_mitigation()) { + switch (k.decenc_mitigation().mitigation_option()) { + case video_widevine::License::KeyContainer::DecencMitigation:: + DECENC_MITIGATION_OPTION_NONE: + obj.decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_None; + break; + case video_widevine::License::KeyContainer::DecencMitigation:: + DECENC_MITIGATION_OPTION_AUTHENTICATE_BITSTREAM: + obj.decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream; + // Set the authentication key info for Authenticate Bitstream option. + obj.decenc_mitigation_info.configuration_options.authentication_key_info + .authentication_key = + GetOecSubstring(proto, k.decenc_mitigation() + .authenticate_bitstream_config_option() + .authentication_key()); + obj.decenc_mitigation_info.configuration_options.authentication_key_info + .authentication_key_iv = + GetOecSubstring(proto, k.decenc_mitigation() + .authenticate_bitstream_config_option() + .authentication_key_iv()); + break; + case video_widevine::License::KeyContainer::DecencMitigation:: + DECENC_MITIGATION_OPTION_VALIDATE_BITSTREAM: + obj.decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream; + break; + case video_widevine::License::KeyContainer::DecencMitigation:: + DECENC_MITIGATION_OPTION_RESTRICT_DECODING: + obj.decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding; + break; + default: + obj.decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_None; + } + } return obj; } @@ -87,9 +126,10 @@ bool CreateCoreLicenseResponseFromProto(const CoreMessageFeatures& features, } ODK_Packing_ParsedLicense parsed_lic{}; - std::vector key_array; + std::vector key_array; bool any_content = false; bool any_entitlement = false; + bool any_decenc_mitigation = false; for (int i = 0; i < lic.key_size(); ++i) { const auto& k = lic.key(i); @@ -116,6 +156,12 @@ bool CreateCoreLicenseResponseFromProto(const CoreMessageFeatures& features, } key_array.push_back( KeyContainerToOecKey(serialized_license, k, uses_padding)); + if (k.has_decenc_mitigation() && + k.decenc_mitigation().mitigation_option() != + video_widevine::License::KeyContainer::DecencMitigation:: + DECENC_MITIGATION_OPTION_NONE) { + any_decenc_mitigation = true; + } break; } default: { @@ -149,15 +195,15 @@ bool CreateCoreLicenseResponseFromProto(const CoreMessageFeatures& features, const auto& policy = lic.policy(); switch (policy.initial_renewal_delay_base()) { case video_widevine::License_Policy::LICENSE_LOAD: - parsed_lic.renewal_delay_base = OEMCrypto_License_Load; + parsed_lic.renewal_delay_base = OEMCrypto_LicenseLoad; break; case video_widevine::License_Policy::FIRST_DECRYPT: - parsed_lic.renewal_delay_base = OEMCrypto_First_Decrypt; + parsed_lic.renewal_delay_base = OEMCrypto_FirstDecrypt; break; case video_widevine::License_Policy::TIMER_DELAY_BASE_UNSPECIFIED: case video_widevine::License_Policy::LICENSE_START: default: - parsed_lic.renewal_delay_base = OEMCrypto_License_Start; + parsed_lic.renewal_delay_base = OEMCrypto_LicenseStart; break; } ODK_TimerLimits& timer_limits = parsed_lic.timer_limits; @@ -180,7 +226,10 @@ bool CreateCoreLicenseResponseFromProto(const CoreMessageFeatures& features, parsed_lic.key_array = key_array.data(); parsed_lic.key_array_length = static_cast(key_array.size()); - + // If DeCENC mitigation option is used in any key, then set + // |decenc_mitigation_option_used| to a non-zero value. Otherwise, set the + // value to 0. + parsed_lic.decenc_mitigation_option_used = any_decenc_mitigation; return CreateCoreLicenseResponse(features, parsed_lic, core_request, core_request_sha256, oemcrypto_core_message); } diff --git a/oemcrypto/odk/src/odk.c b/oemcrypto/odk/src/odk.c index 3f4090f..0d19968 100644 --- a/oemcrypto/odk/src/odk.c +++ b/oemcrypto/odk/src/odk.c @@ -38,9 +38,9 @@ static OEMCryptoResult ODK_PrepareRequest( } ODK_CoreMessage* core_message = (ODK_CoreMessage*)prepared_request_buffer; *core_message = (ODK_CoreMessage){ - message_type, - 0, - *nonce_values, + .message_type = message_type, + .message_length = 0, + .nonce_values = *nonce_values, }; /* Set core message length, and pack prepared request into message if the @@ -145,13 +145,12 @@ static OEMCryptoResult ODK_PrepareRequest( } /* Parse the core message and verify that it has the right type. The nonce - * values are updated to hold the response's API version. + * values are updated to hold the response's API version. This is used to + * parse a response, not a request. */ -static OEMCryptoResult ODK_ParseCoreHeader(const uint8_t* message, - size_t message_length, - size_t core_message_length, - ODK_MessageType message_type, - ODK_NonceValues* nonce_values) { +static OEMCryptoResult ODK_ParseCoreHeaderInResponse( + const uint8_t* message, size_t message_length, size_t core_message_length, + ODK_MessageType message_type, ODK_NonceValues* nonce_values) { // The core_message_length is the length of the core message, which is a // substring of the complete message. if (message == NULL || core_message_length > message_length) { @@ -183,7 +182,7 @@ static OEMCryptoResult ODK_ParseCoreHeader(const uint8_t* message, core_message.nonce_values.api_major_version < ODK_FIRST_VERSION) { return ODK_UNSUPPORTED_API; } - if (nonce_values) { + if (nonce_values != NULL) { /* If the server sent us an older format, record the message's API version. */ if (nonce_values->api_major_version > @@ -213,15 +212,23 @@ static OEMCryptoResult ODK_ParseCoreHeader(const uint8_t* message, OEMCryptoResult ODK_PrepareCoreLicenseRequest( uint8_t* message, size_t message_length, size_t* core_message_size, const ODK_NonceValues* nonce_values, - const ODK_MessageCounterInfo* counter_info) { + const ODK_MessageCounterInfo* counter_info, + uint16_t decenc_mitigation_options_supported) { if (core_message_size == NULL || nonce_values == NULL || counter_info == NULL) { return ODK_ERROR_CORE_MESSAGE; } if (nonce_values->api_major_version > 17) { - ODK_PreparedLicenseRequest license_request = {0}; - memcpy(&license_request.counter_info, counter_info, - sizeof(license_request.counter_info)); + if (!ODK_VersionSupportsDecencMitigation(nonce_values->api_major_version, + nonce_values->api_minor_version)) { + // |decenc_mitigation_options_supported| will be set to 0 if the version + // does not support decenc mitigation. + decenc_mitigation_options_supported = 0; + } + ODK_PreparedLicenseRequest license_request = { + .counter_info = *counter_info, + .decenc_mitigation_options_supported = + decenc_mitigation_options_supported}; return ODK_PrepareRequest( message, message_length, core_message_size, ODK_License_Request_Type, nonce_values, &license_request, sizeof(ODK_PreparedLicenseRequest)); @@ -235,25 +242,31 @@ OEMCryptoResult ODK_PrepareCoreLicenseRequest( OEMCryptoResult ODK_PrepareCoreReleaseRequest( uint8_t* message, size_t message_length, size_t* core_message_size, - ODK_NonceValues* nonce_values, uint32_t status, - uint32_t clock_security_level, int64_t seconds_since_license_requested, - int64_t seconds_since_first_decrypt, ODK_ClockValues* clock_values, - uint64_t system_time_seconds) { - (void)status; - (void)clock_security_level; - (void)seconds_since_license_requested; - (void)seconds_since_first_decrypt; - if (core_message_size == NULL || nonce_values == NULL || + ODK_NonceValues* nonce_values, + OEMCrypto_UsageEntryStatus usage_entry_status, + uint32_t clock_security_level, uint64_t seconds_since_license_requested, + uint64_t seconds_since_first_decrypt, const uint8_t* pst, size_t pst_length, + ODK_ClockValues* clock_values, uint64_t system_time_seconds) { + if (core_message_size == NULL || nonce_values == NULL || pst == NULL || clock_values == NULL) { return ODK_ERROR_CORE_MESSAGE; } - if (nonce_values->api_major_version >= 19) { - ODK_PreparedReleaseRequest release_request = {0}; + if (nonce_values->api_major_version >= 20) { + // License release is only supported in API version 20 and above. + ODK_PreparedReleaseRequest release_request = { + .usage_entry_status = usage_entry_status, + .clock_security_level = clock_security_level, + .seconds_since_license_requested = seconds_since_license_requested, + .seconds_since_first_decrypt = seconds_since_first_decrypt, + .pst_length = (uint32_t)pst_length, + }; + memset(release_request.pst, 0, sizeof(release_request.pst)); + memcpy(release_request.pst, pst, pst_length); return ODK_PrepareRequest( message, message_length, core_message_size, ODK_Release_Request_Type, nonce_values, &release_request, sizeof(ODK_PreparedReleaseRequest)); } else { - // If the version is pre 19 when license release isn't supported, create a + // If the version is pre 20 when license release isn't supported, create a // license request. return ODK_PrepareCoreRenewalRequest(message, message_length, core_message_size, nonce_values, @@ -318,9 +331,8 @@ OEMCryptoResult ODK_PrepareCoreProvisioningRequest( return ODK_ERROR_CORE_MESSAGE; } if (nonce_values->api_major_version > 17) { - ODK_PreparedProvisioningRequest provisioning_request = {0}; - memcpy(&provisioning_request.counter_info, counter_info, - sizeof(ODK_MessageCounterInfo)); + ODK_PreparedProvisioningRequest provisioning_request = {.counter_info = + *counter_info}; return ODK_PrepareRequest(message, message_length, core_message_length, ODK_Provisioning_Request_Type, nonce_values, @@ -343,16 +355,15 @@ OEMCryptoResult ODK_PrepareCoreProvisioning40Request( counter_info == NULL) { return ODK_ERROR_CORE_MESSAGE; } - ODK_PreparedProvisioning40Request provisioning_request = {0}; + ODK_PreparedProvisioning40Request provisioning_request = { + .counter_info = *counter_info, + .device_info_length = (uint32_t)device_info_length}; if (device_info_length > sizeof(provisioning_request.device_info)) { return ODK_ERROR_CORE_MESSAGE; } - provisioning_request.device_info_length = (uint32_t)device_info_length; - if (device_info) { + if (device_info != NULL && device_info_length > 0) { memcpy(provisioning_request.device_info, device_info, device_info_length); } - memcpy(&provisioning_request.counter_info, counter_info, - sizeof(provisioning_request.counter_info)); return ODK_PrepareRequest(message, message_length, core_message_length, ODK_Provisioning40_Request_Type, nonce_values, @@ -368,20 +379,20 @@ OEMCryptoResult ODK_PrepareCoreRenewedProvisioningRequest( if (core_message_length == NULL || nonce_values == NULL) { return ODK_ERROR_CORE_MESSAGE; } - ODK_PreparedRenewedProvisioningRequest provisioning_request = {0}; + ODK_PreparedRenewedProvisioningRequest provisioning_request = { + .device_id_length = (uint32_t)device_id_length, + .renewal_type = renewal_type, + .renewal_data_length = (uint32_t)renewal_data_length}; if (device_id_length > sizeof(provisioning_request.device_id)) { return ODK_ERROR_CORE_MESSAGE; } - provisioning_request.device_id_length = (uint32_t)device_id_length; - if (device_id) { + if (device_id != NULL && device_id_length > 0) { memcpy(provisioning_request.device_id, device_id, device_id_length); } if (renewal_data_length > sizeof(provisioning_request.renewal_data)) { return ODK_ERROR_CORE_MESSAGE; } - provisioning_request.renewal_type = renewal_type; - provisioning_request.renewal_data_length = (uint32_t)renewal_data_length; - if (renewal_data) { + if (renewal_data != NULL && renewal_data_length > 0) { memcpy(provisioning_request.renewal_data, renewal_data, renewal_data_length); } @@ -404,20 +415,18 @@ OEMCryptoResult ODK_ParseLicense( return ODK_ERROR_CORE_MESSAGE; } - OEMCryptoResult err = - ODK_ParseCoreHeader(message, message_length, core_message_length, - ODK_License_Response_Type, nonce_values); + const OEMCryptoResult err = ODK_ParseCoreHeaderInResponse( + message, message_length, core_message_length, ODK_License_Response_Type, + nonce_values); if (err != OEMCrypto_SUCCESS) { return err; } - ODK_LicenseResponse license_response = {0}; - license_response.parsed_license = parsed_license; - ODK_Message msg = ODK_Message_Create((uint8_t*)message, message_length); - ODK_Message_SetSize(&msg, core_message_length); + ODK_LicenseResponse license_response = {.parsed_license = parsed_license}; + Unpack_ODK_LicenseResponse(&msg, &license_response); if (ODK_Message_GetStatus(&msg) != MESSAGE_STATUS_OK || @@ -456,15 +465,14 @@ OEMCryptoResult ODK_ParseLicense( clock_values->time_of_license_request_signed = system_time_seconds; } } - bool license_load = - (parsed_license->renewal_delay_base == OEMCrypto_License_Load); + const bool license_load = + (parsed_license->renewal_delay_base == OEMCrypto_LicenseLoad); *timer_limits = parsed_license->timer_limits; /* And update the clock values state. */ clock_values->timer_status = ODK_CLOCK_TIMER_STATUS_LICENSE_LOADED; if (nonce_values->api_major_version >= 18 && license_load) { - err = ODK_AttemptFirstPlayback(system_time_seconds, timer_limits, - clock_values, timer_value); - return err; + return ODK_AttemptFirstPlayback(system_time_seconds, timer_limits, + clock_values, timer_value); } return OEMCrypto_SUCCESS; } @@ -481,15 +489,16 @@ OEMCryptoResult ODK_ParseRenewal(const uint8_t* message, size_t message_length, return ODK_ERROR_CORE_MESSAGE; } - const OEMCryptoResult err = - ODK_ParseCoreHeader(message, message_length, core_message_length, - ODK_Renewal_Response_Type, nonce_values); + const OEMCryptoResult err = ODK_ParseCoreHeaderInResponse( + message, message_length, core_message_length, ODK_Renewal_Response_Type, + nonce_values); if (err != OEMCrypto_SUCCESS) { return err; } - ODK_RenewalResponse renewal_response = {0}; ODK_Message msg = ODK_Message_Create((uint8_t*)message, message_length); ODK_Message_SetSize(&msg, core_message_length); + + ODK_RenewalResponse renewal_response = {0}; Unpack_ODK_RenewalResponse(&msg, &renewal_response); if (ODK_Message_GetStatus(&msg) != MESSAGE_STATUS_OK || @@ -533,16 +542,17 @@ OEMCryptoResult ODK_ParseRelease(const uint8_t* message, size_t message_length, return ODK_ERROR_CORE_MESSAGE; } - const OEMCryptoResult err = - ODK_ParseCoreHeader(message, message_length, core_message_length, - ODK_Release_Response_Type, nonce_values); + const OEMCryptoResult err = ODK_ParseCoreHeaderInResponse( + message, message_length, core_message_length, ODK_Release_Response_Type, + nonce_values); if (err != OEMCrypto_SUCCESS) { return err; } - ODK_ReleaseResponse release_response = {0}; ODK_Message msg = ODK_Message_Create((uint8_t*)message, message_length); ODK_Message_SetSize(&msg, core_message_length); + + ODK_ReleaseResponse release_response = {0}; Unpack_ODK_ReleaseResponse(&msg, &release_response); if (ODK_Message_GetStatus(&msg) != MESSAGE_STATUS_OK || @@ -561,24 +571,24 @@ OEMCryptoResult ODK_ParseProvisioning( parsed_response == NULL) { return ODK_ERROR_CORE_MESSAGE; } - const OEMCryptoResult err = - ODK_ParseCoreHeader(message, message_length, core_message_length, - ODK_Provisioning_Response_Type, nonce_values); + const OEMCryptoResult err = ODK_ParseCoreHeaderInResponse( + message, message_length, core_message_length, + ODK_Provisioning_Response_Type, nonce_values); if (err != OEMCrypto_SUCCESS) { return err; } if (nonce_values->api_major_version <= 17) { // Do v16/v17 - ODK_ProvisioningResponseV16 provisioning_response = {0}; - provisioning_response.parsed_provisioning = parsed_response; - if (device_id_length > ODK_DEVICE_ID_LEN_MAX) { return ODK_ERROR_CORE_MESSAGE; } ODK_Message msg = ODK_Message_Create((uint8_t*)message, message_length); ODK_Message_SetSize(&msg, core_message_length); + + ODK_ProvisioningResponseV16 provisioning_response = {.parsed_provisioning = + parsed_response}; Unpack_ODK_ProvisioningResponseV16(&msg, &provisioning_response); if (ODK_Message_GetStatus(&msg) != MESSAGE_STATUS_OK || ODK_Message_GetOffset(&msg) != core_message_length) { @@ -597,11 +607,11 @@ OEMCryptoResult ODK_ParseProvisioning( } } else { // v18 - ODK_ProvisioningResponse provisioning_response = {0}; - provisioning_response.parsed_provisioning = parsed_response; - ODK_Message msg = ODK_Message_Create((uint8_t*)message, message_length); ODK_Message_SetSize(&msg, core_message_length); + + ODK_ProvisioningResponse provisioning_response = {0}; + provisioning_response.parsed_provisioning = parsed_response; Unpack_ODK_ProvisioningResponse(&msg, &provisioning_response); if (ODK_Message_GetStatus(&msg) != MESSAGE_STATUS_OK || ODK_Message_GetOffset(&msg) != core_message_length) { @@ -624,16 +634,17 @@ OEMCryptoResult ODK_ParseProvisioning40(const uint8_t* message, if (message == NULL || nonce_values == NULL) { return ODK_ERROR_CORE_MESSAGE; } - const OEMCryptoResult err = - ODK_ParseCoreHeader(message, message_length, core_message_length, - ODK_Provisioning_Response_Type, nonce_values); + const OEMCryptoResult err = ODK_ParseCoreHeaderInResponse( + message, message_length, core_message_length, + ODK_Provisioning_Response_Type, nonce_values); if (err != OEMCrypto_SUCCESS) { return err; } - ODK_Provisioning40Response provisioning_response = {0}; ODK_Message msg = ODK_Message_Create((uint8_t*)message, message_length); ODK_Message_SetSize(&msg, core_message_length); + + ODK_Provisioning40Response provisioning_response = {0}; Unpack_ODK_Provisioning40Response(&msg, &provisioning_response); if (ODK_Message_GetStatus(&msg) != MESSAGE_STATUS_OK || ODK_Message_GetOffset(&msg) != core_message_length) { @@ -655,13 +666,16 @@ bool CheckApiVersionAtMost(const ODK_NonceValues* nonce_values, nonce_values->api_minor_version <= minor_version); } +// These constants are exposed in the header file. const uint8_t ODK_MacKeyLabelWithZero[] = "AUTHENTICATION"; +// Inclusion of null terminator is intentional. const size_t ODK_MacKeyLabelWithZeroLength = sizeof(ODK_MacKeyLabelWithZero); // This is the key size (512) in network byte order. const uint8_t ODK_MacKeySuffix[] = {0x00, 0x00, 0x02, 0x00}; const size_t ODK_MacKeySuffixLength = sizeof(ODK_MacKeySuffix); const uint8_t ODK_EncKeyLabelWithZero[] = "ENCRYPTION"; +// Inclusion of null terminator is intentional. const size_t ODK_EncKeyLabelWithZeroLength = sizeof(ODK_EncKeyLabelWithZero); // This is the key size (128) in network byte order. const uint8_t ODK_EncKeySuffix[] = {0x00, 0x00, 0x00, 0x80}; @@ -673,34 +687,40 @@ OEMCryptoResult ODK_GenerateKeyContexts(const uint8_t* context, size_t* mac_key_context_length, uint8_t* enc_key_context, size_t* enc_key_context_length) { - size_t real_mac_length; - size_t real_enc_length; + size_t real_mac_length = 0; + size_t real_enc_length = 0; if (odk_add_overflow_ux( context_length, ODK_MacKeyLabelWithZeroLength + ODK_MacKeySuffixLength, &real_mac_length) || - real_mac_length > 0xffffffff || + real_mac_length > UINT32_MAX || odk_add_overflow_ux( context_length, ODK_EncKeyLabelWithZeroLength + ODK_EncKeySuffixLength, &real_enc_length) || - real_enc_length > 0xffffffff) { + real_enc_length > UINT32_MAX) { return OEMCrypto_ERROR_INVALID_CONTEXT; } + // Compute lengths for both key contexts before returning + // OEMCrypto_ERROR_SHORT_BUFFER. bool short_buffer = false; - if (mac_key_context_length) { + if (mac_key_context_length != NULL) { short_buffer = real_mac_length > *mac_key_context_length; *mac_key_context_length = real_mac_length; } - if (enc_key_context_length) { + if (enc_key_context_length != NULL) { short_buffer = short_buffer || real_enc_length > *enc_key_context_length; *enc_key_context_length = real_enc_length; } - if (short_buffer || !mac_key_context || !enc_key_context) { + if (short_buffer || mac_key_context == NULL || enc_key_context == NULL) { return OEMCrypto_ERROR_SHORT_BUFFER; } - if (!context || !mac_key_context_length || !enc_key_context_length) { + // Returning OEMCrypto_ERROR_INVALID_CONTEXT here to allow for + // obtaining at least one of the correct lengths in the case of + // OEMCrypto_ERROR_SHORT_BUFFER. + if (context == NULL || mac_key_context_length == NULL || + enc_key_context_length == NULL) { return OEMCrypto_ERROR_INVALID_CONTEXT; } diff --git a/oemcrypto/odk/src/odk_message_priv.h b/oemcrypto/odk/src/odk_message_priv.h index f8e9bcc..74d8a59 100644 --- a/oemcrypto/odk/src/odk_message_priv.h +++ b/oemcrypto/odk/src/odk_message_priv.h @@ -5,10 +5,6 @@ #ifndef WIDEVINE_ODK_SRC_ODK_MESSAGE_PRIV_H_ #define WIDEVINE_ODK_SRC_ODK_MESSAGE_PRIV_H_ -#ifdef __cplusplus -extern "C" { -#endif - /* * This file must only be included by odk_message.c and serialization_base.c. */ @@ -18,6 +14,10 @@ extern "C" { #include "odk_message.h" +#ifdef __cplusplus +extern "C" { +#endif + /* * This is the implementation of a message. This structure is private, i.e. it * should only be included by files that are allowed to modify the internals of diff --git a/oemcrypto/odk/src/odk_overflow.c b/oemcrypto/odk/src/odk_overflow.c index ba19962..4ffd1f8 100644 --- a/oemcrypto/odk/src/odk_overflow.c +++ b/oemcrypto/odk/src/odk_overflow.c @@ -1,48 +1,42 @@ // Copyright 2019 Google LLC. This file and proprietary // source code may only be used and distributed under the Widevine // License Agreement. - #include "odk_overflow.h" #include #include +#define NO_OVERFLOW 0 +#define OVERFLOW_ERROR 1 + int odk_sub_overflow_u64(uint64_t a, uint64_t b, uint64_t* c) { - if (a >= b) { - if (c) { - *c = a - b; - } - return 0; + if (a < b) return OVERFLOW_ERROR; + if (c != NULL) { + *c = a - b; } - return 1; + return NO_OVERFLOW; } int odk_add_overflow_u64(uint64_t a, uint64_t b, uint64_t* c) { - if (UINT64_MAX - a >= b) { - if (c) { - *c = a + b; - } - return 0; + if ((UINT64_MAX - a) < b) return OVERFLOW_ERROR; + if (c != NULL) { + *c = a + b; } - return 1; + return NO_OVERFLOW; } int odk_add_overflow_ux(size_t a, size_t b, size_t* c) { - if (SIZE_MAX - a >= b) { - if (c) { - *c = a + b; - } - return 0; + if ((SIZE_MAX - a) < b) return OVERFLOW_ERROR; + if (c != NULL) { + *c = a + b; } - return 1; + return NO_OVERFLOW; } int odk_mul_overflow_ux(size_t a, size_t b, size_t* c) { - if (b > 0 && a > SIZE_MAX / b) { - return 1; - } - if (c) { + if (b > 0 && a > (SIZE_MAX / b)) return OVERFLOW_ERROR; + if (c != NULL) { *c = a * b; } - return 0; + return NO_OVERFLOW; } diff --git a/oemcrypto/odk/src/odk_overflow.h b/oemcrypto/odk/src/odk_overflow.h index 7b50552..835f05b 100644 --- a/oemcrypto/odk/src/odk_overflow.h +++ b/oemcrypto/odk/src/odk_overflow.h @@ -12,6 +12,12 @@ extern "C" { #endif +// ODK overflow-safe math functions. +// +// Performs an operation using the two operands |a| and |b|, +// and assigns the result to |c| if not NULL. +// +// Returns 1 if an overflow occurred, 0 otherwise. int odk_sub_overflow_u64(uint64_t a, uint64_t b, uint64_t* c); int odk_add_overflow_u64(uint64_t a, uint64_t b, uint64_t* c); int odk_add_overflow_ux(size_t a, size_t b, size_t* c); diff --git a/oemcrypto/odk/src/odk_serialize.c b/oemcrypto/odk/src/odk_serialize.c index 1f2e48e..c9e1942 100644 --- a/oemcrypto/odk/src/odk_serialize.c +++ b/oemcrypto/odk/src/odk_serialize.c @@ -1,22 +1,37 @@ // Copyright 2019 Google LLC. This file and proprietary // source code may only be used and distributed under the Widevine // License Agreement. - -/* - * This code is auto-generated, do not edit - */ - #include "odk_serialize.h" +#include +#include + +#include "OEMCryptoCENCCommon.h" #include "odk_message.h" #include "odk_overflow.h" +#include "odk_structs.h" #include "odk_structs_priv.h" +#include "odk_target.h" #include "serialization_base.h" /* @ serialize */ /* @@ private serialize */ +bool ODK_VersionSupportsDecencMitigation(uint16_t api_major_version, + uint16_t api_minor_version) { + if (api_major_version >= 20) { + return true; + } + if (api_major_version == 19) { + return (api_minor_version >= 7); + } + if (api_major_version == 18) { + return (api_minor_version >= 11); + } + return false; +} + static void Pack_ODK_NonceValues(ODK_Message* msg, ODK_NonceValues const* obj) { Pack_uint16_t(msg, &obj->api_minor_version); Pack_uint16_t(msg, &obj->api_major_version); @@ -30,13 +45,44 @@ static void Pack_ODK_CoreMessage(ODK_Message* msg, ODK_CoreMessage const* obj) { Pack_ODK_NonceValues(msg, &obj->nonce_values); } +static void Pack_OEMCrypto_Authentication_Key_Info( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo const* obj) { + Pack_OEMCrypto_Substring(msg, &obj->authentication_key); + Pack_OEMCrypto_Substring(msg, &obj->authentication_key_iv); +} + +static void Pack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info const* obj) { + Pack_uint32_t(msg, &obj->mitigation_option); + if (obj->mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream) { + Pack_OEMCrypto_Authentication_Key_Info( + msg, &obj->configuration_options.authentication_key_info); + } else if (obj->mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream || + obj->mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding) { + // These mitigation options do not require any additional data from the key + // object. + } else { + ODK_Message_SetStatus(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); + return; + } +} + static void Pack_OEMCrypto_KeyObject(ODK_Message* msg, - OEMCrypto_KeyObject const* obj) { + OEMCrypto_KeyObjectV2 const* obj, + size_t key_obj_version) { Pack_OEMCrypto_Substring(msg, &obj->key_id); Pack_OEMCrypto_Substring(msg, &obj->key_data_iv); Pack_OEMCrypto_Substring(msg, &obj->key_data); Pack_OEMCrypto_Substring(msg, &obj->key_control_iv); Pack_OEMCrypto_Substring(msg, &obj->key_control); + if (key_obj_version >= 2) { + // For |key_obj_version| >= 2, the |decenc_mitigation_info| is expected, + // even if it is set to OEMCrypto_DeCENC_Mitigation_Option_None. + Pack_OEMCrypto_DeCENC_Mitigation_Info(msg, &obj->decenc_mitigation_info); + } } static void Pack_ODK_TimerLimits(ODK_Message* msg, ODK_TimerLimits const* obj) { @@ -84,19 +130,43 @@ static void Pack_ODK_ParsedLicense(ODK_Message* msg, if (nonce_values->api_major_version >= 18) { Pack_enum(msg, obj->renewal_delay_base); } - Pack_uint32_t(msg, &obj->key_array_length); - size_t i; - for (i = 0; i < (size_t)obj->key_array_length; i++) { - Pack_OEMCrypto_KeyObject(msg, &obj->key_array[i]); + if (ODK_VersionSupportsDecencMitigation(nonce_values->api_major_version, + nonce_values->api_minor_version)) { + Pack_uint16_t(msg, &obj->decenc_mitigation_option_used); + } else if (obj->decenc_mitigation_option_used != 0) { + // Bad parameter, DeCENC not supported. + ODK_Message_SetStatus(msg, MESSAGE_STATUS_UNKNOWN_ERROR); + return; + } else { + static const uint16_t kDecencOptionNone = 0; + Pack_uint16_t(msg, &kDecencOptionNone); + } + Pack_uint16_t(msg, &obj->key_array_length); + for (uint32_t i = 0; i < obj->key_array_length; i++) { + if (obj->decenc_mitigation_option_used != 0) { + Pack_OEMCrypto_KeyObject(msg, &obj->key_array[i], 2); + } else { + Pack_OEMCrypto_KeyObject(msg, &obj->key_array[i], 1); + } } } static void Pack_ODK_ParsedProvisioning(ODK_Message* msg, - ODK_ParsedProvisioning const* obj) { + ODK_ParsedProvisioning const* obj, + const ODK_NonceValues* nonce_values) { Pack_enum(msg, obj->key_type); Pack_OEMCrypto_Substring(msg, &obj->enc_private_key); Pack_OEMCrypto_Substring(msg, &obj->enc_private_key_iv); Pack_OEMCrypto_Substring(msg, &obj->encrypted_message_key); + if (nonce_values->api_major_version >= 20) { + Pack_OEMCrypto_Substring(msg, &obj->curr_server_sealing_key); + Pack_uint32_t(msg, &obj->server_sealing_key_array_length); + for (uint32_t i = 0; i < obj->server_sealing_key_array_length && + i < ODK_MAX_NUM_SERVER_SEALING_KEYS; + i++) { + Pack_OEMCrypto_Substring(msg, &obj->server_sealing_key_array[i]); + } + } } static void Pack_ODK_MessageCounterInfo(ODK_Message* msg, @@ -119,6 +189,7 @@ void Pack_ODK_PreparedLicenseRequest(ODK_Message* msg, ODK_PreparedLicenseRequest const* obj) { Pack_ODK_CoreMessage(msg, &obj->core_message); Pack_ODK_MessageCounterInfo(msg, &obj->counter_info); + Pack_uint16_t(msg, &obj->decenc_mitigation_options_supported); } void Pack_ODK_PreparedLicenseRequestV17( @@ -129,6 +200,16 @@ void Pack_ODK_PreparedLicenseRequestV17( void Pack_ODK_PreparedReleaseRequest(ODK_Message* msg, const ODK_PreparedReleaseRequest* obj) { Pack_ODK_CoreMessage(msg, &obj->core_message); + Pack_uint32_t(msg, &obj->usage_entry_status); + Pack_uint32_t(msg, &obj->clock_security_level); + Pack_uint64_t(msg, &obj->seconds_since_license_requested); + Pack_uint64_t(msg, &obj->seconds_since_first_decrypt); + Pack_uint32_t(msg, &obj->pst_length); + if (obj->pst_length > ODK_PST_LEN_MAX) { + ODK_Message_SetStatus(msg, MESSAGE_STATUS_OVERFLOW_ERROR); + return; + } + PackFixedArray(msg, &obj->pst[0], obj->pst_length, ODK_PST_LEN_MAX); } void Pack_ODK_PreparedRenewalRequest(ODK_Message* msg, @@ -142,7 +223,7 @@ void Pack_ODK_PreparedProvisioningRequest( Pack_ODK_CoreMessage(msg, &obj->core_message); // Fake device_id_length for older servers, since we removed device id from // the v18 request - uint32_t device_id_len = 64; + const uint32_t device_id_len = ODK_DEVICE_ID_LEN_MAX; Pack_uint32_t(msg, &device_id_len); Pack_ODK_MessageCounterInfo(msg, &obj->counter_info); } @@ -180,7 +261,7 @@ void Pack_ODK_LicenseResponse(ODK_Message* msg, Pack_ODK_ParsedLicense(msg, (const ODK_Packing_ParsedLicense*)obj->parsed_license, &obj->core_message.nonce_values); - if ((&obj->core_message.nonce_values)->api_major_version == 16) { + if (obj->core_message.nonce_values.api_major_version == 16) { PackArray(msg, &obj->request_hash[0], sizeof(obj->request_hash)); } } @@ -200,14 +281,16 @@ void Pack_ODK_ProvisioningResponse(ODK_Message* msg, const ODK_ProvisioningResponse* obj) { Pack_ODK_CoreMessage(msg, &obj->core_message); Pack_ODK_ParsedProvisioning( - msg, (const ODK_ParsedProvisioning*)obj->parsed_provisioning); + msg, (const ODK_ParsedProvisioning*)obj->parsed_provisioning, + &obj->core_message.nonce_values); } void Pack_ODK_ProvisioningResponseV16(ODK_Message* msg, const ODK_ProvisioningResponseV16* obj) { Pack_ODK_PreparedProvisioningRequestV17(msg, &obj->request); Pack_ODK_ParsedProvisioning( - msg, (const ODK_ParsedProvisioning*)obj->parsed_provisioning); + msg, (const ODK_ParsedProvisioning*)obj->parsed_provisioning, + &obj->request.core_message.nonce_values); } void Pack_ODK_Provisioning40Response(ODK_Message* msg, @@ -232,13 +315,47 @@ void Unpack_ODK_CoreMessage(ODK_Message* msg, ODK_CoreMessage* obj) { Unpack_ODK_NonceValues(msg, &obj->nonce_values); } +static void Unpack_OEMCrypto_Authentication_Key_Info( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo* obj) { + Unpack_OEMCrypto_Substring(msg, &obj->authentication_key); + Unpack_OEMCrypto_Substring(msg, &obj->authentication_key_iv); +} + +static void Unpack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info* obj) { + Unpack_uint32_t(msg, &obj->mitigation_option); + // Check that the mitigation_option only has 1 bit set or is 0. + if ((obj->mitigation_option & (obj->mitigation_option - 1)) != 0) { + ODK_Message_SetStatus(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); + return; + } + if (obj->mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream) { + Unpack_OEMCrypto_Authentication_Key_Info( + msg, &obj->configuration_options.authentication_key_info); + } else if (obj->mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream || + obj->mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding) { + // These mitigation options do not require any additional data from the key + // object. + } else { + ODK_Message_SetStatus(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); + return; + } +} + static void Unpack_OEMCrypto_KeyObject(ODK_Message* msg, - OEMCrypto_KeyObject* obj) { + OEMCrypto_KeyObjectV2* obj, + size_t key_obj_version) { Unpack_OEMCrypto_Substring(msg, &obj->key_id); Unpack_OEMCrypto_Substring(msg, &obj->key_data_iv); Unpack_OEMCrypto_Substring(msg, &obj->key_data); Unpack_OEMCrypto_Substring(msg, &obj->key_control_iv); Unpack_OEMCrypto_Substring(msg, &obj->key_control); + if (key_obj_version >= 2) { + Unpack_OEMCrypto_DeCENC_Mitigation_Info(msg, &obj->decenc_mitigation_info); + } /* Edge case for servers that incorrectly process protocol VERSION_2_2 padding. @@ -314,23 +431,44 @@ static void Unpack_ODK_ParsedLicense(ODK_Message* msg, ODK_ParsedLicense* obj, if (nonce_values->api_major_version >= 18) { Unpack_OEMCrypto_TimerDelayBase(msg, &obj->renewal_delay_base); } - Unpack_uint32_t(msg, &obj->key_array_length); + Unpack_uint16_t(msg, &obj->decenc_mitigation_option_used); + if (!ODK_VersionSupportsDecencMitigation(nonce_values->api_major_version, + nonce_values->api_minor_version) && + obj->decenc_mitigation_option_used != 0) { + // Bad ODK message, DeCENC not supported. + ODK_Message_SetStatus(msg, MESSAGE_STATUS_UNKNOWN_ERROR); + return; + } + Unpack_uint16_t(msg, &obj->key_array_length); if (obj->key_array_length > ODK_MAX_NUM_KEYS) { ODK_Message_SetStatus(msg, MESSAGE_STATUS_OVERFLOW_ERROR); return; } - uint32_t i; - for (i = 0; i < obj->key_array_length; i++) { - Unpack_OEMCrypto_KeyObject(msg, &obj->key_array[i]); + for (uint32_t i = 0; i < obj->key_array_length; i++) { + if (obj->decenc_mitigation_option_used != 0) { + Unpack_OEMCrypto_KeyObject(msg, &obj->key_array[i], 2); + } else { + Unpack_OEMCrypto_KeyObject(msg, &obj->key_array[i], 1); + } } } static void Unpack_ODK_ParsedProvisioning(ODK_Message* msg, - ODK_ParsedProvisioning* obj) { + ODK_ParsedProvisioning* obj, + const ODK_NonceValues* nonce_values) { Unpack_OEMCrypto_PrivateKeyType(msg, &obj->key_type); Unpack_OEMCrypto_Substring(msg, &obj->enc_private_key); Unpack_OEMCrypto_Substring(msg, &obj->enc_private_key_iv); Unpack_OEMCrypto_Substring(msg, &obj->encrypted_message_key); + if (nonce_values->api_major_version >= 20) { + Unpack_OEMCrypto_Substring(msg, &obj->curr_server_sealing_key); + Unpack_uint32_t(msg, &obj->server_sealing_key_array_length); + for (uint32_t i = 0; i < obj->server_sealing_key_array_length && + i < ODK_MAX_NUM_SERVER_SEALING_KEYS; + i++) { + Unpack_OEMCrypto_Substring(msg, &obj->server_sealing_key_array[i]); + } + } } static void Unpack_ODK_MessageCounterInfo(ODK_Message* msg, @@ -353,6 +491,14 @@ void Unpack_ODK_PreparedLicenseRequest(ODK_Message* msg, ODK_PreparedLicenseRequest* obj) { Unpack_ODK_CoreMessage(msg, &obj->core_message); Unpack_ODK_MessageCounterInfo(msg, &obj->counter_info); + // Check that there is enough space for the + // decenc_mitigation_options_supported field. Major and minor version checks + // will be added in the future. + if ((ODK_Message_GetSize(msg) - ODK_Message_GetOffset(msg)) >= 2) { + Unpack_uint16_t(msg, &obj->decenc_mitigation_options_supported); + } else { + obj->decenc_mitigation_options_supported = 0; + } } void Unpack_ODK_PreparedLicenseRequestV17(ODK_Message* msg, @@ -363,6 +509,16 @@ void Unpack_ODK_PreparedLicenseRequestV17(ODK_Message* msg, void Unpack_ODK_PreparedReleaseRequest(ODK_Message* msg, ODK_PreparedReleaseRequest* obj) { Unpack_ODK_CoreMessage(msg, &obj->core_message); + Unpack_uint32_t(msg, &obj->usage_entry_status); + Unpack_uint32_t(msg, &obj->clock_security_level); + Unpack_uint64_t(msg, &obj->seconds_since_license_requested); + Unpack_uint64_t(msg, &obj->seconds_since_first_decrypt); + Unpack_uint32_t(msg, &obj->pst_length); + if (obj->pst_length > ODK_PST_LEN_MAX) { + ODK_Message_SetStatus(msg, MESSAGE_STATUS_OVERFLOW_ERROR); + return; + } + UnpackFixedArray(msg, &obj->pst[0], obj->pst_length, ODK_PST_LEN_MAX); } void Unpack_ODK_PreparedRenewalRequest(ODK_Message* msg, @@ -422,7 +578,7 @@ void Unpack_ODK_LicenseResponse(ODK_Message* msg, ODK_LicenseResponse* obj) { Unpack_ODK_CoreMessage(msg, &obj->core_message); Unpack_ODK_ParsedLicense(msg, obj->parsed_license, &obj->core_message.nonce_values); - if ((&obj->core_message.nonce_values)->api_major_version == 16) { + if (obj->core_message.nonce_values.api_major_version == 16) { UnpackArray(msg, &obj->request_hash[0], sizeof(obj->request_hash)); } } @@ -439,13 +595,15 @@ void Unpack_ODK_RenewalResponse(ODK_Message* msg, ODK_RenewalResponse* obj) { void Unpack_ODK_ProvisioningResponse(ODK_Message* msg, ODK_ProvisioningResponse* obj) { Unpack_ODK_CoreMessage(msg, &obj->core_message); - Unpack_ODK_ParsedProvisioning(msg, obj->parsed_provisioning); + Unpack_ODK_ParsedProvisioning(msg, obj->parsed_provisioning, + &obj->core_message.nonce_values); } void Unpack_ODK_ProvisioningResponseV16(ODK_Message* msg, ODK_ProvisioningResponseV16* obj) { Unpack_ODK_PreparedProvisioningRequestV17(msg, &obj->request); - Unpack_ODK_ParsedProvisioning(msg, obj->parsed_provisioning); + Unpack_ODK_ParsedProvisioning(msg, obj->parsed_provisioning, + &obj->request.core_message.nonce_values); } void Unpack_ODK_Provisioning40Response(ODK_Message* msg, diff --git a/oemcrypto/odk/src/odk_serialize.h b/oemcrypto/odk/src/odk_serialize.h index a7aa220..c1df9ae 100644 --- a/oemcrypto/odk/src/odk_serialize.h +++ b/oemcrypto/odk/src/odk_serialize.h @@ -8,6 +8,8 @@ #ifndef WIDEVINE_ODK_SRC_ODK_SERIALIZE_H_ #define WIDEVINE_ODK_SRC_ODK_SERIALIZE_H_ +#include + #include "odk_message.h" #include "odk_structs_priv.h" @@ -15,6 +17,9 @@ extern "C" { #endif +bool ODK_VersionSupportsDecencMitigation(uint16_t api_major_version, + uint16_t api_minor_version); + /* odk pack */ void Pack_ODK_PreparedLicenseRequest(ODK_Message* msg, const ODK_PreparedLicenseRequest* obj); diff --git a/oemcrypto/odk/src/odk_structs_priv.h b/oemcrypto/odk/src/odk_structs_priv.h index 208ce2c..9bc0f87 100644 --- a/oemcrypto/odk/src/odk_structs_priv.h +++ b/oemcrypto/odk/src/odk_structs_priv.h @@ -43,6 +43,7 @@ typedef struct { typedef struct { ODK_CoreMessage core_message; ODK_MessageCounterInfo counter_info; + uint16_t decenc_mitigation_options_supported; } ODK_PreparedLicenseRequest; typedef struct { @@ -51,6 +52,12 @@ typedef struct { typedef struct { ODK_CoreMessage core_message; + OEMCrypto_UsageEntryStatus usage_entry_status; + uint32_t clock_security_level; + uint64_t seconds_since_license_requested; + uint64_t seconds_since_first_decrypt; + uint32_t pst_length; + uint8_t pst[ODK_PST_LEN_MAX]; } ODK_PreparedReleaseRequest; typedef struct { @@ -130,9 +137,9 @@ typedef struct { // request structs change. Refer to test suite OdkSizeTest in // ../test/odk_test.cpp for validations of each of the defined request sizes. #define ODK_CORE_MESSAGE_SIZE 20u -#define ODK_LICENSE_REQUEST_SIZE 90u +#define ODK_LICENSE_REQUEST_SIZE 92u #define ODK_LICENSE_REQUEST_SIZE_V17 20u -#define ODK_RELEASE_REQUEST_SIZE 20u +#define ODK_RELEASE_REQUEST_SIZE 304u #define ODK_RENEWAL_REQUEST_SIZE 28u #define ODK_PROVISIONING_REQUEST_SIZE 94u #define ODK_PROVISIONING_REQUEST_SIZE_V17 88u diff --git a/oemcrypto/odk/src/odk_timer.c b/oemcrypto/odk/src/odk_timer.c index a1a9eb8..616cd11 100644 --- a/oemcrypto/odk/src/odk_timer.c +++ b/oemcrypto/odk/src/odk_timer.c @@ -5,9 +5,11 @@ #include #include +#include "OEMCryptoCENCCommon.h" #include "odk.h" #include "odk_attributes.h" #include "odk_overflow.h" +#include "odk_structs.h" #include "odk_structs_priv.h" /* Private function. Checks to see if the license is active. Returns @@ -25,7 +27,7 @@ static OEMCryptoResult ODK_LicenseActive(const ODK_TimerLimits* timer_limits, clock_values->timer_status == ODK_CLOCK_TIMER_STATUS_LICENSE_NOT_LOADED) { return OEMCrypto_ERROR_UNKNOWN_FAILURE; } - if (clock_values->status > kActive) { + if (clock_values->status > OEMCrypto_Active) { return ODK_TIMER_EXPIRED; } return OEMCrypto_SUCCESS; @@ -271,14 +273,17 @@ OEMCryptoResult ODK_InitializeSessionValues(ODK_TimerLimits* timer_limits, nonce_values->api_minor_version = 5; break; case 17: - nonce_values->api_minor_version = 2; + nonce_values->api_minor_version = 7; break; case 18: - nonce_values->api_minor_version = 4; + nonce_values->api_minor_version = 9; break; case 19: nonce_values->api_minor_version = 5; break; + case ODK_MAJOR_VERSION: + nonce_values->api_minor_version = ODK_MINOR_VERSION; + break; default: nonce_values->api_minor_version = 0; break; @@ -316,7 +321,7 @@ OEMCryptoResult ODK_InitializeClockValues(ODK_ClockValues* clock_values, clock_values->time_of_renewal_request = 0; clock_values->time_when_timer_expires = 0; clock_values->timer_status = ODK_CLOCK_TIMER_STATUS_LICENSE_NOT_LOADED; - clock_values->status = kUnused; + clock_values->status = OEMCrypto_Unused; return OEMCrypto_SUCCESS; } @@ -325,7 +330,7 @@ OEMCryptoResult ODK_ReloadClockValues(ODK_ClockValues* clock_values, uint64_t time_of_license_request_signed, uint64_t time_of_first_decrypt, uint64_t time_of_last_decrypt, - enum OEMCrypto_Usage_Entry_Status status, + enum OEMCrypto_UsageEntryStatus status, uint64_t system_time_seconds UNUSED) { if (clock_values == NULL) { return OEMCrypto_ERROR_INVALID_CONTEXT; @@ -389,7 +394,7 @@ OEMCryptoResult ODK_AttemptFirstPlayback(uint64_t system_time_seconds, /* If playback has not already started, then this is the first playback. */ if (clock_values->time_of_first_decrypt == 0) { clock_values->time_of_first_decrypt = system_time_seconds; - clock_values->status = kActive; + clock_values->status = OEMCrypto_Active; } /* Similar to the rental window, we check the playback window @@ -456,10 +461,10 @@ OEMCryptoResult ODK_DeactivateUsageEntry(ODK_ClockValues* clock_values) { if (clock_values == NULL) { return OEMCrypto_ERROR_UNKNOWN_FAILURE; } - if (clock_values->status == kUnused) { - clock_values->status = kInactiveUnused; - } else if (clock_values->status == kActive) { - clock_values->status = kInactiveUsed; + if (clock_values->status == OEMCrypto_Unused) { + clock_values->status = OEMCrypto_InactiveUnused; + } else if (clock_values->status == OEMCrypto_Active) { + clock_values->status = OEMCrypto_InactiveUsed; } clock_values->timer_status = ODK_CLOCK_TIMER_STATUS_LICENSE_INACTIVE; return OEMCrypto_SUCCESS; @@ -507,7 +512,7 @@ OEMCryptoResult ODK_RefreshV15Values(const ODK_TimerLimits* timer_limits, if (nonce_values->api_major_version != 15) { return OEMCrypto_ERROR_INVALID_NONCE; } - if (clock_values->status > kActive) { + if (clock_values->status > OEMCrypto_Active) { clock_values->timer_status = ODK_CLOCK_TIMER_STATUS_LICENSE_INACTIVE; return ODK_TIMER_EXPIRED; } diff --git a/oemcrypto/odk/src/odk_util.c b/oemcrypto/odk/src/odk_util.c index 12c93e4..e9eccd5 100644 --- a/oemcrypto/odk/src/odk_util.c +++ b/oemcrypto/odk/src/odk_util.c @@ -4,6 +4,8 @@ #include "odk_util.h" +#include "odk_structs.h" + int crypto_memcmp(const void* in_a, const void* in_b, size_t len) { if (len == 0) { return 0; diff --git a/oemcrypto/odk/src/serialization_base.c b/oemcrypto/odk/src/serialization_base.c index 4527b2a..80d06ff 100644 --- a/oemcrypto/odk/src/serialization_base.c +++ b/oemcrypto/odk/src/serialization_base.c @@ -27,70 +27,89 @@ static ODK_Message_Impl* GetMessageImpl(ODK_Message* message) { static void PackBytes(ODK_Message* message, const uint8_t* ptr, size_t count) { ODK_Message_Impl* message_impl = GetMessageImpl(message); - if (!message_impl) return; - if (count <= message_impl->capacity - message_impl->size) { + if (message_impl == NULL) return; + if (count > (message_impl->capacity - message_impl->size)) { + message_impl->status = MESSAGE_STATUS_OVERFLOW_ERROR; + return; + } + if (count > 0) { + assert(ptr); memcpy((void*)(message_impl->base + message_impl->size), (const void*)ptr, count); message_impl->size += count; - } else { - message_impl->status = MESSAGE_STATUS_OVERFLOW_ERROR; } } +static void PackPaddingBytes(ODK_Message* message, uint8_t fill_value, + size_t count) { + ODK_Message_Impl* message_impl = GetMessageImpl(message); + if (message_impl == NULL) return; + if (count > (message_impl->capacity - message_impl->size)) { + message_impl->status = MESSAGE_STATUS_OVERFLOW_ERROR; + return; + } + if (count == 0) return; + memset((void*)(message_impl->base + message_impl->size), fill_value, count); + message_impl->size += count; +} + void Pack_enum(ODK_Message* message, int value) { - uint32_t v32 = (uint32_t)value; - Pack_uint32_t(message, &v32); + const uint32_t value_u32 = (uint32_t)value; + Pack_uint32_t(message, &value_u32); } void Pack_bool(ODK_Message* message, const bool* value) { assert(value); - uint8_t data[4] = {0}; - data[3] = *value ? 1 : 0; + const uint8_t data[4] = {0, 0, 0, *value ? 1 : 0}; PackBytes(message, data, sizeof(data)); } void Pack_uint8_t(ODK_Message* message, const uint8_t* value) { assert(value); - uint8_t data[1] = {0}; - data[0] = (uint8_t)(*value >> 0); + const uint8_t data[1] = {*value}; PackBytes(message, data, sizeof(data)); } void Pack_uint16_t(ODK_Message* message, const uint16_t* value) { assert(value); - uint8_t data[2] = {0}; - data[0] = (uint8_t)(*value >> 8); - data[1] = (uint8_t)(*value >> 0); + const uint8_t data[2] = {(uint8_t)(*value >> 8), (uint8_t)(*value >> 0)}; PackBytes(message, data, sizeof(data)); } void Pack_uint32_t(ODK_Message* message, const uint32_t* value) { assert(value); - uint8_t data[4] = {0}; - data[0] = (uint8_t)(*value >> 24); - data[1] = (uint8_t)(*value >> 16); - data[2] = (uint8_t)(*value >> 8); - data[3] = (uint8_t)(*value >> 0); + const uint8_t data[4] = {(uint8_t)(*value >> 24), (uint8_t)(*value >> 16), + (uint8_t)(*value >> 8), (uint8_t)(*value >> 0)}; PackBytes(message, data, sizeof(data)); } void Pack_uint64_t(ODK_Message* message, const uint64_t* value) { assert(value); - uint32_t hi = (uint32_t)(*value >> 32); - uint32_t lo = (uint32_t)(*value); - Pack_uint32_t(message, &hi); - Pack_uint32_t(message, &lo); + const uint32_t high = (uint32_t)(*value >> 32); + const uint32_t low = (uint32_t)(*value); + Pack_uint32_t(message, &high); + Pack_uint32_t(message, &low); } void PackArray(ODK_Message* message, const uint8_t* base, size_t size) { PackBytes(message, base, size); } +// Writes a total of |data_field_size| bytes. +// The first |data_length| bytes are from |data|, +// the next |data_field_size - data_length| bytes are zero 0. +void PackFixedArray(ODK_Message* message, const uint8_t* data, + size_t data_length, size_t data_field_size) { + assert(data_length <= data_field_size); + PackBytes(message, data, data_length); + PackPaddingBytes(message, 0, data_field_size - data_length); +} + void Pack_OEMCrypto_Substring(ODK_Message* message, const OEMCrypto_Substring* obj) { assert(obj); - uint32_t offset = (uint32_t)obj->offset; - uint32_t length = (uint32_t)obj->length; + const uint32_t offset = (uint32_t)obj->offset; + const uint32_t length = (uint32_t)obj->length; Pack_uint32_t(message, &offset); Pack_uint32_t(message, &length); } @@ -98,23 +117,36 @@ void Pack_OEMCrypto_Substring(ODK_Message* message, static void UnpackBytes(ODK_Message* message, uint8_t* ptr, size_t count) { assert(ptr); ODK_Message_Impl* message_impl = GetMessageImpl(message); - if (!message_impl) return; - if (count <= message_impl->size - message_impl->read_offset) { + if (message_impl == NULL) return; + if (count > (message_impl->size - message_impl->read_offset)) { + message_impl->status = MESSAGE_STATUS_UNDERFLOW_ERROR; + return; + } + if (count > 0) { memcpy((void*)ptr, (void*)(message_impl->base + message_impl->read_offset), count); message_impl->read_offset += count; - } else { - message_impl->status = MESSAGE_STATUS_UNDERFLOW_ERROR; } } +static void SkipBytes(ODK_Message* message, size_t count) { + ODK_Message_Impl* message_impl = GetMessageImpl(message); + if (message_impl == NULL) return; + if (count > (message_impl->size - message_impl->read_offset)) { + message_impl->status = MESSAGE_STATUS_UNDERFLOW_ERROR; + return; + } + if (count == 0) return; + message_impl->read_offset += count; +} + void Unpack_OEMCrypto_LicenseType(ODK_Message* message, OEMCrypto_LicenseType* value) { assert(value); - uint32_t v32 = 0; - Unpack_uint32_t(message, &v32); - if (v32 <= OEMCrypto_LicenseType_MaxValue) { - *value = (OEMCrypto_LicenseType)v32; + uint32_t value_u32 = 0; + Unpack_uint32_t(message, &value_u32); + if (value_u32 <= OEMCrypto_LicenseType_MaxValue) { + *value = (OEMCrypto_LicenseType)value_u32; } else { ODK_Message_SetStatus(message, MESSAGE_STATUS_PARSE_ERROR); } @@ -123,10 +155,10 @@ void Unpack_OEMCrypto_LicenseType(ODK_Message* message, void Unpack_OEMCrypto_PrivateKeyType(ODK_Message* message, OEMCrypto_PrivateKeyType* value) { assert(value); - uint32_t v32 = 0; - Unpack_uint32_t(message, &v32); - if (v32 <= OEMCrypto_PrivateKeyType_MaxValue) { - *value = (OEMCrypto_PrivateKeyType)v32; + uint32_t value_u32 = 0; + Unpack_uint32_t(message, &value_u32); + if (value_u32 <= OEMCrypto_PrivateKeyType_MaxValue) { + *value = (OEMCrypto_PrivateKeyType)value_u32; } else { ODK_Message_SetStatus(message, MESSAGE_STATUS_PARSE_ERROR); } @@ -135,10 +167,10 @@ void Unpack_OEMCrypto_PrivateKeyType(ODK_Message* message, void Unpack_OEMCrypto_TimerDelayBase(ODK_Message* message, OEMCrypto_TimerDelayBase* value) { assert(value); - uint32_t v32 = 0; - Unpack_uint32_t(message, &v32); - if (v32 <= OEMCrypto_TimerDelayBase_MaxValue) { - *value = (OEMCrypto_TimerDelayBase)v32; + uint32_t value_u32 = 0; + Unpack_uint32_t(message, &value_u32); + if (value_u32 <= OEMCrypto_TimerDelayBase_MaxValue) { + *value = (OEMCrypto_TimerDelayBase)value_u32; } else { ODK_Message_SetStatus(message, MESSAGE_STATUS_PARSE_ERROR); } @@ -168,7 +200,7 @@ void Unpack_uint16_t(ODK_Message* message, uint16_t* value) { void Unpack_uint32_t(ODK_Message* message, uint32_t* value) { ODK_Message_Impl* message_impl = (ODK_Message_Impl*)message; - if (!message_impl) return; + if (message_impl == NULL) return; uint8_t data[4] = {0}; UnpackBytes(message, data, sizeof(data)); assert(value); @@ -179,22 +211,23 @@ void Unpack_uint32_t(ODK_Message* message, uint32_t* value) { } void Unpack_uint64_t(ODK_Message* message, uint64_t* value) { - uint32_t hi = 0; - uint32_t lo = 0; - Unpack_uint32_t(message, &hi); - Unpack_uint32_t(message, &lo); + uint32_t high = 0; + uint32_t low = 0; + Unpack_uint32_t(message, &high); + Unpack_uint32_t(message, &low); assert(value); - *value = hi; - *value = *value << 32 | lo; + *value = high; + *value = *value << 32 | low; } void Unpack_OEMCrypto_Substring(ODK_Message* message, OEMCrypto_Substring* obj) { - uint32_t offset = 0, length = 0; + uint32_t offset = 0; + uint32_t length = 0; Unpack_uint32_t(message, &offset); Unpack_uint32_t(message, &length); ODK_Message_Impl* message_impl = GetMessageImpl(message); - if (!message_impl) return; + if (message_impl == NULL) return; /* Each substring should be contained within the message body, which is in the * total message, just after the core message. The offset of a substring is @@ -220,11 +253,22 @@ void Unpack_OEMCrypto_Substring(ODK_Message* message, return; } assert(obj); - obj->offset = offset; - obj->length = length; + obj->offset = (size_t)offset; + obj->length = (size_t)length; } /* copy out */ void UnpackArray(ODK_Message* message, uint8_t* address, size_t size) { + assert(address); UnpackBytes(message, address, size); } + +void UnpackFixedArray(ODK_Message* message, uint8_t* data, size_t data_length, + size_t data_field_size) { + assert(data); + assert(data_length <= data_field_size); + UnpackBytes(message, data, data_length); + // Move the counter along the rest of the buffer. + SkipBytes(message, data_field_size - data_length); + memset(&data[data_length], 0, data_field_size - data_length); +} diff --git a/oemcrypto/odk/src/serialization_base.h b/oemcrypto/odk/src/serialization_base.h index e4b0ef0..9c287a8 100644 --- a/oemcrypto/odk/src/serialization_base.h +++ b/oemcrypto/odk/src/serialization_base.h @@ -5,16 +5,16 @@ #ifndef WIDEVINE_ODK_SRC_SERIALIZATION_BASE_H_ #define WIDEVINE_ODK_SRC_SERIALIZATION_BASE_H_ -#ifdef __cplusplus -extern "C" { -#endif - #include #include #include "OEMCryptoCENCCommon.h" #include "odk_message.h" +#ifdef __cplusplus +extern "C" { +#endif + void Pack_enum(ODK_Message* message, int value); void Pack_bool(ODK_Message* message, const bool* value); void Pack_uint8_t(ODK_Message* message, const uint8_t* value); @@ -22,6 +22,8 @@ void Pack_uint16_t(ODK_Message* message, const uint16_t* value); void Pack_uint32_t(ODK_Message* message, const uint32_t* value); void Pack_uint64_t(ODK_Message* message, const uint64_t* value); void PackArray(ODK_Message* message, const uint8_t* base, size_t size); +void PackFixedArray(ODK_Message* message, const uint8_t* data, + size_t data_length, size_t data_field_size); void Pack_OEMCrypto_Substring(ODK_Message* message, const OEMCrypto_Substring* obj); @@ -38,6 +40,8 @@ void Unpack_uint32_t(ODK_Message* message, uint32_t* value); void Unpack_uint64_t(ODK_Message* message, uint64_t* value); void UnpackArray(ODK_Message* message, uint8_t* address, size_t size); /* copy out */ +void UnpackFixedArray(ODK_Message* message, uint8_t* data, size_t data_length, + size_t data_field_size); void Unpack_OEMCrypto_Substring(ODK_Message* message, OEMCrypto_Substring* obj); #ifdef __cplusplus diff --git a/oemcrypto/odk/test/fuzzing/Android.bp b/oemcrypto/odk/test/fuzzing/Android.bp index 6ce3f45..cda0879 100644 --- a/oemcrypto/odk/test/fuzzing/Android.bp +++ b/oemcrypto/odk/test/fuzzing/Android.bp @@ -16,166 +16,107 @@ package { } cc_defaults { - name: "odk_fuzz_library_defaults", + name: "odk_fuzz_test_defaults", srcs: [ "odk_fuzz_helper.cpp", ], - include_dirs: [ - "vendor/widevine/libwvdrmengine/oemcrypto/odk/test", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/include", - "vendor/widevine/libwvdrmengine/oemcrypto/odk/src", + static_libs: [ + "libwv_kdo", + "libwv_odk", ], + fuzz_config: { + componentid: 611718, + }, + proprietary: true, + owner: "widevine", } cc_fuzz { name: "odk_license_request_fuzz", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_license_request_fuzz.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/license_request_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_renewal_request_fuzz", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_renewal_request_fuzz.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/renewal_request_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_provisioning_request_fuzz", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_provisioning_request_fuzz.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/provisioning_request_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_license_response_fuzz", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_license_response_fuzz.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/license_response_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_renewal_response_fuzz", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_renewal_response_fuzz.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/renewal_response_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_provisioning_response_fuzz", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_provisioning_response_fuzz.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/provisioning_response_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_license_response_fuzz_with_mutator", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_license_response_fuzz_with_mutator.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/license_response_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_renewal_response_fuzz_with_mutator", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_renewal_response_fuzz_with_mutator.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/renewal_response_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } cc_fuzz { name: "odk_provisioning_response_fuzz_with_mutator", + defaults: ["odk_fuzz_test_defaults"], + srcs: [ "odk_provisioning_response_fuzz_with_mutator.cpp", ], - fuzz_config: { - componentid: 611718, - }, corpus: ["corpus/little_endian_64bit/provisioning_response_corpus/*"], - static_libs: [ - "libwv_kdo", - "libwv_odk", - ], - defaults: ["odk_fuzz_library_defaults"], - proprietary: true, } diff --git a/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.cpp b/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.cpp index b7797fd..080d31c 100644 --- a/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.cpp +++ b/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.cpp @@ -3,42 +3,94 @@ // License Agreement. #include "fuzzing/odk_fuzz_helper.h" +#include #include +#include #include +#include "OEMCryptoCENCCommon.h" +#include "core_message_features.h" #include "core_message_types.h" +#include "fuzzing/odk_fuzz_structs.h" #include "odk.h" #include "odk_attributes.h" #include "odk_structs.h" +#include "odk_structs_priv.h" #include "odk_target.h" +namespace { +bool ConvertFuzzedByteToNormalizedBoolean(const bool* in) { + constexpr uint8_t kConversionBit = 1; + const uint8_t value = *reinterpret_cast(in); + return (kConversionBit & value) != 0; +} + +OEMCrypto_LicenseType ConvertFuzzedBytesToNormalizedLicenseType( + const OEMCrypto_LicenseType* fuzz_license_type) { + // ASAN will trigger an error if fuzz_license_type is dereference directly + // before being normalized. + using RawEnum = std::underlying_type_t; + const RawEnum kLowestValue = static_cast(0); + const RawEnum kHighestValue = + static_cast(OEMCrypto_LicenseType_MaxValue); + const RawEnum kSpan = kHighestValue - kLowestValue + 1; + const RawEnum raw_license_type = + *reinterpret_cast(fuzz_license_type); + return static_cast( + ((raw_license_type - kLowestValue) % kSpan) + kLowestValue); +} + +OEMCrypto_TimerDelayBase ConvertFuzzedBytesToNormalizedRenewalDelayBase( + const OEMCrypto_TimerDelayBase* fuzz_renewal_delay_base) { + // ASAN will trigger an error if fuzz_license_type is dereference directly + // before being normalized. + using RawEnum = std::underlying_type_t; + const RawEnum kLowestValue = static_cast(0); + const RawEnum kHighestValue = + static_cast(OEMCrypto_TimerDelayBase_MaxValue); + const RawEnum kSpan = kHighestValue - kLowestValue + 1; + const RawEnum raw_renewal_delay_base = + *reinterpret_cast(fuzz_renewal_delay_base); + return static_cast( + ((raw_renewal_delay_base - kLowestValue) % kSpan) + kLowestValue); +} + +OEMCrypto_PrivateKeyType ConvertFuzzedBytesToNormalizedPrivateKeyType( + const OEMCrypto_PrivateKeyType* fuzz_key_type) { + using RawEnum = std::underlying_type_t; + const RawEnum kLowestValue = static_cast(0); + const RawEnum kHighestValue = + static_cast(OEMCrypto_PrivateKeyType_MaxValue); + const RawEnum kSpan = kHighestValue - kLowestValue + 1; + const RawEnum raw_key_type = *reinterpret_cast(fuzz_key_type); + return static_cast( + ((raw_key_type - kLowestValue) % kSpan) + kLowestValue); +} +} // namespace + namespace oemcrypto_core_message { using features::CoreMessageFeatures; -bool convert_byte_to_valid_boolean(const bool* in) { - const char* buf = reinterpret_cast(in); - for (int i = 0; i < sizeof(bool); i++) { - if (buf[i]) { - return true; - } - } - return false; -} - -void ConvertDataToValidBools(ODK_ParsedLicense* t) { +void NormalizeODKStructFields(ODK_ParsedLicense* t) { // Convert boolean flags in parsed_license to valid bytes to // avoid errors from msan - t->nonce_required = convert_byte_to_valid_boolean(&t->nonce_required); + t->nonce_required = ConvertFuzzedByteToNormalizedBoolean(&t->nonce_required); t->timer_limits.soft_enforce_playback_duration = - convert_byte_to_valid_boolean( + ConvertFuzzedByteToNormalizedBoolean( &t->timer_limits.soft_enforce_playback_duration); - t->timer_limits.soft_enforce_rental_duration = convert_byte_to_valid_boolean( - &t->timer_limits.soft_enforce_rental_duration); + t->timer_limits.soft_enforce_rental_duration = + ConvertFuzzedByteToNormalizedBoolean( + &t->timer_limits.soft_enforce_rental_duration); + t->license_type = ConvertFuzzedBytesToNormalizedLicenseType(&t->license_type); + t->renewal_delay_base = + ConvertFuzzedBytesToNormalizedRenewalDelayBase(&t->renewal_delay_base); } -void ConvertDataToValidBools(ODK_PreparedRenewalRequest* t UNUSED) {} +void NormalizeODKStructFields(ODK_PreparedRenewalRequest* t UNUSED) {} -void ConvertDataToValidBools(ODK_ParsedProvisioning* t UNUSED) {} +void NormalizeODKStructFields(ODK_ParsedProvisioning* t) { + t->key_type = ConvertFuzzedBytesToNormalizedPrivateKeyType(&t->key_type); +} OEMCryptoResult odk_serialize_LicenseRequest( const void* in UNUSED, uint8_t* out, size_t* size, @@ -47,8 +99,10 @@ OEMCryptoResult odk_serialize_LicenseRequest( // TODO(mattfedd): hook up counters to fuzzer const ODK_MessageCounterInfo counter_info = {0, 0, 0, 0, 0, 0, 0, {0}, {0}, {0}}; + uint16_t decenc_mitigation_options_supported = 0; return ODK_PrepareCoreLicenseRequest(out, SIZE_MAX, size, nonce_values, - &counter_info); + &counter_info, + decenc_mitigation_options_supported); } OEMCryptoResult odk_serialize_RenewalRequest( @@ -93,10 +147,11 @@ OEMCryptoResult odk_deserialize_RenewalResponse( * errors in fuzzer code and converting random bytes to 0 OR 1. * This has no negative security impact*/ a->timer_limits.soft_enforce_playback_duration = - convert_byte_to_valid_boolean( + ConvertFuzzedByteToNormalizedBoolean( &a->timer_limits.soft_enforce_playback_duration); - a->timer_limits.soft_enforce_rental_duration = convert_byte_to_valid_boolean( - &a->timer_limits.soft_enforce_rental_duration); + a->timer_limits.soft_enforce_rental_duration = + ConvertFuzzedByteToNormalizedBoolean( + &a->timer_limits.soft_enforce_rental_duration); uint64_t timer_value = 0; OEMCryptoResult err = ODK_ParseRenewal(buf, SIZE_MAX, len, nonce_values, a->system_time, @@ -128,6 +183,9 @@ OEMCryptoResult odk_deserialize_ProvisioningResponse( bool kdo_serialize_LicenseResponse(const ODK_ParseLicense_Args* args, const ODK_ParsedLicense& parsed_lic, std::string* oemcrypto_core_message) { + if (args == nullptr || oemcrypto_core_message == nullptr) { + return false; + } const auto& nonce_values = args->nonce_values; const ODK_MessageCounter counter_info = {0, 0, 0, 0, 0, 0, 0, {0}, {0}, {0}}; ODK_LicenseRequest core_request{ @@ -146,11 +204,22 @@ bool kdo_serialize_LicenseResponse(const ODK_ParseLicense_Args* args, parsed_license.watermarking = parsed_lic.watermarking; parsed_license.dtcp2_required = parsed_lic.dtcp2_required; parsed_license.renewal_delay_base = parsed_lic.renewal_delay_base; + parsed_license.decenc_mitigation_option_used = + parsed_lic.decenc_mitigation_option_used; + // If the mutated input for parsed_lic.key_array_length is greater than + // 2 * ODK_MAX_NUM_KEYS, then don't bother going through the rest of the + // serialization code. + if (parsed_lic.key_array_length > 2 * ODK_MAX_NUM_KEYS) return false; parsed_license.key_array_length = parsed_lic.key_array_length; - std::vector key_array; - size_t i; - for (i = 0; i < parsed_lic.key_array_length && i < ODK_MAX_NUM_KEYS; i++) { - key_array.push_back(parsed_lic.key_array[i]); + std::vector key_array; + key_array.reserve(parsed_lic.key_array_length); + for (size_t i = 0; i < parsed_lic.key_array_length; i++) { + // |parsed_lic| is of type ODK_ParsedLicense which is a struct with a fixed + // key array size. Since |parsed_license| is of type + // ODK_Packing_ParsedLicense which is a struct with a variable key array + // size, we can use a modulus operation to make duplicate keys in + // |parsed_license|. + key_array.push_back(parsed_lic.key_array[i % ODK_MAX_NUM_KEYS]); } parsed_license.key_array = key_array.data(); return serialize::CreateCoreLicenseResponse( diff --git a/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.h b/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.h index 309c2be..4d0d29f 100644 --- a/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.h +++ b/oemcrypto/odk/test/fuzzing/odk_fuzz_helper.h @@ -4,6 +4,7 @@ #ifndef WIDEVINE_ODK_TEST_FUZZING_ODK_FUZZ_HELPER_H_ #define WIDEVINE_ODK_TEST_FUZZING_ODK_FUZZ_HELPER_H_ +#include #include #include @@ -12,10 +13,9 @@ #include "fuzzing/odk_fuzz_structs.h" #include "odk_attributes.h" #include "odk_serialize.h" +#include "odk_structs.h" namespace oemcrypto_core_message { -bool convert_byte_to_valid_boolean(const bool* in); - OEMCryptoResult odk_serialize_LicenseRequest( const void* in, uint8_t* out, size_t* size, const ODK_LicenseRequest& core_license_request, @@ -66,11 +66,11 @@ bool kdo_serialize_ProvisioningResponse( // make us of common FuzzerMutateResponse across three response fuzzers, // three independent functions were defined and renewal and provisioning // functions would be empty as no additional processing is needed for them. -void ConvertDataToValidBools(ODK_ParsedLicense* t); +void NormalizeODKStructFields(ODK_ParsedLicense* t); -void ConvertDataToValidBools(ODK_PreparedRenewalRequest* t); +void NormalizeODKStructFields(ODK_PreparedRenewalRequest* t); -void ConvertDataToValidBools(ODK_ParsedProvisioning* t); +void NormalizeODKStructFields(ODK_ParsedProvisioning* t); // Forward-declare the libFuzzer's mutator callback. Mark it weak so that // the program links successfully even outside of --config=asan-fuzzer @@ -78,6 +78,16 @@ void ConvertDataToValidBools(ODK_ParsedProvisioning* t); extern "C" size_t LLVMFuzzerMutate(uint8_t* Data, size_t Size, size_t MaxSize) __attribute__((weak)); +// 1) ODK message fuzz bytes -> F deserialize -> struct T +// 2) struct T -> ODK struct fuzzing -> fuzz struct T +// 2a) struct T -> memcpy + padding -> unnormalized intermediate bytes +// 2b) unnormalized intermediate bytes -> LLVMFuzzerMutate -> unnormalized +// intermediate buzz bytes +// 2c) unnormalized intermediate buzz bytes -> memcpy +// -> unnormalized struct T +// 2d) unnormalized struct T -> +// NormalizeODKStructFields -> fuzz struct T +// 3) fuzz struct T -> G deserialize -> ODK message template size_t FuzzerMutateResponse(uint8_t* data, size_t size, size_t max_size, const F& odk_deserialize_fun, @@ -123,8 +133,8 @@ size_t FuzzerMutateResponse(uint8_t* data, size_t size, size_t max_size, memcpy(args, data, kArgsSize); memcpy(&t, data + kArgsSize, kCoreResponseSize); // Convert boolean flags in parsed message to valid bytes to - // avoid errors from msan. Only needed for parsed license. - ConvertDataToValidBools(&t); + // avoid errors from msan. + NormalizeODKStructFields(&t); // Serialize the data after mutation. std::string oemcrypto_core_message; if (!kdo_serialize_fun(args, t, &oemcrypto_core_message)) { diff --git a/oemcrypto/odk/test/fuzzing/odk_license_response_fuzz_with_mutator.cpp b/oemcrypto/odk/test/fuzzing/odk_license_response_fuzz_with_mutator.cpp index 735ce7a..0148eb4 100644 --- a/oemcrypto/odk/test/fuzzing/odk_license_response_fuzz_with_mutator.cpp +++ b/oemcrypto/odk/test/fuzzing/odk_license_response_fuzz_with_mutator.cpp @@ -15,7 +15,9 @@ extern "C" size_t LLVMFuzzerCustomMutator(uint8_t* data, size_t size, size_t max_size, unsigned int seed UNUSED) { const size_t kLicenseResponseArgsSize = sizeof(ODK_ParseLicense_Args); - if (size < kLicenseResponseArgsSize) { + const size_t kCoreResponseSize = sizeof(ODK_ParsedLicense); + // TODO(b/403349564): Regenerate larger corpuses. + if (size < kLicenseResponseArgsSize || size < kCoreResponseSize) { return 0; } diff --git a/oemcrypto/odk/test/fuzzing/odk_renewal_response_fuzz.cpp b/oemcrypto/odk/test/fuzzing/odk_renewal_response_fuzz.cpp index 416c7b0..d7e0710 100644 --- a/oemcrypto/odk/test/fuzzing/odk_renewal_response_fuzz.cpp +++ b/oemcrypto/odk/test/fuzzing/odk_renewal_response_fuzz.cpp @@ -6,6 +6,7 @@ #include #include "fuzzing/odk_fuzz_helper.h" +#include "odk_structs_priv.h" namespace oemcrypto_core_message { diff --git a/oemcrypto/odk/test/odk_core_message_test.cpp b/oemcrypto/odk/test/odk_core_message_test.cpp index bea7891..ff2d9bf 100644 --- a/oemcrypto/odk/test/odk_core_message_test.cpp +++ b/oemcrypto/odk/test/odk_core_message_test.cpp @@ -42,7 +42,7 @@ TEST(CoreMessageTest, RenwalRequest) { uint32_t nonce = 0; uint32_t timer_status = 2; uint64_t time = 10; - enum OEMCrypto_Usage_Entry_Status status = kInactiveUsed; + enum OEMCrypto_UsageEntryStatus status = OEMCrypto_InactiveUsed; ODK_NonceValues nonce_values{api_minor_version, api_major_version, nonce}; ODK_ClockValues clock_values{time, time, time, time, time, timer_status, status}; @@ -125,8 +125,7 @@ TEST_P(ProvisioningRoundTripTest_18V0, ProvisioningRoundtrip) { // Make sure we can create a response from that request with the same core // message - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); std::string serialized_provisioning_resp; video_widevine::ProvisioningResponse provisioning_response; provisioning_response.set_device_certificate("device_certificate"); @@ -137,8 +136,8 @@ TEST_P(ProvisioningRoundTripTest_18V0, ProvisioningRoundtrip) { } std::string oemcrypto_core_message; EXPECT_TRUE(CreateCoreProvisioningResponseFromProto( - features, serialized_provisioning_resp, request, - OEMCrypto_RSA_Private_Key, &oemcrypto_core_message)); + features, serialized_provisioning_resp, request, OEMCrypto_RSAPrivateKey, + &oemcrypto_core_message)); // Extract core message from generated prov response and match values with // request diff --git a/oemcrypto/odk/test/odk_golden_v16.cpp b/oemcrypto/odk/test/odk_golden_v16.cpp index 3dcf1f7..cffb507 100644 --- a/oemcrypto/odk/test/odk_golden_v16.cpp +++ b/oemcrypto/odk/test/odk_golden_v16.cpp @@ -41,15 +41,14 @@ class ODKGoldenProvisionV16 : public ::testing::Test { EXPECT_TRUE(CoreProvisioningRequestFromMessage(core_request_, &core_provisioning_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreProvisioningResponseFromProto( features, provisioning_response_, core_provisioning_request, device_key_type_, &generated_core_message)); EXPECT_EQ(core_response_, generated_core_message); } - OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSA_Private_Key; + OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSAPrivateKey; std::string core_request_; std::string core_response_; std::string provisioning_response_; @@ -62,8 +61,7 @@ class ODKGoldenLicenseV16 : public ::testing::Test { EXPECT_TRUE( CoreLicenseRequestFromMessage(core_request_, &core_license_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreLicenseResponseFromProto( features, serialized_license_, core_license_request, core_request_sha256_, nonce_required_, uses_padding_, @@ -86,8 +84,7 @@ class ODKGoldenRenewalV16 : public ::testing::Test { EXPECT_TRUE( CoreRenewalRequestFromMessage(core_request_, &core_renewal_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreRenewalResponse(features, core_renewal_request, renewal_duration_seconds_, &generated_core_message)); diff --git a/oemcrypto/odk/test/odk_golden_v17.cpp b/oemcrypto/odk/test/odk_golden_v17.cpp index e1e832c..545c518 100644 --- a/oemcrypto/odk/test/odk_golden_v17.cpp +++ b/oemcrypto/odk/test/odk_golden_v17.cpp @@ -41,15 +41,14 @@ class ODKGoldenProvisionV17 : public ::testing::Test { EXPECT_TRUE(CoreProvisioningRequestFromMessage(core_request_, &core_provisioning_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreProvisioningResponseFromProto( features, provisioning_response_, core_provisioning_request, device_key_type_, &generated_core_message)); EXPECT_EQ(core_response_, generated_core_message); } - OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSA_Private_Key; + OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSAPrivateKey; std::string core_request_; std::string core_response_; std::string provisioning_response_; @@ -62,8 +61,7 @@ class ODKGoldenLicenseV17 : public ::testing::Test { EXPECT_TRUE( CoreLicenseRequestFromMessage(core_request_, &core_license_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreLicenseResponseFromProto( features, serialized_license_, core_license_request, core_request_sha256_, nonce_required_, uses_padding_, @@ -86,8 +84,7 @@ class ODKGoldenRenewalV17 : public ::testing::Test { EXPECT_TRUE( CoreRenewalRequestFromMessage(core_request_, &core_renewal_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreRenewalResponse(features, core_renewal_request, renewal_duration_seconds_, &generated_core_message)); diff --git a/oemcrypto/odk/test/odk_golden_v18.cpp b/oemcrypto/odk/test/odk_golden_v18.cpp index acbd8dc..505c6f4 100644 --- a/oemcrypto/odk/test/odk_golden_v18.cpp +++ b/oemcrypto/odk/test/odk_golden_v18.cpp @@ -41,15 +41,14 @@ class ODKGoldenProvisionV18 : public ::testing::Test { EXPECT_TRUE(CoreProvisioningRequestFromMessage(core_request_, &core_provisioning_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreProvisioningResponseFromProto( features, provisioning_response_, core_provisioning_request, device_key_type_, &generated_core_message)); EXPECT_EQ(core_response_, generated_core_message); } - OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSA_Private_Key; + OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSAPrivateKey; std::string core_request_; std::string core_response_; std::string provisioning_response_; @@ -62,8 +61,7 @@ class ODKGoldenLicenseV18 : public ::testing::Test { EXPECT_TRUE( CoreLicenseRequestFromMessage(core_request_, &core_license_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreLicenseResponseFromProto( features, serialized_license_, core_license_request, core_request_sha256_, nonce_required_, uses_padding_, @@ -86,8 +84,7 @@ class ODKGoldenRenewalV18 : public ::testing::Test { EXPECT_TRUE( CoreRenewalRequestFromMessage(core_request_, &core_renewal_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreRenewalResponse(features, core_renewal_request, renewal_duration_seconds_, &generated_core_message)); @@ -369,7 +366,7 @@ TEST_F(ODKGoldenProvisionV18, CorePIGTest_OfflineNoNonce_prov20) { provisioning_response_ = std::string(reinterpret_cast(provisioning_response_raw), sizeof(provisioning_response_raw)); - device_key_type_ = OEMCrypto_RSA_Private_Key; + device_key_type_ = OEMCrypto_RSAPrivateKey; RunTest(); } @@ -496,7 +493,7 @@ TEST_F(ODKGoldenProvisionV18, CorePIGTest_OfflineNoNonce_prov20ecc) { provisioning_response_ = std::string(reinterpret_cast(provisioning_response_raw), sizeof(provisioning_response_raw)); - device_key_type_ = OEMCrypto_ECC_Private_Key; + device_key_type_ = OEMCrypto_ECCPrivateKey; RunTest(); } @@ -760,7 +757,7 @@ TEST_F(ODKGoldenProvisionV18, CorePIGTest_OfflineNoNonce_prov30) { provisioning_response_ = std::string(reinterpret_cast(provisioning_response_raw), sizeof(provisioning_response_raw)); - device_key_type_ = OEMCrypto_RSA_Private_Key; + device_key_type_ = OEMCrypto_RSAPrivateKey; RunTest(); } diff --git a/oemcrypto/odk/test/odk_golden_v19.cpp b/oemcrypto/odk/test/odk_golden_v19.cpp index bee6c0d..8099ffa 100644 --- a/oemcrypto/odk/test/odk_golden_v19.cpp +++ b/oemcrypto/odk/test/odk_golden_v19.cpp @@ -41,15 +41,14 @@ class ODKGoldenProvisionV19 : public ::testing::Test { EXPECT_TRUE(CoreProvisioningRequestFromMessage(core_request_, &core_provisioning_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreProvisioningResponseFromProto( features, provisioning_response_, core_provisioning_request, device_key_type_, &generated_core_message)); EXPECT_EQ(core_response_, generated_core_message); } - OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSA_Private_Key; + OEMCrypto_PrivateKeyType device_key_type_ = OEMCrypto_RSAPrivateKey; std::string core_request_; std::string core_response_; std::string provisioning_response_; @@ -75,8 +74,7 @@ class ODKGoldenLicenseV19 : public ::testing::Test { EXPECT_TRUE( CoreLicenseRequestFromMessage(core_request_, &core_license_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreLicenseResponseFromProto( features, serialized_license_, core_license_request, core_request_sha256_, nonce_required_, uses_padding_, @@ -99,8 +97,7 @@ class ODKGoldenRenewalV19 : public ::testing::Test { EXPECT_TRUE( CoreRenewalRequestFromMessage(core_request_, &core_renewal_request)); std::string generated_core_message; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); EXPECT_TRUE(CreateCoreRenewalResponse(features, core_renewal_request, renewal_duration_seconds_, &generated_core_message)); @@ -382,7 +379,7 @@ TEST_F(ODKGoldenProvisionV19, CorePIGTest_OfflineNoNonce_prov20) { provisioning_response_ = std::string(reinterpret_cast(provisioning_response_raw), sizeof(provisioning_response_raw)); - device_key_type_ = OEMCrypto_RSA_Private_Key; + device_key_type_ = OEMCrypto_RSAPrivateKey; RunTest(); } @@ -509,7 +506,7 @@ TEST_F(ODKGoldenProvisionV19, CorePIGTest_OfflineNoNonce_prov20ecc) { provisioning_response_ = std::string(reinterpret_cast(provisioning_response_raw), sizeof(provisioning_response_raw)); - device_key_type_ = OEMCrypto_ECC_Private_Key; + device_key_type_ = OEMCrypto_ECCPrivateKey; RunTest(); } @@ -773,7 +770,7 @@ TEST_F(ODKGoldenProvisionV19, CorePIGTest_OfflineNoNonce_prov30) { provisioning_response_ = std::string(reinterpret_cast(provisioning_response_raw), sizeof(provisioning_response_raw)); - device_key_type_ = OEMCrypto_RSA_Private_Key; + device_key_type_ = OEMCrypto_RSAPrivateKey; RunTest(); } diff --git a/oemcrypto/odk/test/odk_test.cpp b/oemcrypto/odk/test/odk_test.cpp index bd14883..f84e946 100644 --- a/oemcrypto/odk/test/odk_test.cpp +++ b/oemcrypto/odk/test/odk_test.cpp @@ -4,6 +4,8 @@ #include "odk.h" +#include + #include #include #include @@ -41,7 +43,6 @@ using oemcrypto_core_message::deserialize::CoreCommonRequestFromMessage; using oemcrypto_core_message::deserialize::CoreLicenseRequestFromMessage; using oemcrypto_core_message::deserialize::CoreProvisioning40RequestFromMessage; using oemcrypto_core_message::deserialize::CoreProvisioningRequestFromMessage; -using oemcrypto_core_message::deserialize::CoreReleaseRequestFromMessage; using oemcrypto_core_message::deserialize::CoreRenewalRequestFromMessage; using oemcrypto_core_message::deserialize:: CoreRenewedProvisioningRequestFromMessage; @@ -61,21 +62,33 @@ constexpr uint32_t kExtraPayloadSize = 128u; /* Used to parameterize tests by version number. The request is given one * version number, and we will expect the response to have another version * number. */ -struct VersionParameters { - uint32_t maximum_major_version; - uint16_t request_major_version; - uint16_t request_minor_version; - uint16_t response_major_version; - uint16_t response_minor_version; +struct VersionPair { + uint16_t major_version; + uint16_t minor_version; + + template + VersionPair(T major, S minor) + : major_version(static_cast(major)), + minor_version(static_cast(minor)){}; }; +struct VersionParameters { + VersionPair maximum; // The CoreMessageFeatures maximum version. + bool prerelease; // CoreMessageFeatures serve_prerelease_odk_messages. + VersionPair request; + VersionPair response; + bool should_accept_request; +}; + +std::ostream& operator<<(std::ostream& stream, const VersionPair& v) { + return stream << "v" << v.major_version << "." << v.minor_version; +} // This function is called by GTest when a parameterized test fails in order // to log the parameter used for the failing test. void PrintTo(const VersionParameters& p, std::ostream* os) { - *os << "max=v" << p.maximum_major_version << ", request = v" - << p.request_major_version << "." << p.request_minor_version - << ", response = v" << p.response_major_version << "." - << p.response_minor_version; + *os << "max=" << p.maximum << ", prerelease=" << p.prerelease + << ", request=" << p.request << ", response=" << p.response + << ", should_accept_request=" << p.should_accept_request; } void SetDefaultSerializedProvisioningResponse(std::string* serialized_message) { @@ -147,40 +160,37 @@ void ValidateRequest(uint32_t message_type, // non-empty buf, expect core message length to be set correctly, and buf is // filled with ODK_Field values appropriately - uint8_t* buf = new uint8_t[message_size]{}; + std::vector buf(message_size, 0); EXPECT_EQ(OEMCrypto_SUCCESS, - odk_prepare_func(buf, &core_message_length, &nonce_values)); + odk_prepare_func(buf.data(), &core_message_length, &nonce_values)); EXPECT_EQ(core_message_length, message_size); - uint8_t* buf_expected = new uint8_t[message_size]{}; + std::vector buf_expected(message_size, 0); size_t buf_len_expected = 0; - EXPECT_EQ(OEMCrypto_SUCCESS, ODK_IterFields(ODK_WRITE, buf_expected, SIZE_MAX, - &buf_len_expected, total_fields)); + EXPECT_EQ(OEMCrypto_SUCCESS, + ODK_IterFields(ODK_WRITE, buf_expected.data(), SIZE_MAX, + &buf_len_expected, total_fields)); EXPECT_EQ(buf_len_expected, message_size); - EXPECT_NO_FATAL_FAILURE( - ODK_ExpectEqualBuf(buf_expected, buf, message_size, total_fields)); + EXPECT_NO_FATAL_FAILURE(ODK_ExpectEqualBuf(buf_expected.data(), buf.data(), + message_size, total_fields)); // odk kdo round-trip: deserialize from buf, then serialize it to buf2 // expect them to be identical T t = {}; - std::string oemcrypto_core_message(reinterpret_cast(buf), + std::string oemcrypto_core_message(reinterpret_cast(buf.data()), message_size); EXPECT_TRUE(kdo_parse_func(oemcrypto_core_message, &t)); nonce_values.api_minor_version = t.api_minor_version; nonce_values.api_major_version = t.api_major_version; nonce_values.nonce = t.nonce; nonce_values.session_id = t.session_id; - uint8_t* buf2 = new uint8_t[message_size]{}; + std::vector buf2(message_size, 0); EXPECT_EQ(OEMCrypto_SUCCESS, - odk_prepare_func(buf2, &core_message_length, &nonce_values)); + odk_prepare_func(buf2.data(), &core_message_length, &nonce_values)); EXPECT_EQ(core_message_length, message_size); EXPECT_NO_FATAL_FAILURE( - ODK_ExpectEqualBuf(buf, buf2, message_size, total_fields)); - - delete[] buf; - delete[] buf_expected; - delete[] buf2; + ODK_ExpectEqualBuf(buf.data(), buf2.data(), message_size, total_fields)); } /** @@ -195,40 +205,39 @@ void ValidateResponse(const VersionParameters& versions, const std::vector& extra_fields, const F& odk_parse_func, const G& kdo_prepare_func) { T t = {}; - t.api_major_version = versions.request_major_version; - t.api_minor_version = versions.request_minor_version; + t.api_major_version = versions.request.major_version; + t.api_minor_version = versions.request.minor_version; t.nonce = core_message->nonce_values.nonce; t.session_id = core_message->nonce_values.session_id; - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(core_message, extra_fields, &buf, &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(core_message, extra_fields, &buf); - uint8_t* zero = new uint8_t[buf_size]{}; + std::vector zero(buf.size(), 0); size_t bytes_read = 0; // zero-out input - EXPECT_EQ(OEMCrypto_SUCCESS, ODK_IterFields(ODK_READ, zero, buf_size, + EXPECT_EQ(OEMCrypto_SUCCESS, ODK_IterFields(ODK_READ, zero.data(), buf.size(), &bytes_read, extra_fields)); // Parse buf with odk - const OEMCryptoResult parse_result = odk_parse_func(buf, buf_size); + const OEMCryptoResult parse_result = odk_parse_func(buf.data(), buf.size()); EXPECT_EQ(OEMCrypto_SUCCESS, parse_result); size_t size_out = 0; if (parse_result != OEMCrypto_SUCCESS) { - ODK_IterFields(ODK_FieldMode::ODK_DUMP, buf, buf_size, &size_out, + ODK_IterFields(ODK_FieldMode::ODK_DUMP, buf.data(), buf.size(), &size_out, extra_fields); } // serialize odk output to oemcrypto_core_message std::string oemcrypto_core_message; - EXPECT_TRUE(kdo_prepare_func(t, &oemcrypto_core_message)); + EXPECT_EQ(versions.should_accept_request, + kdo_prepare_func(t, &oemcrypto_core_message)); + if (!versions.should_accept_request) return; // verify round-trip works - EXPECT_NO_FATAL_FAILURE(ODK_ExpectEqualBuf(buf, oemcrypto_core_message.data(), - buf_size, extra_fields)); - delete[] buf; - delete[] zero; + EXPECT_NO_FATAL_FAILURE(ODK_ExpectEqualBuf( + buf.data(), oemcrypto_core_message.data(), buf.size(), extra_fields)); } TEST(OdkTest, SerializeFields) { @@ -266,26 +275,26 @@ TEST(OdkTest, SerializeFieldsStress) { total_size += ODK_FieldLength(fields[i].type); } - uint8_t* buf = new uint8_t[total_size]{}; + std::vector buf; + buf.reserve(total_size); for (size_t i = 0; i < total_size; i++) { - buf[i] = std::rand() & 0xff; + buf.push_back(std::rand() & 0xff); } size_t bytes_read = 0, bytes_written = 0; - uint8_t* buf2 = new uint8_t[total_size]{}; - ODK_IterFields(ODK_READ, buf, total_size, &bytes_read, fields); + std::vector buf2(total_size); + ODK_IterFields(ODK_READ, buf.data(), total_size, &bytes_read, fields); EXPECT_EQ(bytes_read, total_size); - ODK_IterFields(ODK_WRITE, buf2, total_size, &bytes_written, fields); + ODK_IterFields(ODK_WRITE, buf2.data(), total_size, &bytes_written, fields); EXPECT_EQ(bytes_written, total_size); - EXPECT_NO_FATAL_FAILURE(ODK_ExpectEqualBuf(buf, buf2, total_size, fields)); + EXPECT_NO_FATAL_FAILURE( + ODK_ExpectEqualBuf(buf.data(), buf2.data(), total_size, fields)); // cleanup for (int i = 0; i < n; i++) { free(fields[i].value); } - delete[] buf; - delete[] buf2; } TEST(OdkTest, NullRequestTest) { @@ -296,17 +305,21 @@ TEST(OdkTest, NullRequestTest) { memset(&clock_values, 0, sizeof(clock_values)); ODK_MessageCounterInfo counter_info; memset(&counter_info, 0, sizeof(counter_info)); + uint16_t decenc_mitigation_options_supported = 0; // Assert that nullptr does not cause a core dump. EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, ODK_PrepareCoreLicenseRequest(nullptr, 0uL, nullptr, &nonce_values, - &counter_info)); + &counter_info, + decenc_mitigation_options_supported)); EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, ODK_PrepareCoreLicenseRequest(nullptr, 0uL, &core_message_length, - nullptr, &counter_info)); + nullptr, &counter_info, + decenc_mitigation_options_supported)); EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, ODK_PrepareCoreLicenseRequest(nullptr, 0uL, &core_message_length, - &nonce_values, nullptr)); + &nonce_values, nullptr, + decenc_mitigation_options_supported)); EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, ODK_PrepareCoreRenewalRequest(nullptr, 0uL, nullptr, &nonce_values, @@ -479,10 +492,12 @@ TEST(OdkTest, PrepareCoreLicenseRequest) { memset(&nonce_values, 0, sizeof(nonce_values)); ODK_MessageCounterInfo counter_info; memset(&counter_info, 0, sizeof(counter_info)); - EXPECT_EQ(OEMCrypto_SUCCESS, - ODK_PrepareCoreLicenseRequest( - license_message, sizeof(license_message), &core_message_length, - &nonce_values, &counter_info)); + uint16_t decenc_mitigation_options_supported = 0; + EXPECT_EQ( + OEMCrypto_SUCCESS, + ODK_PrepareCoreLicenseRequest( + license_message, sizeof(license_message), &core_message_length, + &nonce_values, &counter_info, decenc_mitigation_options_supported)); } TEST(OdkTest, PrepareCoreLicenseRequestSize) { @@ -492,18 +507,21 @@ TEST(OdkTest, PrepareCoreLicenseRequestSize) { memset(&nonce_values, 0, sizeof(nonce_values)); ODK_MessageCounterInfo counter_info; memset(&counter_info, 0, sizeof(counter_info)); + uint16_t decenc_mitigation_options_supported = 0; // message length smaller than core message length size_t core_message_length_invalid = core_message_length + 1; EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, ODK_PrepareCoreLicenseRequest( license_message, sizeof(license_message), - &core_message_length_invalid, &nonce_values, &counter_info)); + &core_message_length_invalid, &nonce_values, &counter_info, + decenc_mitigation_options_supported)); // message length larger than core message length uint8_t license_message_large[ODK_LICENSE_REQUEST_SIZE * 2] = {0}; EXPECT_EQ(OEMCrypto_SUCCESS, ODK_PrepareCoreLicenseRequest( license_message_large, sizeof(license_message_large), - &core_message_length, &nonce_values, &counter_info)); + &core_message_length, &nonce_values, &counter_info, + decenc_mitigation_options_supported)); } TEST(OdkTest, PrepareCoreRenewalRequest) { @@ -641,16 +659,60 @@ TEST(OdkTest, LicenseRequestRoundtrip) { counter_info.major_version = ODK_MAJOR_VERSION; counter_info.minor_version = ODK_MINOR_VERSION; counter_info.patch_version = 4; - memset(counter_info.soc_vendor, 0xff, sizeof(counter_info.soc_vendor)); - memset(counter_info.chipset_model, 0xdd, sizeof(counter_info.chipset_model)); + memset(counter_info.soc_vendor, 'S', sizeof(counter_info.soc_vendor)); + memset(counter_info.chipset_model, 'C', sizeof(counter_info.chipset_model)); memset(counter_info.extra, 0xee, sizeof(counter_info.extra)); + uint16_t decenc_mitigation_options_supported = 0; std::vector extra_fields = { {ODK_MESSAGECOUNTER, &counter_info, "counter_info"}, + {ODK_UINT16, &decenc_mitigation_options_supported, + "decenc_mitigation_options_supported"}, }; auto odk_prepare_func = [&](uint8_t* const buf, size_t* size, ODK_NonceValues* nonce_values) { return ODK_PrepareCoreLicenseRequest(buf, SIZE_MAX, size, nonce_values, - &counter_info); + &counter_info, + decenc_mitigation_options_supported); + }; + auto kdo_parse_func = [&](const std::string& oemcrypto_core_message, + ODK_LicenseRequest* core_license_request) { + bool ok = CoreLicenseRequestFromMessage(oemcrypto_core_message, + core_license_request); + if (!ok) return false; + + ok = CheckCounterInfoIsEqual(&counter_info, + &core_license_request->counter_info); + return ok; + }; + ValidateRequest(ODK_License_Request_Type, extra_fields, + odk_prepare_func, kdo_parse_func); +} + +// Serialize and de-serialize license request that has decenc mitigation +// non-zero value. +TEST(OdkTest, LicenseRequestRoundtripDecencMitigation) { + ODK_MessageCounterInfo counter_info; + counter_info.master_generation_number = 0x12345678abcdffff; + counter_info.provisioning_count = 12; + counter_info.license_count = 50; + counter_info.decrypt_count = 340; + counter_info.major_version = ODK_MAJOR_VERSION; + counter_info.minor_version = ODK_MINOR_VERSION; + counter_info.patch_version = 4; + memset(counter_info.soc_vendor, 'S', sizeof(counter_info.soc_vendor)); + memset(counter_info.chipset_model, 'C', sizeof(counter_info.chipset_model)); + memset(counter_info.extra, 0xee, sizeof(counter_info.extra)); + uint16_t decenc_mitigation_options_supported = 1; + std::vector extra_fields = { + {ODK_MESSAGECOUNTER, &counter_info, "counter_info"}, + {ODK_UINT16, &decenc_mitigation_options_supported, + "decenc_mitigation_options_supported"}, + }; + auto odk_prepare_func = [&](uint8_t* const buf, size_t* size, + ODK_NonceValues* nonce_values) { + return ODK_PrepareCoreLicenseRequest(buf, SIZE_MAX, size, nonce_values, + &counter_info, + decenc_mitigation_options_supported); }; auto kdo_parse_func = [&](const std::string& oemcrypto_core_message, ODK_LicenseRequest* core_license_request) { @@ -692,32 +754,30 @@ TEST(OdkTest, RenewalRequestRoundtrip) { } TEST(OdkTest, ReleaseRequestRoundTrip) { - const uint32_t clock_security_level = 1; - const uint32_t status = 1; - constexpr uint64_t system_time_seconds = 0xBADDCAFE000FF1CE; - uint64_t playback_time = 0xCAFE00000000; - const int64_t seconds_since_license_requested = 1; - const int64_t seconds_since_first_decrypt = - static_cast(system_time_seconds - playback_time); + OEMCrypto_UsageEntryStatus usage_entry_status = OEMCrypto_Active; + uint32_t clock_security_level = 1; + // constexpr uint64_t system_time_seconds = 0xBADDCAFE000FF1CE; + uint64_t seconds_since_license_requested = 0; + uint64_t seconds_since_first_decrypt = 0; + uint32_t pst_length = 1; + uint8_t pst[ODK_PST_LEN_MAX]; + pst[0] = 0xff; + // uint8_t* pst_ptr = pst; ODK_ClockValues clock_values; memset(&clock_values, 0, sizeof(clock_values)); clock_values.time_of_first_decrypt = seconds_since_first_decrypt; - std::vector extra_fields = {}; - auto odk_prepare_func = [&](uint8_t* const buf, size_t* size, - ODK_NonceValues* nonce_values) { - return ODK_PrepareCoreReleaseRequest( - buf, SIZE_MAX, size, nonce_values, status, clock_security_level, - seconds_since_license_requested, seconds_since_first_decrypt, - &clock_values, system_time_seconds); + std::vector extra_fields = { + {ODK_UINT32, &usage_entry_status, "usage_entry_status"}, + {ODK_UINT32, &clock_security_level, "clock_security_level"}, + {ODK_UINT64, &seconds_since_license_requested, + "seconds_since_license_requested"}, + {ODK_UINT64, &seconds_since_first_decrypt, "seconds_since_first_decrypt"}, + {ODK_UINT32, &pst_length, "pst_length"}, + {ODK_UINT8, &pst[0], "pst"}, }; - auto kdo_parse_func = [&](const std::string& oemcrypto_core_message, - ODK_ReleaseRequest* core_release_request) { - bool ok = CoreReleaseRequestFromMessage(oemcrypto_core_message, - core_release_request); - return ok; - }; - ValidateRequest(ODK_Release_Request_Type, extra_fields, - odk_prepare_func, kdo_parse_func); + // TODO(vickymin): Restore the rest of this test after test framework is + // updated and fixed. Some variables above are commented out because they + // aren't used in the test as of now. This will be updated. } TEST(OdkTest, ProvisionRequestRoundtrip) { @@ -729,8 +789,8 @@ TEST(OdkTest, ProvisionRequestRoundtrip) { counter_info.major_version = ODK_MAJOR_VERSION; counter_info.minor_version = ODK_MINOR_VERSION; counter_info.patch_version = 4; - memset(counter_info.soc_vendor, 0xff, sizeof(counter_info.soc_vendor)); - memset(counter_info.chipset_model, 0xdd, sizeof(counter_info.chipset_model)); + memset(counter_info.soc_vendor, 'S', sizeof(counter_info.soc_vendor)); + memset(counter_info.chipset_model, 'C', sizeof(counter_info.chipset_model)); memset(counter_info.extra, 0xee, sizeof(counter_info.extra)); // Fake device_id_length for older servers, since we removed device id from // the v18 request @@ -772,8 +832,8 @@ TEST(OdkTest, ProvisionRequest40Roundtrip) { counter_info.major_version = ODK_MAJOR_VERSION; counter_info.minor_version = ODK_MINOR_VERSION; counter_info.patch_version = 4; - memset(counter_info.soc_vendor, 0xff, sizeof(counter_info.soc_vendor)); - memset(counter_info.chipset_model, 0xdd, sizeof(counter_info.chipset_model)); + memset(counter_info.soc_vendor, 'S', sizeof(counter_info.soc_vendor)); + memset(counter_info.chipset_model, 'C', sizeof(counter_info.chipset_model)); memset(counter_info.extra, 0xee, sizeof(counter_info.extra)); std::vector extra_fields = { {ODK_UINT32, &device_info_length, "device_info_length"}, @@ -837,36 +897,30 @@ TEST(OdkTest, RenewedProvisionRequestRoundtrip) { TEST(OdkTest, ParseLicenseErrorNonce) { ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); // temporarily mess up with nonce params.core_message.nonce_values.nonce = 0; OEMCryptoResult err = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, 0, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, 0, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(OEMCrypto_ERROR_INVALID_NONCE, err); - delete[] buf; } TEST(OdkTest, ParseLicenseErrorUsageEntry) { ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); params.usage_entry_present = false; OEMCryptoResult err = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, 0, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, 0, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, err); - delete[] buf; } TEST(OdkTest, ParseLicenseNullSubstring) { @@ -874,17 +928,14 @@ TEST(OdkTest, ParseLicenseNullSubstring) { ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); params.parsed_license.srm_restriction_data.offset = 0; params.parsed_license.srm_restriction_data.length = 0; - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); OEMCryptoResult result = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, 0, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, 0, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(OEMCrypto_SUCCESS, result); - delete[] buf; } TEST(OdkTest, ParseLicenseErrorSubstringOffset) { @@ -892,52 +943,43 @@ TEST(OdkTest, ParseLicenseErrorSubstringOffset) { ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); params.parsed_license.enc_mac_keys_iv.offset = 1024; - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); OEMCryptoResult err = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, 0, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, 0, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, err); - delete[] buf; // offset + length out of range err = OEMCrypto_SUCCESS; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); - params.parsed_license.enc_mac_keys_iv.length = buf_size; - buf = nullptr; - buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + params.parsed_license.enc_mac_keys_iv.length = buf.size(); + buf.clear(); + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); err = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, 0, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, 0, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(ODK_ERROR_CORE_MESSAGE, err); - delete[] buf; } TEST(OdkTest, ParseRenewalErrorTimer) { ODK_RenewalResponseParams params; ODK_SetDefaultRenewalResponseParams(¶ms); - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); // Set the time for the last renewal request, as seen in clock_values, to be // after the time in the request. // TODO: b/290249855 - This is reversed. It should be +5. params.clock_values.time_of_renewal_request = params.playback_clock - 5; OEMCryptoResult err = ODK_ParseRenewal( - buf, buf_size, buf_size, &(params.core_message.nonce_values), + buf.data(), buf.size(), buf.size(), &(params.core_message.nonce_values), params.system_time, &(params.timer_limits), &(params.clock_values), &(params.playback_timer)); EXPECT_EQ(ODK_STALE_RENEWAL, err); - delete[] buf; } TEST(OdkTest, ProvisionResponseFromProto) { @@ -950,12 +992,11 @@ TEST(OdkTest, ProvisionResponseFromProto) { .nonce = 0xdeadbeef, .session_id = 0xcafebabe, }; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); std::string oemcrypto_core_message; EXPECT_TRUE(CreateCoreProvisioningResponseFromProto( features, serialized_provisioning_resp, core_request, - OEMCrypto_RSA_Private_Key, &oemcrypto_core_message)); + OEMCrypto_RSAPrivateKey, &oemcrypto_core_message)); } // Verify de-serialize common request. @@ -969,17 +1010,16 @@ TEST(OdkTest, ParseCoreCommonRequestFromMessage) { .nonce = 0xdeadbeef, .session_id = 0xcafebabe, }; - const CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + const CoreMessageFeatures features(/*prerelease_odk_messages=*/true); std::string oemcrypto_core_message; EXPECT_TRUE(CreateCoreProvisioningResponseFromProto( features, serialized_provisioning_resp, core_request, - OEMCrypto_RSA_Private_Key, &oemcrypto_core_message)); + OEMCrypto_RSAPrivateKey, &oemcrypto_core_message)); ODK_CommonRequest odk_common_request; ASSERT_TRUE(CoreCommonRequestFromMessage(oemcrypto_core_message, &odk_common_request)); EXPECT_EQ(odk_common_request.message_type, 6u); - EXPECT_EQ(odk_common_request.message_length, 48u); + EXPECT_EQ(odk_common_request.message_length, 60u); EXPECT_EQ(odk_common_request.api_minor_version, ODK_MINOR_VERSION); EXPECT_EQ(odk_common_request.api_major_version, ODK_MAJOR_VERSION); EXPECT_EQ(odk_common_request.nonce, 0xdeadbeef); @@ -992,15 +1032,14 @@ class OdkVersionTest : public ::testing::Test, template void SetRequestVersion(P* params) { params->core_message.nonce_values.api_major_version = - GetParam().response_major_version; + GetParam().response.major_version; params->core_message.nonce_values.api_minor_version = - GetParam().response_minor_version; - if (GetParam().maximum_major_version > 0) { - features_ = CoreMessageFeatures::DefaultFeatures( - GetParam().maximum_major_version); - } else { - features_ = CoreMessageFeatures::kDefaultFeatures; - } + GetParam().response.minor_version; + features_ = CoreMessageFeatures(GetParam().prerelease); + features_.maximum_major_version = + static_cast(GetParam().maximum.major_version); + features_.maximum_minor_version = + static_cast(GetParam().maximum.minor_version); } CoreMessageFeatures features_; }; @@ -1009,7 +1048,7 @@ class OdkVersionTest : public ::testing::Test, TEST_P(OdkVersionTest, LicenseResponseRoundtrip) { ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, - GetParam().response_major_version); + GetParam().response.major_version); SetRequestVersion(¶ms); // For v17, we do not use the hash to verify the request. However, the server // needs to be backwards compatible, so it still needs to pass the hash into @@ -1037,8 +1076,11 @@ TEST_P(OdkVersionTest, LicenseResponseRoundtrip) { parsed_license.watermarking = params.parsed_license.watermarking; parsed_license.dtcp2_required = params.parsed_license.dtcp2_required; parsed_license.renewal_delay_base = params.parsed_license.renewal_delay_base; + parsed_license.decenc_mitigation_option_used = + params.parsed_license.decenc_mitigation_option_used; parsed_license.key_array_length = params.parsed_license.key_array_length; - std::vector key_array; + std::vector key_array; + key_array.reserve(params.parsed_license.key_array_length); for (size_t i = 0; i < params.parsed_license.key_array_length; i++) { key_array.push_back(params.parsed_license.key_array[i]); } @@ -1060,9 +1102,12 @@ TEST_P(OdkVersionTest, LicenseResponseRoundtrip) { // Serialize and de-serialize license response with more keys than // ODK_MAX_NUM_KEYS. TEST_P(OdkVersionTest, LicenseResponseRoundtripMoreThanMaxKeys) { + if (!GetParam().should_accept_request) { + GTEST_SKIP() << "Skipping response check because request would be refused."; + } ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, - GetParam().response_major_version); + GetParam().response.major_version); SetRequestVersion(¶ms); // For v17, we do not use the hash to verify the request. However, the server // needs to be backwards compatible, so it still needs to pass the hash into @@ -1070,29 +1115,27 @@ TEST_P(OdkVersionTest, LicenseResponseRoundtripMoreThanMaxKeys) { // will be zero out during the test uint8_t request_hash_read[ODK_SHA256_HASH_SIZE]; memcpy(request_hash_read, params.request_hash, sizeof(request_hash_read)); - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); - uint8_t* zero = new uint8_t[buf_size]{}; + std::vector zero(buf.size(), 0); size_t bytes_read = 0; // zero-out input EXPECT_EQ(OEMCrypto_SUCCESS, - ODK_IterFields(ODK_READ, zero, buf_size, &bytes_read, + ODK_IterFields(ODK_READ, zero.data(), buf.size(), &bytes_read, params.extra_fields)); // Parse buf with odk const OEMCryptoResult parse_result = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, 0, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, 0, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(OEMCrypto_SUCCESS, parse_result); size_t size_out = 0; if (parse_result != OEMCrypto_SUCCESS) { - ODK_IterFields(ODK_FieldMode::ODK_DUMP, buf, buf_size, &size_out, + ODK_IterFields(ODK_FieldMode::ODK_DUMP, buf.data(), buf.size(), &size_out, params.extra_fields); } @@ -1108,10 +1151,15 @@ TEST_P(OdkVersionTest, LicenseResponseRoundtripMoreThanMaxKeys) { parsed_license.watermarking = params.parsed_license.watermarking; parsed_license.dtcp2_required = params.parsed_license.dtcp2_required; parsed_license.renewal_delay_base = params.parsed_license.renewal_delay_base; + parsed_license.decenc_mitigation_option_used = + params.parsed_license.decenc_mitigation_option_used; parsed_license.key_array_length = ODK_MAX_NUM_KEYS + 1; - std::vector key_array; + std::vector key_array; for (size_t i = 0; i < ODK_MAX_NUM_KEYS + 1; i++) { - OEMCrypto_KeyObject key = {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}; + OEMCrypto_KeyObjectV2 key = { + {0, 0}, {0, 0}, + {0, 0}, {0, 0}, + {0, 0}, {OEMCrypto_DeCENC_Mitigation_Option_None, {{{0, 0}, {0, 0}}}}}; key_array.push_back(key); } parsed_license.key_array = key_array.data(); @@ -1122,20 +1170,76 @@ TEST_P(OdkVersionTest, LicenseResponseRoundtripMoreThanMaxKeys) { // serialize odk output to oemcrypto_core_message std::string oemcrypto_core_message; ODK_LicenseRequest core_request = {}; - core_request.api_major_version = GetParam().request_major_version; - core_request.api_minor_version = GetParam().request_minor_version; + core_request.api_major_version = GetParam().request.major_version; + core_request.api_minor_version = GetParam().request.minor_version; core_request.nonce = params.core_message.nonce_values.nonce; core_request.session_id = params.core_message.nonce_values.session_id; bool result = CreateCoreLicenseResponse(features_, parsed_license, core_request, request_hash_string, &oemcrypto_core_message); - EXPECT_TRUE(result); + EXPECT_EQ(GetParam().should_accept_request, result); +} - delete[] buf; - delete[] zero; +// Serialize and de-serialize license response +TEST_P(OdkVersionTest, LicenseResponseRoundtripDecencMitigation) { + ODK_LicenseResponseParams params; + ODK_SetDefaultLicenseResponseParamsDecencMitigation( + ¶ms, GetParam().response.major_version, + GetParam().response.minor_version); + SetRequestVersion(¶ms); + // For v17, we do not use the hash to verify the request. However, the server + // needs to be backwards compatible, so it still needs to pass the hash into + // CreateCoreLiceseseResponse below. Save a copy of params.request_hash as it + // will be zero out during the test + uint8_t request_hash_read[ODK_SHA256_HASH_SIZE]; + memcpy(request_hash_read, params.request_hash, sizeof(request_hash_read)); + auto odk_parse_func = [&](const uint8_t* buf, size_t size) { + return ODK_ParseLicense( + buf, size + kExtraPayloadSize, size, params.initial_license_load, + params.usage_entry_present, 0, &(params.timer_limits), + &(params.clock_values), &(params.core_message.nonce_values), + &(params.parsed_license), nullptr); + }; + + ODK_Packing_ParsedLicense parsed_license; + parsed_license.enc_mac_keys_iv = params.parsed_license.enc_mac_keys_iv; + parsed_license.enc_mac_keys = params.parsed_license.enc_mac_keys; + parsed_license.pst = params.parsed_license.pst; + parsed_license.srm_restriction_data = + params.parsed_license.srm_restriction_data; + parsed_license.license_type = params.parsed_license.license_type; + parsed_license.nonce_required = params.parsed_license.nonce_required; + parsed_license.timer_limits = params.parsed_license.timer_limits; + parsed_license.watermarking = params.parsed_license.watermarking; + parsed_license.dtcp2_required = params.parsed_license.dtcp2_required; + parsed_license.renewal_delay_base = params.parsed_license.renewal_delay_base; + parsed_license.decenc_mitigation_option_used = + params.parsed_license.decenc_mitigation_option_used; + parsed_license.key_array_length = params.parsed_license.key_array_length; + std::vector key_array; + key_array.reserve(params.parsed_license.key_array_length); + for (size_t i = 0; i < params.parsed_license.key_array_length; i++) { + key_array.push_back(params.parsed_license.key_array[i]); + } + parsed_license.key_array = key_array.data(); + const std::string request_hash_string( + reinterpret_cast(request_hash_read), + sizeof(request_hash_read)); + auto kdo_prepare_func = [&](const ODK_LicenseRequest& core_request, + std::string* oemcrypto_core_message) { + return CreateCoreLicenseResponse(features_, parsed_license, core_request, + request_hash_string, + oemcrypto_core_message); + }; + ValidateResponse(GetParam(), &(params.core_message), + params.extra_fields, odk_parse_func, + kdo_prepare_func); } TEST_P(OdkVersionTest, RenewalResponseRoundtrip) { + if (!GetParam().should_accept_request) { + GTEST_SKIP() << "Skipping renewal test because request would be refused."; + } ODK_RenewalResponseParams params; ODK_SetDefaultRenewalResponseParams(¶ms); SetRequestVersion(¶ms); @@ -1166,23 +1270,20 @@ TEST_P(OdkVersionTest, RenewalResponseRoundtrip) { } TEST_P(OdkVersionTest, ReleaseResponseRoundtrip) { + if (!GetParam().should_accept_request) { + GTEST_SKIP() << "Skipping release test because request would be refused."; + } ODK_ReleaseResponseParams params; ODK_SetDefaultReleaseResponseParams(¶ms); SetRequestVersion(¶ms); - const int64_t seconds_since_license_requested = - params.seconds_since_license_requested; - const int64_t seconds_since_first_decrypt = - params.seconds_since_first_decrypt; auto odk_parse_func = [&](const uint8_t* buf, size_t size) { - OEMCryptoResult err = - ODK_ParseRelease(buf, size, size, &(params.core_message.nonce_values)); - return err; + return ODK_ParseRelease(buf, size, size, + &(params.core_message.nonce_values)); }; auto kdo_prepare_func = [&](ODK_ReleaseRequest& core_request, std::string* oemcrypto_core_message) { - return CreateCoreReleaseResponse( - features_, core_request, seconds_since_license_requested, - seconds_since_first_decrypt, oemcrypto_core_message); + return CreateCoreReleaseResponse(features_, core_request, + oemcrypto_core_message); }; ValidateResponse(GetParam(), &(params.core_message), params.extra_fields, odk_parse_func, @@ -1192,7 +1293,7 @@ TEST_P(OdkVersionTest, ReleaseResponseRoundtrip) { TEST_P(OdkVersionTest, ProvisionResponseRoundtrip) { ODK_ProvisioningResponseParams params; ODK_SetDefaultProvisioningResponseParams(¶ms, - GetParam().response_major_version); + GetParam().response.major_version); SetRequestVersion(¶ms); // save a copy of params.device_id as it will be zero out during the test const uint32_t device_id_length = params.device_id_length; @@ -1200,9 +1301,13 @@ TEST_P(OdkVersionTest, ProvisionResponseRoundtrip) { memcpy(device_id, params.device_id, device_id_length); auto odk_parse_func = [&](const uint8_t* buf, size_t size) { - OEMCryptoResult err = ODK_ParseProvisioning( - buf, size + 16, size, &(params.core_message.nonce_values), device_id, - device_id_length, &(params.parsed_provisioning)); + // This constant represents the additional space needed for the message + // length on top of the core message size. + constexpr size_t kProvResponsePayloadSize = 18; + OEMCryptoResult err = + ODK_ParseProvisioning(buf, size + kProvResponsePayloadSize, size, + &(params.core_message.nonce_values), device_id, + device_id_length, &(params.parsed_provisioning)); return err; }; auto kdo_prepare_func = [&](ODK_ProvisioningRequest& core_request, @@ -1216,6 +1321,7 @@ TEST_P(OdkVersionTest, ProvisionResponseRoundtrip) { return CreateCoreProvisioningResponse(features_, params.parsed_provisioning, core_request, oemcrypto_core_message); }; + ValidateResponse(GetParam(), &(params.core_message), params.extra_fields, odk_parse_func, kdo_prepare_func); @@ -1242,68 +1348,82 @@ TEST_P(OdkVersionTest, Provision40ResponseRoundtrip) { } // If the minor version is positive, we can test an older minor version. -const uint16_t kOldMinor = ODK_MINOR_VERSION > 0 ? ODK_MINOR_VERSION - 1 : 0; +constexpr uint16_t kOldMinor = + ODK_MINOR_VERSION > 0 ? ODK_MINOR_VERSION - 1 : 0; // Similarly, if this isn't the first major version, we can test an older major // version. -const uint16_t kOldMajor = ODK_MAJOR_VERSION > ODK_FIRST_VERSION - ? ODK_MAJOR_VERSION - 1 - : ODK_FIRST_VERSION; +constexpr uint16_t kOldMajor = ODK_MAJOR_VERSION > ODK_FIRST_VERSION + ? ODK_MAJOR_VERSION - 1 + : ODK_FIRST_VERSION; // If there is an older major, then we should accept any minor version. // Otherwise, this test won't make sense and we should just use a minor of 0. -const uint16_t kOldMajorMinor = ODK_MAJOR_VERSION > ODK_FIRST_VERSION ? 42 : 0; +constexpr uint16_t kOldMajorMinor = + ODK_MAJOR_VERSION > ODK_FIRST_VERSION ? 42 : 0; -// List of major and minor versions to test. +// List of major and minor versions to test. This list is a little bit fragile, +// everytime we release a new version the list has to be updated. std::vector TestCases() { + const VersionPair latest(ODK_MAJOR_VERSION, ODK_MINOR_VERSION); + const VersionPair previous(ODK_MAJOR_VERSION, kOldMinor); + const VersionPair next_minor(ODK_MAJOR_VERSION, ODK_MINOR_VERSION + 1); + // The next major release with minor=0 is a prerelease. + const VersionPair next_major_prerelease(ODK_MAJOR_VERSION + 1, 0); + // The next major release with minor>0 is a release. + const VersionPair next_major_release(ODK_MAJOR_VERSION + 1, 42); + const VersionPair previous_major(kOldMajor, kOldMajorMinor); + const VersionPair latest_prerelease(ODK_MAJOR_VERSION, 0); + const VersionPair previous_prerelease(kOldMajor, 0); + const VersionPair latest_released( + CoreMessageFeatures::kDefaultFeatures.maximum_major_version, + CoreMessageFeatures::kDefaultFeatures.maximum_minor_version); + const VersionPair v16(16, 5); + const VersionPair v17(17, 6); + const VersionPair v18(18, 8); + const VersionPair v19(19, 5); + const VersionPair v20_prerelease(20, 0); + const VersionPair v20(20, 1); + std::vector test_cases{ - // Fields: maximum major version, - // request major, request minor, response major, response minor, - {ODK_MAJOR_VERSION, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, - ODK_MAJOR_VERSION, ODK_MINOR_VERSION}, - {ODK_MAJOR_VERSION, ODK_MAJOR_VERSION, ODK_MINOR_VERSION + 1, - ODK_MAJOR_VERSION, ODK_MINOR_VERSION}, - {ODK_MAJOR_VERSION, ODK_MAJOR_VERSION, kOldMinor, ODK_MAJOR_VERSION, - kOldMinor}, - {ODK_MAJOR_VERSION, ODK_MAJOR_VERSION, 0, ODK_MAJOR_VERSION, 0}, - {ODK_MAJOR_VERSION, ODK_MAJOR_VERSION + 1, 42, ODK_MAJOR_VERSION, - ODK_MINOR_VERSION}, - {ODK_MAJOR_VERSION, kOldMajor, 0, kOldMajor, 0}, - {ODK_MAJOR_VERSION, kOldMajor, kOldMajorMinor, kOldMajor, kOldMajorMinor}, - // If the server is restricted to v16, then the response can be at - // most 16.5 - // These tests cases must be updated whenever we roll the minor version - // number. - {16, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 16, 5}, - {17, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 17, 2}, - {18, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 18, 4}, - {19, ODK_MAJOR_VERSION, ODK_MINOR_VERSION, 19, 5}, - // Here are some known good versions. Make extra sure they work. - {ODK_MAJOR_VERSION, 16, 3, 16, 3}, - {ODK_MAJOR_VERSION, 16, 4, 16, 4}, - {ODK_MAJOR_VERSION, 16, 5, 16, 5}, - {ODK_MAJOR_VERSION, 17, 1, 17, 1}, - {ODK_MAJOR_VERSION, 17, 2, 17, 2}, - {ODK_MAJOR_VERSION, 18, 1, 18, 1}, - {ODK_MAJOR_VERSION, 18, 2, 18, 2}, - {ODK_MAJOR_VERSION, 18, 3, 18, 3}, - {ODK_MAJOR_VERSION, 19, 0, 19, 0}, - {ODK_MAJOR_VERSION, 19, 1, 19, 1}, - {ODK_MAJOR_VERSION, 19, 2, 19, 2}, - {ODK_MAJOR_VERSION, 19, 3, 19, 3}, - {ODK_MAJOR_VERSION, 19, 4, 19, 4}, - {ODK_MAJOR_VERSION, 19, 5, 19, 5}, - {0, 16, 3, 16, 3}, - {0, 16, 4, 16, 4}, - {0, 16, 5, 16, 5}, - {0, 17, 1, 17, 1}, - {0, 17, 2, 17, 2}, - {0, 18, 3, 18, 3}, - {0, 18, 4, 18, 4}, - {0, 19, 0, 19, 0}, - {0, 19, 1, 19, 1}, - {0, 19, 2, 19, 2}, - {0, 19, 3, 19, 3}, - {0, 19, 4, 19, 4}, - {0, 19, 5, 19, 5}, + // Fields: maximum, prerelease, request, response, should_accept_request. + {latest, true, latest, latest, true}, + {latest, true, next_minor, latest, true}, + {latest, true, previous, previous, true}, + {latest, true, previous_major, previous_major, true}, + {latest, true, next_minor, latest, true}, + {latest, true, next_major_prerelease, latest, true}, + {latest, true, next_major_release, latest, true}, + {latest, true, latest_prerelease, latest_prerelease, true}, + {latest, true, previous_prerelease, previous_prerelease, true}, + + {v16, true, latest, v16, true}, + {v17, true, latest, v17, true}, + {v18, true, latest, v18, true}, + {v19, true, latest, v19, true}, + + {latest, true, v16, v16, true}, + {latest, true, v17, v17, true}, + {latest, true, v18, v18, true}, + {latest, true, v19, v19, true}, + {latest, true, v20_prerelease, v20_prerelease, true}, + {latest, true, v20, latest, true}, // Change this when v20.1 is released. + + // From here on, the serve is restricted to the released version. + {latest_released, false, v16, v16, true}, + {latest_released, false, v17, v17, true}, + {latest_released, false, v18, v18, true}, + {latest_released, false, v19, v19, true}, + // TODO(b/403381396): change the response version to v20 after v20 has + // been released. + {latest_released, false, v20, latest_released, true}, + {latest_released, false, next_minor, latest_released, true}, + + // Prerelease requests should not be accepted. + {latest_released, false, v20_prerelease, v20_prerelease, false}, + {latest_released, false, latest_prerelease, latest_prerelease, false}, + + // Future pre-release and release are OK, for back compat testing. + {latest_released, false, next_major_prerelease, latest_released, true}, + {latest_released, false, next_major_release, latest_released, true}, }; return test_cases; } @@ -1321,12 +1441,13 @@ TEST(OdkSizeTest, LicenseRequest) { uint32_t session_id = 0; ODK_MessageCounterInfo counter_info; memset(&counter_info, 0, sizeof(counter_info)); + uint16_t decenc_mitigation_options_supported = 0; ODK_NonceValues nonce_values{api_minor_version, api_major_version, nonce, session_id}; EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, - ODK_PrepareCoreLicenseRequest(message, message_length, - &core_message_length, &nonce_values, - &counter_info)); + ODK_PrepareCoreLicenseRequest( + message, message_length, &core_message_length, &nonce_values, + &counter_info, decenc_mitigation_options_supported)); // the core_message_length should be appropriately set if (nonce_values.api_major_version > 17) { EXPECT_EQ(ODK_LICENSE_REQUEST_SIZE, core_message_length); diff --git a/oemcrypto/odk/test/odk_test_helper.cpp b/oemcrypto/odk/test/odk_test_helper.cpp index 184515d..24ffa7d 100644 --- a/oemcrypto/odk/test/odk_test_helper.cpp +++ b/oemcrypto/odk/test/odk_test_helper.cpp @@ -18,6 +18,7 @@ #include "OEMCryptoCENCCommon.h" #include "gtest/gtest.h" #include "odk_endian.h" +#include "odk_serialize.h" #include "odk_structs.h" #include "odk_structs_priv.h" @@ -78,7 +79,8 @@ void ODK_SetDefaultLicenseResponseParams(ODK_LicenseResponseParams* params, .length = 3, .data = {0, 0, 0}, }}, - .renewal_delay_base = OEMCrypto_License_Start, + .renewal_delay_base = OEMCrypto_LicenseStart, + .decenc_mitigation_option_used = 0, .key_array_length = 3, .key_array = { @@ -212,7 +214,10 @@ void ODK_SetDefaultLicenseResponseParams(ODK_LicenseResponseParams* params, {ODK_UINT32, &(params->parsed_license.renewal_delay_base), ".renewal_delay_base"}); } - params->extra_fields.push_back({ODK_UINT32, + params->extra_fields.push_back( + {ODK_UINT16, &(params->parsed_license.decenc_mitigation_option_used), + ".decenc_mitigation_option_used"}); + params->extra_fields.push_back({ODK_UINT16, &(params->parsed_license.key_array_length), ".key_array_length"}); params->extra_fields.push_back({ODK_SUBSTRING, @@ -266,12 +271,337 @@ void ODK_SetDefaultLicenseResponseParams(ODK_LicenseResponseParams* params, } } +void ODK_SetDefaultLicenseResponseParamsDecencMitigation( + ODK_LicenseResponseParams* params, uint32_t odk_major_version, + uint32_t odk_minor_version) { + ODK_SetDefaultCoreFields(&(params->core_message), ODK_License_Response_Type); + params->initial_license_load = true; + params->usage_entry_present = true; + params->parsed_license = { + .enc_mac_keys_iv = {.offset = 0, .length = 1}, + .enc_mac_keys = {.offset = 2, .length = 3}, + .pst = {.offset = 4, .length = 5}, + .srm_restriction_data = {.offset = 6, .length = 7}, + .license_type = OEMCrypto_EntitlementLicense, + .nonce_required = true, + .timer_limits = + { + .soft_enforce_rental_duration = true, + .soft_enforce_playback_duration = false, + .earliest_playback_start_seconds = 10, + .rental_duration_seconds = 11, + .total_playback_duration_seconds = 12, + .initial_renewal_duration_seconds = 13, + }, + .watermarking = 0, + .dtcp2_required = {.dtcp2_required = 0, + .cmi_descriptor_0 = + { + .id = 0, + .extension = 0, + .length = 1, + .data = 0, + }, + .cmi_descriptor_1 = + { + .id = 1, + .extension = 0, + .length = 3, + .data = {0, 0, 0}, + }, + .cmi_descriptor_2 = + { + .id = 2, + .extension = 0, + .length = 3, + .data = {0, 0, 0}, + }}, + .renewal_delay_base = OEMCrypto_LicenseStart, + .decenc_mitigation_option_used = ODK_VersionSupportsDecencMitigation( + odk_major_version, odk_minor_version), + .key_array_length = 3, + .key_array = + { + { + .key_id = {.offset = 15, .length = 16}, + .key_data_iv = {.offset = 17, .length = 18}, + .key_data = {.offset = 19, .length = 20}, + .key_control_iv = {.offset = 21, .length = 22}, + .key_control = {.offset = 23, .length = 24}, + }, + { + .key_id = {.offset = 25, .length = 26}, + .key_data_iv = {.offset = 27, .length = 28}, + .key_data = {.offset = 29, .length = 30}, + .key_control_iv = {.offset = 31, .length = 32}, + .key_control = {.offset = 33, .length = 34}, + }, + { + .key_id = {.offset = 35, .length = 36}, + .key_data_iv = {.offset = 37, .length = 38}, + .key_data = {.offset = 39, .length = 40}, + .key_control_iv = {.offset = 41, .length = 42}, + .key_control = {.offset = 43, .length = 44}, + }, + }, + }; + // Set decenc mitigation info if decenc mitigation is requested by the content + // provider. + if (params->parsed_license.decenc_mitigation_option_used != 0) { + auto& key_array = params->parsed_license.key_array; + for (uint16_t i = 0; i < params->parsed_license.key_array_length; i++) { + key_array[i].decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream; + auto& auth_key_info = key_array[i] + .decenc_mitigation_info.configuration_options + .authentication_key_info; + auth_key_info.authentication_key = + (OEMCrypto_Substring){.offset = 25, .length = 26}; + auth_key_info.authentication_key_iv = + (OEMCrypto_Substring){.offset = 27, .length = 28}; + } + } + memset(params->request_hash, 0xaa, sizeof(params->request_hash)); + params->extra_fields = { + {ODK_SUBSTRING, &(params->parsed_license.enc_mac_keys_iv), + ".enc_mac_keys_iv"}, + {ODK_SUBSTRING, &(params->parsed_license.enc_mac_keys), ".enc_mac_keys"}, + {ODK_SUBSTRING, &(params->parsed_license.pst), ".pst"}, + {ODK_SUBSTRING, &(params->parsed_license.srm_restriction_data), + ".srm_restriction_data"}, + {ODK_UINT32, &(params->parsed_license.license_type), ".license_type"}, + {ODK_UINT32, &(params->parsed_license.nonce_required), ".nonce_required"}, + {ODK_BOOL, + &(params->parsed_license.timer_limits.soft_enforce_rental_duration), + ".soft_enforce_rental_duration"}, + {ODK_BOOL, + &(params->parsed_license.timer_limits.soft_enforce_playback_duration), + ".soft_enforce_playback_duration"}, + {ODK_UINT64, + &(params->parsed_license.timer_limits.earliest_playback_start_seconds), + ".earliest_playback_start_seconds"}, + {ODK_UINT64, + &(params->parsed_license.timer_limits.rental_duration_seconds), + ".rental_duration_seconds"}, + {ODK_UINT64, + &(params->parsed_license.timer_limits.total_playback_duration_seconds), + ".total_playback_duration_seconds"}, + {ODK_UINT64, + &(params->parsed_license.timer_limits.initial_renewal_duration_seconds), + ".initial_renewal_duration_seconds"}, + }; + if (odk_major_version >= 17) { + params->extra_fields.push_back( + {ODK_UINT32, &(params->parsed_license.watermarking), ".watermarking"}); + params->extra_fields.push_back( + {ODK_UINT8, &(params->parsed_license.dtcp2_required.dtcp2_required), + ".dtcp2_required"}); + if (params->parsed_license.dtcp2_required.dtcp2_required) { + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_0.id), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_0.extension), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT16, + &(params->parsed_license.dtcp2_required.cmi_descriptor_0.length), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_0.data), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_1.id), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_1.extension), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT16, + &(params->parsed_license.dtcp2_required.cmi_descriptor_1.length), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_1.data[0]), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_1.data[1]), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_1.data[2]), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_2.id), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_2.extension), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT16, + &(params->parsed_license.dtcp2_required.cmi_descriptor_2.length), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_2.data[0]), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_2.data[1]), + ".cmi_descriptor_data"}); + params->extra_fields.push_back( + {ODK_UINT8, + &(params->parsed_license.dtcp2_required.cmi_descriptor_2.data[2]), + ".cmi_descriptor_data"}); + } + } + if (odk_major_version >= 18) { + params->extra_fields.push_back( + {ODK_UINT32, &(params->parsed_license.renewal_delay_base), + ".renewal_delay_base"}); + } + params->extra_fields.push_back( + {ODK_UINT16, &(params->parsed_license.decenc_mitigation_option_used), + ".decenc_mitigation_option_used"}); + params->extra_fields.push_back({ODK_UINT16, + &(params->parsed_license.key_array_length), + ".key_array_length"}); + params->extra_fields.push_back({ODK_SUBSTRING, + &(params->parsed_license.key_array[0].key_id), + ".key_id"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[0].key_data_iv), + ".key_data_iv"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[0].key_data), + ".key_data"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[0].key_control_iv), + ".key_control_iv"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[0].key_control), + ".key_control"}); + if (params->parsed_license.decenc_mitigation_option_used != 0 && + ODK_VersionSupportsDecencMitigation(odk_major_version, + odk_minor_version)) { + params->extra_fields.push_back( + {ODK_UINT32, + &(params->parsed_license.key_array[0] + .decenc_mitigation_info.mitigation_option), + ".mitigation_option"}); + if (params->parsed_license.key_array[0] + .decenc_mitigation_info.mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream) { + params->extra_fields.push_back( + {ODK_SUBSTRING, + &(params->parsed_license.key_array[0] + .decenc_mitigation_info.configuration_options + .authentication_key_info.authentication_key), + ".authentication_key"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, + &(params->parsed_license.key_array[0] + .decenc_mitigation_info.configuration_options + .authentication_key_info.authentication_key_iv), + ".authentication_key_iv"}); + } + } + params->extra_fields.push_back({ODK_SUBSTRING, + &(params->parsed_license.key_array[1].key_id), + ".key_id"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[1].key_data_iv), + ".key_data_iv"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[1].key_data), + ".key_data"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[1].key_control_iv), + ".key_control_iv"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[1].key_control), + ".key_control"}); + if (params->parsed_license.decenc_mitigation_option_used != 0 && + ODK_VersionSupportsDecencMitigation(odk_major_version, + odk_minor_version)) { + params->extra_fields.push_back( + {ODK_UINT32, + &(params->parsed_license.key_array[1] + .decenc_mitigation_info.mitigation_option), + ".mitigation_option"}); + if (params->parsed_license.key_array[1] + .decenc_mitigation_info.mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream) { + params->extra_fields.push_back( + {ODK_SUBSTRING, + &(params->parsed_license.key_array[1] + .decenc_mitigation_info.configuration_options + .authentication_key_info.authentication_key), + ".authentication_key"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, + &(params->parsed_license.key_array[1] + .decenc_mitigation_info.configuration_options + .authentication_key_info.authentication_key_iv), + ".authentication_key_iv"}); + } + } + params->extra_fields.push_back({ODK_SUBSTRING, + &(params->parsed_license.key_array[2].key_id), + ".key_id"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[2].key_data_iv), + ".key_data_iv"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[2].key_data), + ".key_data"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[2].key_control_iv), + ".key_control_iv"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_license.key_array[2].key_control), + ".key_control"}); + if (params->parsed_license.decenc_mitigation_option_used != 0 && + ODK_VersionSupportsDecencMitigation(odk_major_version, + odk_minor_version)) { + params->extra_fields.push_back( + {ODK_UINT32, + &(params->parsed_license.key_array[2] + .decenc_mitigation_info.mitigation_option), + ".mitigation_option"}); + if (params->parsed_license.key_array[2] + .decenc_mitigation_info.mitigation_option == + OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream) { + params->extra_fields.push_back( + {ODK_SUBSTRING, + &(params->parsed_license.key_array[2] + .decenc_mitigation_info.configuration_options + .authentication_key_info.authentication_key), + ".authentication_key"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, + &(params->parsed_license.key_array[2] + .decenc_mitigation_info.configuration_options + .authentication_key_info.authentication_key_iv), + ".authentication_key_iv"}); + } + } + if (odk_major_version == 16) { + params->extra_fields.push_back( + {ODK_HASH, params->request_hash, ".request_hash"}); + } +} + void ODK_SetDefaultReleaseResponseParams(ODK_ReleaseResponseParams* params) { ODK_SetDefaultCoreFields(&(params->core_message), ODK_Release_Response_Type); - params->status = kActive; - params->clock_security_level = 0; - params->seconds_since_license_requested = 0; - params->seconds_since_first_decrypt = 0; } void ODK_SetDefaultRenewalResponseParams(ODK_RenewalResponseParams* params) { @@ -300,7 +630,7 @@ void ODK_SetDefaultRenewalResponseParams(ODK_RenewalResponseParams* params) { .time_of_renewal_request = params->playback_clock, .time_when_timer_expires = params->system_time + params->playback_timer, .timer_status = ODK_CLOCK_TIMER_STATUS_ACTIVE, - .status = kActive, + .status = OEMCrypto_Active, }; } @@ -313,10 +643,14 @@ void ODK_SetDefaultProvisioningResponseParams( memset(params->device_id + params->device_id_length, 0, ODK_DEVICE_ID_LEN_MAX - params->device_id_length); params->parsed_provisioning = { - .key_type = OEMCrypto_RSA_Private_Key, + .key_type = OEMCrypto_RSAPrivateKey, .enc_private_key = {.offset = 0, .length = 1}, .enc_private_key_iv = {.offset = 2, .length = 3}, .encrypted_message_key = {.offset = 4, .length = 5}, + .curr_server_sealing_key = {.offset = 6, .length = 7}, + .server_sealing_key_array_length = 1, + .server_sealing_key_array = {(OEMCrypto_Substring){.offset = 8, + .length = 9}}, }; params->extra_fields = {}; @@ -339,6 +673,20 @@ void ODK_SetDefaultProvisioningResponseParams( params->extra_fields.push_back( {ODK_SUBSTRING, &(params->parsed_provisioning).encrypted_message_key, "encrypted_message_key"}); + + if (odk_major_version >= 20) { + params->extra_fields.push_back( + {ODK_SUBSTRING, &(params->parsed_provisioning).curr_server_sealing_key, + "curr_server_sealing_key"}); + params->extra_fields.push_back( + {ODK_UINT32, + &(params->parsed_provisioning).server_sealing_key_array_length, + "server_sealing_key_array_length"}); + params->extra_fields.push_back( + {ODK_SUBSTRING, + &(params->parsed_provisioning).server_sealing_key_array[0], + "server_sealing_key"}); + } } void ODK_SetDefaultProvisioning40ResponseParams( @@ -358,8 +706,6 @@ size_t ODK_FieldLength(ODK_FieldType type) { return sizeof(uint32_t); case ODK_UINT64: return sizeof(uint64_t); - case ODK_INT64: - return sizeof(uint64_t); case ODK_SUBSTRING: return sizeof(uint32_t) + sizeof(uint32_t); case ODK_DEVICEID: @@ -414,12 +760,6 @@ OEMCryptoResult ODK_WriteSingleField(uint8_t* buf, const ODK_Field* field) { memcpy(buf, &u64, sizeof(u64)); break; } - case ODK_INT64: { - const int64_t i64 = - oemcrypto_htobe64(*static_cast(field->value)); - memcpy(buf, &i64, sizeof(i64)); - break; - } case ODK_BOOL: { const bool value = *static_cast(field->value); const uint32_t u32 = oemcrypto_htobe32(value ? 1 : 0); @@ -499,12 +839,6 @@ OEMCryptoResult ODK_ReadSingleField(const uint8_t* buf, *u64p = oemcrypto_be64toh(*u64p); break; } - case ODK_INT64: { - memcpy(field->value, buf, sizeof(int64_t)); - int64_t* i64p = static_cast(field->value); - *i64p = oemcrypto_be64toh(*i64p); - break; - } case ODK_BOOL: { uint32_t value; memcpy(&value, buf, sizeof(uint32_t)); @@ -624,19 +958,13 @@ OEMCryptoResult ODK_DumpSingleField(const uint8_t* buf, << "\n"; break; } - case ODK_INT64: { - int64_t val; - memcpy(&val, buf, sizeof(int64_t)); - val = oemcrypto_be64toh(val); - std::cerr << field->name << ": " << val << " = 0x" << std::hex << val - << "\n"; - break; - } case ODK_SUBSTRING: { uint32_t off = 0; uint32_t len = 0; memcpy(&off, buf, sizeof(off)); memcpy(&len, buf + sizeof(off), sizeof(len)); + off = oemcrypto_be32toh(off); + len = oemcrypto_be32toh(len); std::cerr << field->name << ": (off=" << off << ", len=" << len << ")\n"; break; } @@ -723,13 +1051,13 @@ std::vector ODK_MakeTotalFields( } // Expect the two buffers of size n to be equal. If not, dump the messages. -void ODK_ExpectEqualBuf(const void* s1, const void* s2, size_t n, +void ODK_ExpectEqualBuf(const void* expected, const void* actual, size_t n, const std::vector& fields) { - if (memcmp(s1, s2, n) != 0) { + if (memcmp(expected, actual, n) != 0) { ODK_CoreMessage core_message; std::vector total_fields = ODK_MakeTotalFields(fields, &core_message); - const void* buffers[] = {s1, s2}; + const void* buffers[] = {expected, actual}; for (int i = 0; i < 2; i++) { char _tmp[] = "/tmp/fileXXXXXX"; const int temp_fd = mkstemp(_tmp); @@ -744,7 +1072,8 @@ void ODK_ExpectEqualBuf(const void* s1, const void* s2, size_t n, out.write(static_cast(buffers[i]), n); out.close(); std::cerr << '\n' - << "Message buffer " << i << " dumped to " << tmp << '\n'; + << "Message buffer " << i << (i == 0 ? " expected" : " actual") + << " dumped to " << tmp << '\n'; size_t bytes_written; uint8_t* buf = const_cast(reinterpret_cast(buffers[i])); @@ -768,24 +1097,25 @@ void ODK_ResetOdkFields(std::vector* fields) { void ODK_BuildMessageBuffer(ODK_CoreMessage* core_message, const std::vector& extra_fields, - uint8_t** buf, uint32_t* buf_size) { + std::vector* buf) { ASSERT_TRUE(core_message != nullptr); - ASSERT_TRUE(buf_size != nullptr); + ASSERT_TRUE(buf != nullptr); std::vector total_fields = ODK_MakeTotalFields(extra_fields, core_message); + uint32_t buf_size = 0; for (auto& field : total_fields) { - *buf_size += ODK_FieldLength(field.type); + buf_size += ODK_FieldLength(field.type); } // update message_size - *(reinterpret_cast(total_fields[1].value)) = *buf_size; + *(reinterpret_cast(total_fields[1].value)) = buf_size; - *buf = new uint8_t[*buf_size]{}; + buf->resize(buf_size); size_t bytes_written = 0; // serialize ODK fields to message buffer - EXPECT_EQ(OEMCrypto_SUCCESS, ODK_IterFields(ODK_WRITE, *buf, SIZE_MAX, + EXPECT_EQ(OEMCrypto_SUCCESS, ODK_IterFields(ODK_WRITE, buf->data(), SIZE_MAX, &bytes_written, total_fields)); - EXPECT_EQ(bytes_written, *buf_size); + EXPECT_EQ(bytes_written, buf_size); } } // namespace wvodk_test diff --git a/oemcrypto/odk/test/odk_test_helper.h b/oemcrypto/odk/test/odk_test_helper.h index 5005b94..f52bb7d 100644 --- a/oemcrypto/odk/test/odk_test_helper.h +++ b/oemcrypto/odk/test/odk_test_helper.h @@ -20,7 +20,6 @@ enum ODK_FieldType { ODK_UINT16, ODK_UINT32, ODK_UINT64, - ODK_INT64, ODK_SUBSTRING, ODK_DEVICEID, ODK_DEVICEINFO, @@ -62,10 +61,6 @@ struct ODK_LicenseResponseParams { struct ODK_ReleaseResponseParams { ODK_CoreMessage core_message; - uint32_t status; - uint32_t clock_security_level; - int64_t seconds_since_license_requested; - int64_t seconds_since_first_decrypt; std::vector extra_fields; }; @@ -99,6 +94,9 @@ void ODK_SetDefaultCoreFields(ODK_CoreMessage* core_message, ODK_MessageType message_type); void ODK_SetDefaultLicenseResponseParams(ODK_LicenseResponseParams* params, uint32_t odk_major_version); +void ODK_SetDefaultLicenseResponseParamsDecencMitigation( + ODK_LicenseResponseParams* params, uint32_t odk_major_version, + uint32_t odk_minor_version); void ODK_SetDefaultReleaseResponseParams(ODK_ReleaseResponseParams* params); void ODK_SetDefaultRenewalResponseParams(ODK_RenewalResponseParams* params); void ODK_SetDefaultProvisioningResponseParams( @@ -124,7 +122,7 @@ void ODK_ResetOdkFields(std::vector* fields); // Serialize core_message and extra_fields into buf void ODK_BuildMessageBuffer(ODK_CoreMessage* core_message, const std::vector& extra_fields, - uint8_t** buf, uint32_t* buf_size); + std::vector* buf); } // namespace wvodk_test diff --git a/oemcrypto/odk/test/odk_timer_test.cpp b/oemcrypto/odk/test/odk_timer_test.cpp index 1190cf0..692d62a 100644 --- a/oemcrypto/odk/test/odk_timer_test.cpp +++ b/oemcrypto/odk/test/odk_timer_test.cpp @@ -3,6 +3,9 @@ * License Agreement. */ +#include +#include + #include "OEMCryptoCENCCommon.h" #include "gtest/gtest.h" #include "odk.h" @@ -35,7 +38,7 @@ TEST(OdkTimerBasicTest, ParseLicenseTimerSet) { // playback timer is successfully started ::wvodk_test::ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); - params.parsed_license.renewal_delay_base = OEMCrypto_License_Load; + params.parsed_license.renewal_delay_base = OEMCrypto_LicenseLoad; params.parsed_license.timer_limits.soft_enforce_rental_duration = false; params.parsed_license.timer_limits.soft_enforce_playback_duration = false; params.parsed_license.timer_limits.earliest_playback_start_seconds = 10; @@ -46,47 +49,41 @@ TEST(OdkTimerBasicTest, ParseLicenseTimerSet) { ODK_InitializeClockValues(¶ms.clock_values, kSystemTime); EXPECT_EQ(OEMCrypto_SUCCESS, result); params.clock_values.time_of_license_request_signed = 5; - params.clock_values.status = kActive; + params.clock_values.status = OEMCrypto_Active; - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); result = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, kSystemTime, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, kSystemTime, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(ODK_SET_TIMER, result); - delete[] buf; } TEST(OdkTimerBasicTest, ParseLicenseTimerDisabled) { // playback timer is successfully started ::wvodk_test::ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); - params.parsed_license.renewal_delay_base = OEMCrypto_License_Load; + params.parsed_license.renewal_delay_base = OEMCrypto_LicenseLoad; params.parsed_license.timer_limits.soft_enforce_rental_duration = true; params.parsed_license.timer_limits.earliest_playback_start_seconds = 3; params.parsed_license.timer_limits.total_playback_duration_seconds = 0; params.parsed_license.timer_limits.initial_renewal_duration_seconds = 0; params.clock_values.time_of_first_decrypt = 10; params.clock_values.time_of_license_request_signed = 5; - params.clock_values.status = kActive; + params.clock_values.status = OEMCrypto_Active; - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); OEMCryptoResult result = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, kSystemTime, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, kSystemTime, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(ODK_DISABLE_TIMER, result); - delete[] buf; } TEST(OdkTimerBasicTest, ParseRenewalTimerExpired) { @@ -94,7 +91,7 @@ TEST(OdkTimerBasicTest, ParseRenewalTimerExpired) { ::wvodk_test::ODK_LicenseResponseParams params; ODK_SetDefaultLicenseResponseParams(¶ms, ODK_MAJOR_VERSION); - params.parsed_license.renewal_delay_base = OEMCrypto_License_Load; + params.parsed_license.renewal_delay_base = OEMCrypto_LicenseLoad; params.parsed_license.timer_limits.rental_duration_seconds = 5; params.parsed_license.timer_limits.earliest_playback_start_seconds = 3; OEMCryptoResult result = @@ -102,18 +99,15 @@ TEST(OdkTimerBasicTest, ParseRenewalTimerExpired) { EXPECT_EQ(OEMCrypto_SUCCESS, result); params.clock_values.time_of_license_request_signed = 5; - uint8_t* buf = nullptr; - uint32_t buf_size = 0; - ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf, - &buf_size); + std::vector buf; + ODK_BuildMessageBuffer(&(params.core_message), params.extra_fields, &buf); result = ODK_ParseLicense( - buf, buf_size + kExtraPayloadSize, buf_size, params.initial_license_load, - params.usage_entry_present, kSystemTime, &(params.timer_limits), - &(params.clock_values), &(params.core_message.nonce_values), - &(params.parsed_license), nullptr); + buf.data(), buf.size() + kExtraPayloadSize, buf.size(), + params.initial_license_load, params.usage_entry_present, kSystemTime, + &(params.timer_limits), &(params.clock_values), + &(params.core_message.nonce_values), &(params.parsed_license), nullptr); EXPECT_EQ(ODK_TIMER_EXPIRED, result); - delete[] buf; } } // namespace wvodk_test @@ -121,7 +115,7 @@ TEST(OdkTimerBasicTest, ParseRenewalTimerExpired) { TEST(OdkTimerBasicTest, NullTest) { // Assert that nullptr does not cause a core dump. ODK_InitializeClockValues(nullptr, 0u); - ODK_ReloadClockValues(nullptr, 0u, 0u, 0u, kActive, 0u); + ODK_ReloadClockValues(nullptr, 0u, 0u, 0u, OEMCrypto_Active, 0u); ODK_AttemptFirstPlayback(0u, nullptr, nullptr, nullptr); ODK_UpdateLastPlaybackTime(0, nullptr, nullptr); ASSERT_TRUE(true); @@ -139,7 +133,7 @@ TEST(OdkTimerBasicTest, Init) { EXPECT_EQ(clock_values.time_when_timer_expires, 0u); EXPECT_EQ(clock_values.timer_status, ODK_CLOCK_TIMER_STATUS_LICENSE_NOT_LOADED); - EXPECT_EQ(clock_values.status, kUnused); + EXPECT_EQ(clock_values.status, OEMCrypto_Unused); } TEST(OdkTimerBasicTest, Reload) { @@ -151,7 +145,7 @@ TEST(OdkTimerBasicTest, Reload) { uint64_t lic_signed = 1u; uint64_t first_decrypt = 2u; uint64_t last_decrypt = 3u; - enum OEMCrypto_Usage_Entry_Status status = kInactiveUsed; + enum OEMCrypto_UsageEntryStatus status = OEMCrypto_InactiveUsed; ODK_ReloadClockValues(&clock_values, lic_signed, first_decrypt, last_decrypt, status, time); EXPECT_EQ(clock_values.time_of_license_request_signed, lic_signed); @@ -247,7 +241,7 @@ class ODKTimerTest : public ::testing::Test { const OEMCryptoResult result = ODK_AttemptFirstPlayback( start, &timer_limits_, &clock_values_, &timer_value); // After first playback, the license is active. - EXPECT_EQ(clock_values_.status, kActive); + EXPECT_EQ(clock_values_.status, OEMCrypto_Active); EXPECT_EQ(clock_values_.time_when_timer_expires, cutoff); if (cutoff > 0) { // If we expect the timer to be set. EXPECT_EQ(result, ODK_SET_TIMER); @@ -324,7 +318,7 @@ class ODKTimerTest : public ::testing::Test { EXPECT_EQ(clock_values_.time_of_license_request_signed, kRentalClockStart); EXPECT_EQ(clock_values_.time_of_first_decrypt, start_of_playback_); EXPECT_EQ(clock_values_.time_of_last_decrypt, time_of_last_decrypt); - EXPECT_EQ(clock_values_.status, kActive); + EXPECT_EQ(clock_values_.status, OEMCrypto_Active); } // Convert from rental time to system time. By "system time", we mean @@ -369,7 +363,7 @@ TEST_F(ODKTimerTest, EarlyTest) { // We use the TIMER_EXPIRED error to mean both early or late. ForbidPlayback(bad_start_time); // And times were not updated: - EXPECT_EQ(clock_values_.status, kUnused); + EXPECT_EQ(clock_values_.status, OEMCrypto_Unused); // This is when we will attempt the first valid playback. start_of_playback_ = GetSystemTime(timer_limits_.earliest_playback_start_seconds + 10); @@ -928,7 +922,7 @@ class RenewalTest : public ODKTimerTest { &timer_limits_, &clock_values_, start, renewal_duration_seconds, timer_value_pointer); // After first playback, the license is active. - EXPECT_EQ(clock_values_.status, kActive); + EXPECT_EQ(clock_values_.status, OEMCrypto_Active); EXPECT_EQ(clock_values_.time_when_timer_expires, cutoff); if (cutoff > 0) { // If we expect the timer to be set. EXPECT_EQ(result, ODK_SET_TIMER); diff --git a/oemcrypto/oemcrypto_security_tests.gyp b/oemcrypto/oemcrypto_security_tests.gyp index 98feb27..6deb671 100644 --- a/oemcrypto/oemcrypto_security_tests.gyp +++ b/oemcrypto/oemcrypto_security_tests.gyp @@ -29,15 +29,14 @@ '<(platform_specific_dir)/log.cpp', '<(util_dir)/src/cdm_random.cpp', '<(util_dir)/src/platform.cpp', - '<(util_dir)/src/rw_lock.cpp', '<(util_dir)/src/string_conversions.cpp', + '<(util_dir)/src/string_utils.cpp', '<(util_dir)/test/test_sleep.cpp', '<(util_dir)/test/test_clock.cpp', ], 'includes': [ '../util/libssl_dependency.gypi', 'test/oemcrypto_security_tests.gypi', - 'util/oec_ref_util.gypi', 'util/oec_ref_util_unittests.gypi', ], 'libraries': [ @@ -46,6 +45,8 @@ 'dependencies': [ '<(gtest_dependency)', '<(gmock_dependency)', + '<(oemcrypto_dir)/util/wvcrc32.gyp:libwvcrc32', + '<(oemcrypto_dir)/util/build.gyp:liboec_ref_util', ], }, ], diff --git a/oemcrypto/oemcrypto_unittests.gyp b/oemcrypto/oemcrypto_unittests.gyp index 3985423..83eddf7 100644 --- a/oemcrypto/oemcrypto_unittests.gyp +++ b/oemcrypto/oemcrypto_unittests.gyp @@ -17,8 +17,6 @@ 'platform_specific_dir': 'drm_private_key->signature_size; result = ODK_PrepareCoreLicenseRequest( message, message_length, core_message_length, - &session_context->nonce_values, &g_counter_info); + &session_context->nonce_values, &g_counter_info, 0); if (*signature_length < required_signature_size || result == OEMCrypto_ERROR_SHORT_BUFFER) { *signature_length = required_signature_size; /* The core_message_length has been correctly set by - * ODK_PrepareCoreLicenseRequest, but the message and signature buffer will - * be initialized and filled in the subsequent call. */ + * ODK_PrepareCoreLicenseRequest, but the message and signature buffer + * will be initialized and filled in the subsequent call. */ return OEMCrypto_ERROR_SHORT_BUFFER; } if (result != OEMCrypto_SUCCESS) { @@ -1076,8 +1077,8 @@ OEMCryptoResult OEMCrypto_PrepAndSignLicenseRequest( message_or_hash_length = sizeof(message_hash); } - /* For backwards compatibility, we only sign the message body, and we compute - * a SHA256 of the core message. */ + /* For backwards compatibility, we only sign the message body, and we + * compute a SHA256 of the core message. */ result = WTPI_C1_SHA256(message, *core_message_length, session_context->license_request_hash); if (result != OEMCrypto_SUCCESS) { @@ -1130,8 +1131,9 @@ OEMCryptoResult OEMCrypto_PrepAndSignReleaseRequest( } ABORT_IF(session_context == NULL, "OPKI_GetSession() provided invalid output."); - // If we are talking to an old license server, then call the renewal function. - if (session_context->nonce_values.api_major_version < 19) { + // If we are talking to an old license server, then call the renewal + // function. + if (session_context->nonce_values.api_major_version < 20) { return OEMCrypto_PrepAndSignRenewalRequest(session, message, message_length, core_message_length, signature, signature_length); @@ -1141,9 +1143,10 @@ OEMCryptoResult OEMCrypto_PrepAndSignReleaseRequest( if (result != OEMCrypto_SUCCESS) return result; RETURN_INVALID_CONTEXT_IF_NULL(core_message_length); RETURN_INVALID_CONTEXT_IF_NULL(signature_length); - /* If we have signed a request, but have not loaded it, something is wrong. On - * the other hand, we can sign a license release using the mac keys from the - * usage table. So it is OK if we have never signed a license request. */ + /* If we have signed a request, but have not loaded it, something is wrong. + * On the other hand, we can sign a license release using the mac keys from + * the usage table. So it is OK if we have never signed a license request. + */ if (session_context->request_signed && !session_context->response_loaded) { LOGE("Attempt to sign release before load"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; @@ -1166,12 +1169,14 @@ OEMCryptoResult OEMCrypto_PrepAndSignReleaseRequest( return result; } - OEMCrypto_Usage_Entry_Status status; + OEMCrypto_UsageEntryStatus status; int64_t seconds_since_license_received; int64_t seconds_since_first_decrypt; - result = OPKI_GetUsageEntryInfo(session_context->session_id, &status, - &seconds_since_license_received, - &seconds_since_first_decrypt); + uint8_t pst[ODK_PST_LEN_MAX]; + size_t pst_length = sizeof(pst); + result = OPKI_GetUsageEntryInfo( + session_context->session_id, &status, &seconds_since_license_received, + &seconds_since_first_decrypt, pst, &pst_length); if (result != OEMCrypto_SUCCESS) { LOGE("Error with getting usage entry info"); return result; @@ -1179,15 +1184,15 @@ OEMCryptoResult OEMCrypto_PrepAndSignReleaseRequest( result = ODK_PrepareCoreReleaseRequest( message, message_length, core_message_length, - &session_context->nonce_values, status, kMonotonicClock, - seconds_since_license_received, seconds_since_first_decrypt, - &session_context->clock_values, now); + &session_context->nonce_values, status, OEMCrypto_MonotonicClock, + seconds_since_license_received, seconds_since_first_decrypt, pst, + pst_length, &session_context->clock_values, now); if (*signature_length < required_signature_size || result == OEMCrypto_ERROR_SHORT_BUFFER) { *signature_length = required_signature_size; /* The core_message_length has been correctly set by - * ODK_PrepareCoreReleaseRequest, but the message and signature buffer will - * be initialized and filled in the subsequent call. */ + * ODK_PrepareCoreReleaseRequest, but the message and signature buffer + * will be initialized and filled in the subsequent call. */ return OEMCrypto_ERROR_SHORT_BUFFER; } if (result != OEMCrypto_SUCCESS) { @@ -1251,9 +1256,10 @@ OEMCryptoResult OEMCrypto_PrepAndSignRenewalRequest( if (result != OEMCrypto_SUCCESS) return result; RETURN_INVALID_CONTEXT_IF_NULL(core_message_length); RETURN_INVALID_CONTEXT_IF_NULL(signature_length); - /* If we have signed a request, but have not loaded it, something is wrong. On - * the other hand, we can sign a license release using the mac keys from the - * usage table. So it is OK if we have never signed a license request. */ + /* If we have signed a request, but have not loaded it, something is wrong. + * On the other hand, we can sign a license release using the mac keys from + * the usage table. So it is OK if we have never signed a license request. + */ if (session_context->request_signed && !session_context->response_loaded) { LOGE("Attempt to sign renewal before load"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; @@ -1269,8 +1275,8 @@ OEMCryptoResult OEMCrypto_PrepAndSignRenewalRequest( result == OEMCrypto_ERROR_SHORT_BUFFER) { *signature_length = required_signature_size; /* The core_message_length has been correctly set by - * ODK_PrepareCoreRenewalRequest, but the message and signature buffer will - * be initialized and filled in the subsequent call. */ + * ODK_PrepareCoreRenewalRequest, but the message and signature buffer + * will be initialized and filled in the subsequent call. */ return OEMCrypto_ERROR_SHORT_BUFFER; } if (result != OEMCrypto_SUCCESS) { @@ -1285,8 +1291,8 @@ OEMCryptoResult OEMCrypto_PrepAndSignRenewalRequest( } RETURN_INVALID_CONTEXT_IF_NULL(signature); if (session_context->nonce_values.api_major_version < ODK_FIRST_VERSION) { - /* If we are talking to an old license server, then we only sign the message - * body. */ + /* If we are talking to an old license server, then we only sign the + * message body. */ const uint8_t* message_body = message + *core_message_length; const size_t message_body_length = message_length - *core_message_length; if (OPKI_CheckKey(session_context->mac_key_client, MAC_KEY_CLIENT)) { @@ -1328,7 +1334,7 @@ OEMCryptoResult OEMCrypto_PrepAndSignRenewalRequest( static OEMCryptoResult LoadKeysNoSignature( OEMCryptoSession* session, const uint8_t* message, size_t message_length, OEMCrypto_Substring enc_mac_keys_iv, OEMCrypto_Substring enc_mac_keys, - size_t num_keys, const OEMCrypto_KeyObject* key_array, + size_t num_keys, const OEMCrypto_KeyObjectV2* key_array, OEMCrypto_Substring pst, OEMCrypto_Substring srm_restriction_data, OEMCrypto_LicenseType license_type) { ABORT_IF(g_opk_system_state != SYSTEM_INITIALIZED, @@ -1376,8 +1382,8 @@ static OEMCryptoResult LoadKeysNoSignature( return OEMCrypto_ERROR_INVALID_CONTEXT; } - /* Check to see if the MAC Keys' IV is directly before the MAC Keys, which is - forbidden. */ + /* Check to see if the MAC Keys' IV is directly before the MAC Keys, which + is forbidden. */ if (enc_mac_keys.offset >= KEY_IV_SIZE && enc_mac_keys.length > 0 && crypto_memcmp(message + enc_mac_keys.offset - KEY_IV_SIZE, message + enc_mac_keys_iv.offset, KEY_IV_SIZE) == 0) { @@ -1385,12 +1391,13 @@ static OEMCryptoResult LoadKeysNoSignature( return OEMCrypto_ERROR_INVALID_CONTEXT; } - /* If we've already loaded in a license and it isn't the current license type, - fail. */ + /* If we've already loaded in a license and it isn't the current license + type, fail. */ if (session->state == SESSION_LICENSE_LOADED && session->license_type != license_type) { LOGE( - "License type doesn't match. License type already loaded: %u, license " + "License type doesn't match. License type already loaded: %u, " + "license " "type to be loaded: %u", session->license_type, license_type); return OEMCrypto_ERROR_INVALID_CONTEXT; @@ -1517,7 +1524,8 @@ static OEMCryptoResult LoadKeysNoSignature( message + pst.offset, pst.length); if (result != OEMCrypto_SUCCESS) { LOGE( - "Failed to verify usage entry PST with result: %u, session_id = " + "Failed to verify usage entry PST with result: %u, session_id " + "= " "%u", result, session->session_id); break; @@ -1721,15 +1729,16 @@ OEMCryptoResult OEMCrypto_LoadEntitledContentKeys( goto cleanup; } - /* Initialized the index to which the entitled content key to be loaded. */ + /* Initialized the index to which the entitled content key to be loaded. + */ uint32_t entitled_content_key_index = key_session->num_entitled_content_keys; - /* It prefers to reuse an existing content key index that is entitled by the - * same entitlement key if one exists already, but will allocate a new index - * if there are none that can be reused. - * The block below searches whether there is an existing content key - * entitled by the same entitlement key, and will reuse the index to load - * the new content key if there is one. */ + /* It prefers to reuse an existing content key index that is entitled by + * the same entitlement key if one exists already, but will allocate a new + * index if there are none that can be reused. The block below searches + * whether there is an existing content key entitled by the same + * entitlement key, and will reuse the index to load the new content key + * if there is one. */ for (uint32_t k = 0; k < key_session->num_entitled_content_keys; k++) { /* Gets the entitlement key of the entitled content key at index k. */ const EntitlementKeyInfo* key_info = &key_session->entitlement_keys[k]; @@ -1749,8 +1758,8 @@ OEMCryptoResult OEMCrypto_LoadEntitledContentKeys( } } - /* If we're not updating an existing content key and we can't add any more, - we fail. */ + /* If we're not updating an existing content key and we can't add any + more, we fail. */ const bool adding_new_content_key = entitled_content_key_index == key_session->num_entitled_content_keys; if (adding_new_content_key && @@ -2260,7 +2269,8 @@ OEMCryptoResult OEMCrypto_Generic_Encrypt( } if (in_buffer_length % AES_BLOCK_SIZE != 0) { LOGE( - "in_buffer_length of %zu is not a multiple of the crypto block length.", + "in_buffer_length of %zu is not a multiple of the crypto block " + "length.", in_buffer_length); return OEMCrypto_ERROR_INVALID_CONTEXT; } @@ -2337,7 +2347,8 @@ OEMCryptoResult OEMCrypto_Generic_Decrypt( } if (in_buffer_length % AES_BLOCK_SIZE != 0) { LOGE( - "in_buffer_length of %zu is not a multiple of the crypto block length.", + "in_buffer_length of %zu is not a multiple of the crypto block " + "length.", in_buffer_length); return OEMCrypto_ERROR_INVALID_CONTEXT; } @@ -2649,8 +2660,8 @@ OEMCryptoResult OEMCrypto_FactoryInstallBCCSignature(const uint8_t* signature, // keybox_CA_token(uint8_t[], 72 bytes) || // signature_length(uint32_t, 4 bytes, value=0x20) || // signature(uint8_t[], 32 bytes) - // This signature excluding the leading byte of |signature_type| is generated - // by OEMCrypto_GetBootCertificateChain() in the factory. + // This signature excluding the leading byte of |signature_type| is + // generated by OEMCrypto_GetBootCertificateChain() in the factory. # define SIGNATURE_BY_KEYBOX_MAX_LENGTH 128 if (sizeof(uint8_t) + signature_length > SIGNATURE_BY_KEYBOX_MAX_LENGTH) { return OEMCrypto_ERROR_INVALID_CONTEXT; @@ -2806,14 +2817,21 @@ OEMCryptoResult OEMCrypto_LoadOEMPrivateKey(OEMCrypto_SESSION session) { return OPKI_SetStatePostCall(session_context, API_LOADOEMPRIVATEKEY); } -OEMCryptoResult OEMCrypto_GetRandom(uint8_t* randomData, size_t dataLength) { +OEMCryptoResult OEMCrypto_GetRandom(OEMCrypto_SharedMemory* random_data, + size_t random_data_length) { +#ifdef FACTORY_BUILD_ONLY if (g_opk_system_state != SYSTEM_INITIALIZED) { LOGE("OEMCrypto is not yet initialized"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } - RETURN_INVALID_CONTEXT_IF_NULL(randomData); - RETURN_INVALID_CONTEXT_IF_ZERO(dataLength); - return WTPI_C1_RandomBytes(randomData, dataLength); + RETURN_INVALID_CONTEXT_IF_NULL(random_data); + RETURN_INVALID_CONTEXT_IF_ZERO(random_data_length); + return WTPI_C1_RandomBytes(random_data, random_data_length); +#else + (void)random_data; + (void)random_data_length; + return OEMCrypto_ERROR_NOT_IMPLEMENTED; +#endif } uint32_t OEMCrypto_APIVersion(void) { return API_MAJOR_VERSION; } @@ -3643,7 +3661,7 @@ OEMCryptoResult OEMCrypto_ReportUsage(OEMCrypto_SESSION session, } OEMCryptoResult OEMCrypto_GetUsageEntryInfo( - OEMCrypto_SESSION session, OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_SESSION session, OEMCrypto_UsageEntryStatus* status, int64_t* seconds_since_license_received, int64_t* seconds_since_first_decrypt) { if (g_opk_system_state != SYSTEM_INITIALIZED) { @@ -3666,9 +3684,11 @@ OEMCryptoResult OEMCrypto_GetUsageEntryInfo( result = OPKI_CheckStatePreCall(session_context, API_GETUSAGEENTRYINFO); if (result != OEMCrypto_SUCCESS) return result; - result = OPKI_GetUsageEntryInfo(session_context->session_id, status, - seconds_since_license_received, - seconds_since_first_decrypt); + uint8_t pst[ODK_PST_LEN_MAX]; + size_t pst_length = sizeof(pst); + result = OPKI_GetUsageEntryInfo( + session_context->session_id, status, seconds_since_license_received, + seconds_since_first_decrypt, pst, &pst_length); if (result != OEMCrypto_SUCCESS) { LOGE("Failed to get usage entry info with result: %u, session_id = %u", result, session_context->session_id); @@ -4000,28 +4020,14 @@ OEMCryptoResult OEMCrypto_GenerateCertificateKeyPair( /* Generate the key pair and fill the key type. */ size_t required_signature_size = MAX_ASYMMETRIC_SIGNATURE_SIZE; - if (session_context->prov40_oem_private_key == NULL) { - /* This is for obtaining an OEM leaf cert. The signature is a COSE_SIGN1 - * format. Adjust the required signature size. */ - result = WTPI_GetMaxBccKeyCoseSign1Size(&required_signature_size); - if (result != OEMCrypto_SUCCESS) { - LOGE("Failed to get max COSE_SIGN1 size: %u", result); - return result; - } + result = WTPI_GetMaxBccKeyCoseSign1Size(&required_signature_size); + if (result != OEMCrypto_SUCCESS) { + LOGE("Failed to get max COSE_SIGN1 size: %u", result); + return result; } - - // Determine which stage of Prov 4 provisioning we are in - // Stage 1: OEM private key is NOT installed, we are requesting it. - // Stage 2: OEM private key is installed, now requesting DRM cert. - const bool is_stage1 = (session_context->prov40_oem_private_key == NULL); - const CertSignatureType requested_cert_type = - is_stage1 ? CERT_SIGNATURE_OEM : CERT_SIGNATURE_DRM; - const uint32_t wrapping_context = - is_stage1 ? DEVICE_KEY_WRAP_OEM_CERT : DEVICE_KEY_WRAP_DRM_CERT; - AsymmetricKeyType generated_key_type; result = WTPI_GenerateRandomCertificateKeyPair( - requested_cert_type, &generated_key_type, wrapped_private_key, + CERT_SIGNATURE_DRM, &generated_key_type, wrapped_private_key, wrapped_private_key_size, public_key, public_key_size); if (public_key_signature == NULL || *public_key_signature_size < required_signature_size || @@ -4041,65 +4047,19 @@ OEMCryptoResult OEMCrypto_GenerateCertificateKeyPair( return OEMCrypto_ERROR_UNKNOWN_FAILURE; } - if (is_stage1) { - /* This is the first stage for obtaining an OEM leaf cert. Sign with BCC - * leaf private key. */ - result = - WTPI_BccKeyCoseSign1(public_key, *public_key_size, public_key_signature, - public_key_signature_size); - } else { - /* This is the second stage for obtaining a DRM cert. Sign with the - * installed prov40_oem_private_key. */ - WTPI_AsymmetricKey_Handle signing_key_handle; - uint32_t allowed_schemes_unused; - AsymmetricKey* oem_private_key = session_context->prov40_oem_private_key; - result = WTPI_UnwrapIntoAsymmetricKeyHandle( - DEVICE_KEY_WRAP_OEM_CERT, oem_private_key->wrapped_key, - oem_private_key->wrapped_key_length, oem_private_key->key_type, - &signing_key_handle, &allowed_schemes_unused); - if (result != OEMCrypto_SUCCESS) { - LOGE("Failed to unwrap OEM private key into key handle with result: %u", - result); - return result; - } - /* Sign the generated public key. */ - switch (oem_private_key->key_type) { - case DRM_ECC_PRIVATE_KEY: { - result = WTPI_ECCSign(signing_key_handle, public_key, *public_key_size, - public_key_signature, public_key_signature_size); - break; - } - case DRM_RSA_PRIVATE_KEY: { - /* If the signing key is RSA, the signing key must be OEM private key. - */ - if ((session_context->prov40_oem_allowed_schemes & kSign_RSASSA_PSS) != - kSign_RSASSA_PSS) { - LOGE("Bad prov40 OEM RSA padding scheme: %x", - session_context->prov40_oem_allowed_schemes); - result = OEMCrypto_ERROR_INVALID_KEY; - break; - } - result = WTPI_RSASign(signing_key_handle, public_key, *public_key_size, - public_key_signature, public_key_signature_size, - kSign_RSASSA_PSS); - break; - } - case PROV40_ED25519_PRIVATE_KEY: - default: - LOGE("Unexpected signing private key type"); - result = OEMCrypto_ERROR_UNKNOWN_FAILURE; - break; - } - WTPI_FreeAsymmetricKeyHandle(signing_key_handle); - } + /* Sign the generated public key with the BCC leaf private key. */ + result = + WTPI_BccKeyCoseSign1(public_key, *public_key_size, public_key_signature, + public_key_signature_size); if (result != OEMCrypto_SUCCESS) { return result; } + /* Store the generated key for signing the provisioning request. */ WTPI_AsymmetricKey_Handle csr_signing_key_handle; uint32_t allowed_schemes; result = WTPI_UnwrapIntoAsymmetricKeyHandle( - wrapping_context, wrapped_private_key, *wrapped_private_key_size, + DEVICE_KEY_WRAP_DRM_CERT, wrapped_private_key, *wrapped_private_key_size, generated_key_type, &csr_signing_key_handle, &allowed_schemes); if (result != OEMCrypto_SUCCESS) { LOGE( @@ -4131,7 +4091,7 @@ OEMCryptoResult OEMCrypto_GenerateCertificateKeyPair( LOGE("Failed to create asymmetric CSR signing key with result: %u", result); return result; } - session_context->prov40_csr_keywrap_context = wrapping_context; + return OPKI_SetStatePostCall(session_context, API_GENERATECERTIFICATEKEYPAIR); } @@ -4206,67 +4166,13 @@ OEMCryptoResult OEMCrypto_GetDeviceSignedCsrPayload( OEMCryptoResult OEMCrypto_InstallOemPrivateKey( OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, const uint8_t* wrapped_private_key, size_t wrapped_private_key_length) { - RETURN_INVALID_CONTEXT_IF_NULL(wrapped_private_key); - if (g_opk_system_state != SYSTEM_INITIALIZED) { - LOGE("OEMCrypto is not yet initialized"); - return OEMCrypto_ERROR_UNKNOWN_FAILURE; - } - AsymmetricKeyType oem_key_type; - if (!OPKI_PrivateKeyTypeToAsymmetricKey(key_type, &oem_key_type)) { - LOGE("Invalid key_type: %d", key_type); - return OEMCrypto_ERROR_INVALID_CONTEXT; - } - if (OPKI_GetSessionType(session) != SESSION_TYPE_OEMCRYPTO) { - LOGE("Unexpected session type."); - return OEMCrypto_ERROR_INVALID_SESSION; - } - OEMCryptoSession* session_context = NULL; - OEMCryptoResult result = OPKI_GetSession(session, &session_context); - if (result != OEMCrypto_SUCCESS) { - LOGE("Failed to get session with result: %u, session = %u", result, - session); - return result; - } - ABORT_IF(session_context == NULL, - "OPKI_GetSession() provided invalid output."); - result = OPKI_CheckStatePreCall(session_context, API_INSTALLOEMPRIVATEKEY); - if (result != OEMCrypto_SUCCESS) return result; - - WTPI_AsymmetricKey_Handle private_key_handle; - uint32_t allowed_schemes; - result = WTPI_UnwrapIntoAsymmetricKeyHandle( - DEVICE_KEY_WRAP_OEM_CERT, wrapped_private_key, wrapped_private_key_length, - oem_key_type, &private_key_handle, &allowed_schemes); - if (result != OEMCrypto_SUCCESS) { - LOGE("Failed to unwrap OEM private key into key handle with result: %u", - result); - return result; - } - size_t signature_size; - result = WTPI_GetSignatureSize(private_key_handle, &signature_size); - WTPI_FreeAsymmetricKeyHandle(private_key_handle); - if (result != OEMCrypto_SUCCESS) { - LOGE("Failed to get OEM key signature size with result: %u", result); - return result; - } - - result = OPKI_LoadProv40OEMKey( - session_context, oem_key_type, wrapped_private_key, - wrapped_private_key_length, signature_size, allowed_schemes); - if (result != OEMCrypto_SUCCESS) { - LOGE("Failed to install OEM key"); - goto cleanup; - } - result = OPKI_SetStatePostCall(session_context, API_INSTALLOEMPRIVATEKEY); - -cleanup:; - OEMCryptoResult free_key_result = FreeMacAndEncryptionKeys(session_context); - if (result == OEMCrypto_SUCCESS) result = free_key_result; - if (result != OEMCrypto_SUCCESS) { - OPKI_FreeAsymmetricKeyFromTable(&session_context->prov40_oem_private_key); - session_context->state = SESSION_INVALID; - } - return result; + // TODO(b/374834498): Remove stub once build system is updated to + // handle backwards-compatible OEMCrypto testing. + (void)session; + (void)key_type; + (void)wrapped_private_key; + (void)wrapped_private_key_length; + return OEMCrypto_ERROR_NOT_IMPLEMENTED; } OEMCryptoResult OEMCrypto_ReassociateEntitledKeySession( diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_api_macros.h b/oemcrypto/opk/oemcrypto_ta/oemcrypto_api_macros.h index fec38f8..597ed6d 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_api_macros.h +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_api_macros.h @@ -33,9 +33,9 @@ // OEMCrypto API version would be v17.0.1; once the supported OEMCrypto API // version bumps to v17.1, the first released OPK implementation would be // v17.1.0 -#define API_MAJOR_VERSION 19 -#define API_MINOR_VERSION 5 +#define API_MAJOR_VERSION 20 +#define API_MINOR_VERSION 0 #define OPK_PATCH_VERSION 0 -#define OPK_BUILD_ID "19.5.0" +#define OPK_BUILD_ID "MAIN" #endif /* OEMCRYPTO_TA_OEMCRYPTO_API_MACROS_H_ */ diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_key.c b/oemcrypto/opk/oemcrypto_ta/oemcrypto_key.c index b71b264..2cd3675 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_key.c +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_key.c @@ -71,10 +71,10 @@ bool OPKI_PrivateKeyTypeToAsymmetricKey(OEMCrypto_PrivateKeyType priv_key_type, AsymmetricKeyType* asym_key_type) { if (asym_key_type == NULL) return false; switch (priv_key_type) { - case OEMCrypto_RSA_Private_Key: + case OEMCrypto_RSAPrivateKey: *asym_key_type = DRM_RSA_PRIVATE_KEY; return true; - case OEMCrypto_ECC_Private_Key: + case OEMCrypto_ECCPrivateKey: *asym_key_type = DRM_ECC_PRIVATE_KEY; return true; } @@ -86,10 +86,10 @@ bool OPKI_AsymmetricKeyToPrivateKeyType( if (priv_key_type == NULL) return false; switch (asym_key_type) { case DRM_RSA_PRIVATE_KEY: - *priv_key_type = OEMCrypto_RSA_Private_Key; + *priv_key_type = OEMCrypto_RSAPrivateKey; return true; case DRM_ECC_PRIVATE_KEY: - *priv_key_type = OEMCrypto_ECC_Private_Key; + *priv_key_type = OEMCrypto_ECCPrivateKey; return true; case PROV40_ED25519_PRIVATE_KEY: // ED25519 key can only be used in provisioning 4 BCC. diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_key_control_block.c b/oemcrypto/opk/oemcrypto_ta/oemcrypto_key_control_block.c index 5725dc1..07fbcf2 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_key_control_block.c +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_key_control_block.c @@ -32,6 +32,7 @@ OEMCryptoResult OPKI_ParseKeyControlBlock(const uint8_t* kcb, const char* verification = key_control_block.verification; key_control_block.valid = + memcmp(verification, "kc20", 4) == 0 || /* add in version 20 api */ memcmp(verification, "kc19", 4) == 0 || /* add in version 19 api */ memcmp(verification, "kc18", 4) == 0 || /* add in version 18 api */ memcmp(verification, "kc17", 4) == 0 || /* add in version 17 api */ diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_serialized_usage_table.h b/oemcrypto/opk/oemcrypto_ta/oemcrypto_serialized_usage_table.h index f606d7e..3e135c3 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_serialized_usage_table.h +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_serialized_usage_table.h @@ -50,7 +50,7 @@ typedef struct SavedUsageEntry { int64_t time_of_first_decrypt; /* Time in seconds on system clock. */ int64_t time_of_last_decrypt; /* Time in seconds on system clock.. */ /* Status of the entry or license, as documented in OEMCrypto spec. */ - enum OEMCrypto_Usage_Entry_Status status; + enum OEMCrypto_UsageEntryStatus status; /* Server Mac key wrapped for this device. */ uint8_t mac_key_server[WRAPPED_MAC_KEY_SIZE]; /* Client Mac key wrapped for this device. */ @@ -117,7 +117,7 @@ typedef struct SavedUsageEntryLegacy { int64_t time_of_first_decrypt; /* Time in seconds on system clock. */ int64_t time_of_last_decrypt; /* Time in seconds on system clock.. */ /* Status of the entry or license, as documented in OEMCrypto spec. */ - enum OEMCrypto_Usage_Entry_Status status; + enum OEMCrypto_UsageEntryStatus status; /* Server Mac key wrapped for this device. */ uint8_t mac_key_server[WRAPPED_MAC_KEY_SIZE]; /* Client Mac key wrapped for this device. */ diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.c b/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.c index d77336c..4f2c886 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.c +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.c @@ -57,7 +57,6 @@ OEMCryptoResult OPKI_InitializeSession(OEMCryptoSession* session, .hash_error = OEMCrypto_SUCCESS, }, .allowed_schemes = kSign_RSASSA_PSS, - .prov40_oem_allowed_schemes = kSign_RSASSA_PSS, }; OEMCryptoResult result = ODK_InitializeSessionValues( &session->timer_limits, &session->clock_values, &session->nonce_values, @@ -72,10 +71,6 @@ OEMCryptoResult OPKI_TerminateSession(OEMCryptoSession* session) { OPKI_FreeAsymmetricKeyFromTable(&session->drm_private_key); OEMCryptoResult result = free_key_result; - free_key_result = - OPKI_FreeAsymmetricKeyFromTable(&session->prov40_oem_private_key); - if (result == OEMCrypto_SUCCESS) result = free_key_result; - free_key_result = OPKI_FreeAsymmetricKeyFromTable(&session->prov40_csr_signing_key); if (result == OEMCrypto_SUCCESS) result = free_key_result; @@ -122,7 +117,6 @@ OEMCryptoResult OPKI_CheckStatePreCall(OEMCryptoSession* session, switch (session->state) { case (SESSION_OPENED): case (SESSION_LOAD_OEM_RSA_KEY): - case (SESSION_PROV4_OEM_KEY_LOADED): case (SESSION_DRM_KEY_LOADED): case (SESSION_CAST_RSA_KEY_LOADED): return OEMCrypto_SUCCESS; @@ -142,7 +136,6 @@ OEMCryptoResult OPKI_CheckStatePreCall(OEMCryptoSession* session, switch (session->state) { case (SESSION_OPENED): case (SESSION_LOAD_OEM_RSA_KEY): - case (SESSION_PROV4_OEM_KEY_LOADED): case (SESSION_PREPARING_REQUEST): case (SESSION_DRM_KEY_LOADED): case (SESSION_CAST_RSA_KEY_LOADED): @@ -190,8 +183,6 @@ OEMCryptoResult OPKI_CheckStatePreCall(OEMCryptoSession* session, case (SESSION_OPENED): case (SESSION_PREPARING_REQUEST): case (SESSION_USAGE_ENTRY_LOADED): - // Prov 4 OEM private key installed at beginning of session - case (SESSION_PROV4_OEM_KEY_LOADED): // Provisioning 4 skips LoadProvisioning case (SESSION_WAIT_FOR_PROVISIONING): // This can happen when using testing the CDM with a pre-provisioned @@ -203,25 +194,11 @@ OEMCryptoResult OPKI_CheckStatePreCall(OEMCryptoSession* session, default: goto err; } - case API_INSTALLOEMPRIVATEKEY: - switch (session->state) { - case (SESSION_OPENED): - case (SESSION_PREPARING_REQUEST): - case (SESSION_USAGE_ENTRY_LOADED): - // Provisioning 3 calls OEMCrypto_LoadOEMPrivateKey during provisioning - case (SESSION_LOAD_OEM_RSA_KEY): - // Provisioning 4 skips LoadProvisioning - case (SESSION_WAIT_FOR_PROVISIONING): - return OEMCrypto_SUCCESS; - default: - goto err; - } case API_GENERATECERTIFICATEKEYPAIR: switch (session->state) { case (SESSION_OPENED): case (SESSION_PREPARING_REQUEST): case (SESSION_USAGE_ENTRY_LOADED): - case (SESSION_PROV4_OEM_KEY_LOADED): return OEMCrypto_SUCCESS; default: goto err; @@ -416,9 +393,6 @@ OEMCryptoResult OPKI_SetStatePostCall(OEMCryptoSession* session, case API_LOADOEMPRIVATEKEY: session->state = SESSION_LOAD_OEM_RSA_KEY; break; - case API_INSTALLOEMPRIVATEKEY: - session->state = SESSION_PROV4_OEM_KEY_LOADED; - break; case API_DECRYPTCENC: case API_GENERICENCRYPT: case API_GENERICDECRYPT: @@ -496,38 +470,6 @@ OEMCryptoResult OPKI_LoadDRMKey(OEMCryptoSession* session, return OEMCrypto_SUCCESS; } -OEMCryptoResult OPKI_LoadProv40OEMKey(OEMCryptoSession* session, - AsymmetricKeyType key_type, - const uint8_t* wrapped_key, - size_t wrapped_key_length, - size_t signature_size, - uint32_t allowed_schemes) { - if (session == NULL || wrapped_key == NULL || wrapped_key_length == 0 || - !IsSupportedDrmKeyType(key_type)) { - return OEMCrypto_ERROR_INVALID_CONTEXT; - } - // Free any existing keys. This will only happen if we are using the test RSA - // key and then try to load an existing device certificate. - OEMCryptoResult result = - OPKI_FreeAsymmetricKeyFromTable(&session->prov40_oem_private_key); - if (result != OEMCrypto_SUCCESS) return result; - - if (key_type != DRM_RSA_PRIVATE_KEY) { - allowed_schemes = 0; - } - - result = OPKI_CreateAsymmetricKey(&session->prov40_oem_private_key, key_type, - wrapped_key, wrapped_key_length, - signature_size, allowed_schemes); - if (result != OEMCrypto_SUCCESS) { - LOGE("Failed to create asymmetric OEM private key with result: %u", result); - return result; - } - session->prov40_oem_allowed_schemes = allowed_schemes; - - return OEMCrypto_SUCCESS; -} - NO_IGNORE_RESULT static OEMCryptoResult DeriveKey( WTPI_K1_SymmetricKey_Handle master_key, uint8_t counter, const uint8_t* key_label, size_t key_label_length, const uint8_t* context, @@ -665,12 +607,12 @@ OEMCryptoResult OPKI_GenerateCertSignature(OEMCryptoSession* session, key_type = signing_key->key_type; wrapping_context = DEVICE_KEY_WRAP_DRM_CERT; } else if (signature_type == CERT_SIGNATURE_CSR) { - // Prov40 OEM cert, RSA or ECC + // Prov40 CSR signing key (RSA or ECC) if (!DRMKeyMaySign(session->prov40_csr_signing_key)) return OEMCrypto_ERROR_INVALID_KEY; signing_key = session->prov40_csr_signing_key; key_type = signing_key->key_type; - wrapping_context = session->prov40_csr_keywrap_context; + wrapping_context = DEVICE_KEY_WRAP_DRM_CERT; } else if (signature_type == CERT_SIGNATURE_OEM) { // Prov30 OEM cert, RSA only // For Prov30, signing key handle is directly loaded from the secure @@ -720,12 +662,9 @@ NO_IGNORE_RESULT OEMCryptoResult OPKI_GetSessionSignatureHashAlgorithm( if (session->drm_private_key != NULL) { key = session->drm_private_key; context = DEVICE_KEY_WRAP_DRM_CERT; - } else if (session->prov40_oem_private_key != NULL) { - key = session->prov40_oem_private_key; - context = DEVICE_KEY_WRAP_OEM_CERT; } else if (session->prov40_csr_signing_key != NULL) { key = session->prov40_csr_signing_key; - context = session->prov40_csr_keywrap_context; + context = DEVICE_KEY_WRAP_DRM_CERT; } else if (session->prov30_oem_key_loaded) { // Create OEM key handle on the fly for prov30 since it is not held by // the session. diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.h b/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.h index ea7d57a..a90cb27 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.h +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_session.h @@ -32,7 +32,6 @@ typedef enum OEMCryptoSessionState { SESSION_INVALID = (int)0x23a27071, SESSION_LOAD_OEM_RSA_KEY = (int)0x9d7cae94, SESSION_DRM_KEY_LOADED = (int)0xbc17f592, - SESSION_PROV4_OEM_KEY_LOADED = (int)0x902c5b0a, SESSION_CAST_RSA_KEY_LOADED = (int)0x84a0bb23, } OEMCryptoSessionState; @@ -74,7 +73,6 @@ typedef enum OEMCryptoSessionAPI { API_LOADRENEWAL = (int)0xb096dc9a, API_LOADRELEASE = (int)0x7b89f4a2, API_CREATEENTITLEDKEYSESSION = (int)0x6d7d9bfc, - API_INSTALLOEMPRIVATEKEY = (int)0xd6195c63, API_GENERATECERTIFICATEKEYPAIR = (int)0x30871a8a, API_GETKEYHANDLE = (int)0x81aac42f, } OEMCryptoSessionAPI; @@ -83,14 +81,8 @@ typedef struct OEMCryptoSession { OEMCrypto_SESSION session_id; OEMCryptoSessionState state; AsymmetricKey* drm_private_key; - /* This key can only be used to sign the generated cert key in - * provisioning 4.*/ - AsymmetricKey* prov40_oem_private_key; - /* This key can only be used to sign the provisioning 4 requests. For stage 1, - * it is the generated OEM private key. For stage 2, it is the generated DRM - * private key. */ + /* This key can only be used to sign the provisioning 4 requests. */ AsymmetricKey* prov40_csr_signing_key; - uint32_t prov40_csr_keywrap_context; SymmetricKey* mac_key_server; SymmetricKey* mac_key_client; SymmetricKey* encryption_key; @@ -103,7 +95,6 @@ typedef struct OEMCryptoSession { bool valid_srm_version; uint64_t timer_start; uint32_t allowed_schemes; /* For RSA signatures. */ - uint32_t prov40_oem_allowed_schemes; /* For RSA signatures. */ bool decrypt_started; /* If the license has been used in this session. */ ODK_NonceValues nonce_values; ODK_TimerLimits timer_limits; @@ -170,15 +161,6 @@ NO_IGNORE_RESULT OEMCryptoResult OPKI_LoadDRMKey(OEMCryptoSession* session, size_t signature_size, uint32_t allowed_schemes); -/* Attempts to load the wrapped OEM private key |wrapped_key| into |session|'s - |prov40_oem_private_key| field. |key_type| must be valid. |allowed_schemes| - is only assigned for key types which support it. - Caller retains ownership of all pointers and they must not be NULL. */ -NO_IGNORE_RESULT OEMCryptoResult -OPKI_LoadProv40OEMKey(OEMCryptoSession* session, AsymmetricKeyType key_type, - const uint8_t* wrapped_key, size_t wrapped_key_length, - size_t signature_size, uint32_t allowed_schemes); - /* Derives mac and encryption keys from the specific key. Uses AES-128-CMAC and |context| to derive and create a mac_key_server, mac_key_client, and encryption_key for the |session|. diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.c b/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.c index 997a5dd..30d1207 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.c +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.c @@ -97,7 +97,7 @@ static UsageTable g_usage_table; /* TODO(b/158720996): figure out a way to avoid __attribute__(packed). */ typedef struct { uint8_t signature[20]; // -- HMAC SHA1 of the rest of the report. - uint8_t status; // current status of entry. (OEMCrypto_Usage_Entry_Status) + uint8_t status; // current status of entry. (OEMCrypto_UsageEntryStatus) uint8_t clock_security_level; uint8_t pst_length; uint8_t padding; // make int64's word aligned. @@ -624,8 +624,9 @@ OEMCryptoResult OPKI_TerminateUsageTable(void) { static NO_IGNORE_RESULT UsageEntryStatus GetUsageEntryStatusFromEntry(UsageEntry* entry) { ABORT_IF_NULL(entry); - if (entry->data.status == kInactive || entry->data.status == kInactiveUsed || - entry->data.status == kInactiveUnused) { + if (entry->data.status == OEMCrypto_Inactive || + entry->data.status == OEMCrypto_InactiveUsed || + entry->data.status == OEMCrypto_InactiveUnused) { return USAGE_ENTRY_DEACTIVATED; } if (entry->is_loaded) { @@ -1360,12 +1361,15 @@ OEMCryptoResult OPKI_SignReleaseRequest(OEMCrypto_SESSION session_id, } OEMCryptoResult OPKI_GetUsageEntryInfo(OEMCrypto_SESSION session_id, - OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_UsageEntryStatus* status, int64_t* seconds_since_license_received, - int64_t* seconds_since_first_decrypt) { + int64_t* seconds_since_first_decrypt, + uint8_t* pst, size_t* pst_length) { RETURN_INVALID_CONTEXT_IF_NULL(status); RETURN_INVALID_CONTEXT_IF_NULL(seconds_since_license_received); RETURN_INVALID_CONTEXT_IF_NULL(seconds_since_first_decrypt); + RETURN_INVALID_CONTEXT_IF_NULL(pst); + RETURN_INVALID_CONTEXT_IF_NULL(pst_length); UsageEntry* entry = FindUsageEntry(session_id); if (!entry) { @@ -1373,6 +1377,11 @@ OEMCryptoResult OPKI_GetUsageEntryInfo(OEMCrypto_SESSION session_id, return OEMCrypto_ERROR_UNKNOWN_FAILURE; } + if (*pst_length < entry->data.pst_length) { + *pst_length = entry->data.pst_length; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + uint64_t now; OEMCryptoResult result = WTPI_GetTrustedTime(&now); if (result != OEMCrypto_SUCCESS) { @@ -1383,6 +1392,8 @@ OEMCryptoResult OPKI_GetUsageEntryInfo(OEMCrypto_SESSION session_id, *status = entry->data.status; *seconds_since_license_received = now - entry->data.time_of_license_received; *seconds_since_first_decrypt = now - entry->data.time_of_first_decrypt; + *pst_length = entry->data.pst_length; + memcpy(pst, entry->data.pst, entry->data.pst_length); return OEMCrypto_SUCCESS; } @@ -1426,7 +1437,7 @@ OEMCryptoResult OPKI_ShrinkUsageTableHeader(uint32_t new_entry_count, LOGE("Usage table not initialized"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } - if (new_entry_count >= g_usage_table.table_size) { + if (new_entry_count > g_usage_table.table_size) { return OEMCrypto_ERROR_UNKNOWN_FAILURE; } for (size_t i = new_entry_count; i < g_usage_table.table_size; i++) { diff --git a/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.h b/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.h index 0aa40b2..ba8ec21 100644 --- a/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.h +++ b/oemcrypto/opk/oemcrypto_ta/oemcrypto_usage_table.h @@ -170,9 +170,9 @@ NO_IGNORE_RESULT OEMCryptoResult OPKI_SignReleaseRequest( * or deactivating the license. Pointers must be non-null and are owned by the * caller. */ NO_IGNORE_RESULT OEMCryptoResult OPKI_GetUsageEntryInfo( - OEMCrypto_SESSION session_id, OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_SESSION session_id, OEMCrypto_UsageEntryStatus* status, int64_t* seconds_since_license_received, - int64_t* seconds_since_first_decrypt); + int64_t* seconds_since_first_decrypt, uint8_t* pst, size_t* pst_length); /** * Move the usage entry associated with |session| to the new index in the diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi/README.md b/oemcrypto/opk/oemcrypto_ta/wtpi/README.md index 066082d..8f07d7e 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi/README.md +++ b/oemcrypto/opk/oemcrypto_ta/wtpi/README.md @@ -3,36 +3,33 @@ Some of the headers in wtpi/ directory are tested by the code in wtpi_test/. wtpi_test uses serialization/generator/scrape_interface.py to parse the WTPI interface declarations and generate serialization APIs such as: -* OPK_Pack_SaveGenerationNumber_Request(), -* OPK_Unpack_K1_DeriveKeyFromKeyHandle_Response(), + +* `OPK_Pack_SaveGenerationNumber_Request()` +* `OPK_Unpack_K1_DeriveKeyFromKeyHandle_Response()` * ... In order for the types of the parameters of these WTPI interfaces to be correctly determined and inserted into the auto-generated -OPK_Pack_* / OPK_Unpack_* functions, certain naming conventions have to be -followed: - -* To pack a variable length buffer X with type uint8_t*, the size of the -array must be named as "X_length" or XLength". -* If an output variable length buffer doesn't have an output size specified in -the parameter list, and is supposed to have the same size as the input buffer, -then the output buffer must be named as "out_buffer". - -Below is an example following the naming convention above: -``` -OEMCryptoResult WTPI_C1_SHA256(const uint8_t* input, size_t input_length, - uint8_t* out_buffer); -``` -You can find more details in scrape_interface.py for what is looked for by the -parser. +OPK_Pack_* / OPK_Unpack_* functions, +[certain naming conventions][naming-conventions] have to be followed. WTPI interfaces that are currently covered by wtpi_test: -* wtpi_generation_number_interface.h, -* wtpi_crypto_and_key_management_interface_layer1.h, -* wtpi_crypto_asymmetric_interface.h, -* wtpi_crc32_interface.h, -* wtpi_provisioning_4_interface.h, + +* wtpi_generation_number_interface.h +* wtpi_crypto_and_key_management_interface_layer1.h +* wtpi_crypto_asymmetric_interface.h +* wtpi_provisioning_4_interface.h +* wtpi_crc32_interface.h +* wtpi_clock_interface_layer1.h +* wtpi_config_interface.h +* wtpi_device_key_interface.h + +However, all WTPI interface may potentially be covered someday, so it's best +practice to always follow [the naming conventions][naming-conventions]. Please be cautious when updating parameter names in these interfaces. It can potentially break the auto-generated serialization functions used by the WTPI -tests if the naming convention is not enforced. +tests if the naming conventions are not enforced. It can also cause the +generator to generate invalid or insecure code. + +[naming-conventions]: https://g3doc.corp.google.com/video/widevine/g3doc/devices/oec_function_conventions.md diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_clock_interface_layer1.h b/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_clock_interface_layer1.h index 6e7aa88..2cc47e0 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_clock_interface_layer1.h +++ b/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_clock_interface_layer1.h @@ -65,7 +65,7 @@ OEMCryptoResult WTPI_TerminateClock(void); * https://developers.google.com/widevine/drm/client/oemcrypto/v16/odk-timers * for the definition of insecure clock, secure timer, and secure clock. */ -OEMCrypto_Clock_Security_Level WTPI_GetClockType(void); +OEMCrypto_ClockSecurityLevel WTPI_GetClockType(void); /// @} diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_crypto_asymmetric_interface.h b/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_crypto_asymmetric_interface.h index 8670711..f448c0d 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_crypto_asymmetric_interface.h +++ b/oemcrypto/opk/oemcrypto_ta/wtpi/wtpi_crypto_asymmetric_interface.h @@ -332,9 +332,9 @@ OEMCryptoResult WTPI_GetSignatureHashAlgorithm( * key pair is placed into |private_key_handle| and |public_key|. * * @param[in] deriving_key: The deriving key in clear bytes. - * @param[in] deriving_key_length: size of of |deriving_key| in bytes. + * @param[in] deriving_key_length: size of |deriving_key| in bytes. * @param[in] context: input context for deriving key. - * @param[in] context_length: size of of |context| in bytes. + * @param[in] context_length: size of |context| in bytes. * @param[in] key_type: type of asymmetric key. * @param[out] private_key_handle: The derived private key. * @param[out] public_key: The derived public key. diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_reference/cose_util.c b/oemcrypto/opk/oemcrypto_ta/wtpi_reference/cose_util.c index 7de7870..58de948 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_reference/cose_util.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_reference/cose_util.c @@ -13,6 +13,7 @@ #include "dice/cbor_writer.h" #include "dice/config.h" #include "dice/dice.h" +#include "oemcrypto_api_macros.h" #include "oemcrypto_check_macros.h" #include "wtpi_crypto_and_key_management_interface_layer1.h" #include "wtpi_crypto_asymmetric_interface.h" @@ -217,20 +218,23 @@ static DiceResult EncodeConfigurationDescriptor(size_t buffer_size, CborOutInit(buffer, buffer_size, &out); CborWriteMap(/*num_pairs=*/2, &out); + char str_buf[32] = {0}; // Add component name. CborWriteInt(kComponentNameLabel, &out); if (is_leaf) { // The leaf certificate's component name must contain "Widevine". CborWriteTstr("Widevine", &out); } else { - char buf[32] = {0}; - snprintf(buf, sizeof(buf), "Component %u", entry_index); - CborWriteTstr(buf, &out); + snprintf(str_buf, sizeof(str_buf), "Component %u", entry_index); + CborWriteTstr(str_buf, &out); + memset(str_buf, 0, sizeof(str_buf)); } // Add component version. CborWriteInt(kComponentVersionLabel, &out); - CborWriteTstr("19", &out); + snprintf(str_buf, sizeof(str_buf), "%d", API_MAJOR_VERSION); + CborWriteTstr(str_buf, &out); + memset(str_buf, 0, sizeof(str_buf)); if (CborOutOverflowed(&out)) { return kDiceResultBufferTooSmall; @@ -423,7 +427,13 @@ static OEMCryptoResult GenerateEncodedBccPayload( // Add the profile name. CborWriteInt(kProfileNameLabel, &cbor_out); - CborWriteTstr("android.15", &cbor_out); + if (is_leaf) { + snprintf(str_buf, sizeof(str_buf), "widevine.%d", API_MAJOR_VERSION); + CborWriteTstr(str_buf, &cbor_out); + memset(str_buf, 0, sizeof(str_buf)); + } else { + CborWriteTstr("android.15", &cbor_out); + } // Add the subject public key. CborWriteInt(kSubjectPublicKeyLabel, &cbor_out); diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_clock_and_gn_layer1.c b/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_clock_and_gn_layer1.c index 15cdcdf..28bbed6 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_clock_and_gn_layer1.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_clock_and_gn_layer1.c @@ -208,7 +208,9 @@ OEMCryptoResult WTPI_TerminateClock(void) { return GetTrustedTimeAndSave(&temp, true); } -OEMCrypto_Clock_Security_Level WTPI_GetClockType(void) { return kSecureTimer; } +OEMCrypto_ClockSecurityLevel WTPI_GetClockType(void) { + return OEMCrypto_InsecureClock; +} OEMCryptoResult WTPI_GetTrustedTime(uint64_t* time_in_s) { RETURN_INVALID_CONTEXT_IF_NULL(time_in_s); diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_ref_compat.c b/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_ref_compat.c index 96ab820..1c13927 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_ref_compat.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_reference/wtpi_ref_compat.c @@ -26,7 +26,7 @@ typedef struct { int64_t time_of_license_received; int64_t time_of_first_decrypt; int64_t time_of_last_decrypt; - enum OEMCrypto_Usage_Entry_Status status; + enum OEMCrypto_UsageEntryStatus status; uint8_t mac_key_server[MAC_KEY_SIZE]; uint8_t mac_key_client[MAC_KEY_SIZE]; uint32_t index; diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/README.md b/oemcrypto/opk/oemcrypto_ta/wtpi_test/README.md index 3819e12..026f539 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/README.md +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/README.md @@ -29,7 +29,7 @@ To add more WTPI tests for a new WTPI interface - Add WTPI header file name to oemcrypto/opk/serialization/generator/api_generator.cpp and oemcrypto/opk/serialization/generator/dispatcher_generator.cpp -- Run jenkins/opk_tee_interface_tests as a quick check. It generates all of the +- Run jenkins/opk_tee_interface_test as a quick check. It generates all of the required serialization files before running the test, and will throw an error if there is a problem with the above steps. - Write test cases. If needed, create a new .cpp file and add it to the diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.c b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.c index 5ad9f67..96d27fb 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.c @@ -121,13 +121,13 @@ bool Is_Valid_OEMCryptoResult(uint32_t value) { } } -bool Is_Valid_OEMCrypto_Usage_Entry_Status(uint32_t value) { +bool Is_Valid_OEMCrypto_UsageEntryStatus(uint32_t value) { switch (value) { - case 0: /* kUnused */ - case 1: /* kActive */ - case 2: /* kInactive */ - case 3: /* kInactiveUsed */ - case 4: /* kInactiveUnused */ + case 0: /* OEMCrypto_Unused */ + case 1: /* OEMCrypto_Active */ + case 2: /* OEMCrypto_Inactive */ + case 3: /* OEMCrypto_InactiveUsed */ + case 4: /* OEMCrypto_InactiveUnused */ return true; default: return false; @@ -156,8 +156,8 @@ bool Is_Valid_OEMCrypto_LicenseType(uint32_t value) { bool Is_Valid_OEMCrypto_PrivateKeyType(uint32_t value) { switch (value) { - case 0: /* OEMCrypto_RSA_Private_Key */ - case 1: /* OEMCrypto_ECC_Private_Key */ + case 0: /* OEMCrypto_RSAPrivateKey */ + case 1: /* OEMCrypto_ECCPrivateKey */ return true; default: return false; @@ -166,9 +166,33 @@ bool Is_Valid_OEMCrypto_PrivateKeyType(uint32_t value) { bool Is_Valid_OEMCrypto_TimerDelayBase(uint32_t value) { switch (value) { - case 0: /* OEMCrypto_License_Start */ - case 1: /* OEMCrypto_License_Load */ - case 2: /* OEMCrypto_First_Decrypt */ + case 0: /* OEMCrypto_LicenseStart */ + case 1: /* OEMCrypto_LicenseLoad */ + case 2: /* OEMCrypto_FirstDecrypt */ + return true; + default: + return false; + } +} + +bool Is_Valid_OEMCrypto_ClockSecurityLevel(uint32_t value) { + switch (value) { + case 0: /* OEMCrypto_InsecureClock */ + case 1: /* OEMCrypto_MonotonicClock */ + case 2: /* OEMCrypto_SecureClock */ + case 3: /* OEMCrypto_HardwareSecureClock */ + return true; + default: + return false; + } +} + +bool Is_Valid_OEMCrypto_DeCENC_Mitigation_Option(uint32_t value) { + switch (value) { + case 0: /* OEMCrypto_DeCENC_Mitigation_Option_None */ + case 1: /* OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream */ + case 2: /* OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream */ + case 4: /* OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding */ return true; default: return false; @@ -249,6 +273,29 @@ void OPK_Pack_OEMCrypto_KeyObject(ODK_Message* msg, (const OEMCrypto_Substring*)&obj->key_control); } +void OPK_Pack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo const* obj) { + OPK_Pack_OEMCrypto_Substring( + msg, (const OEMCrypto_Substring*)&obj->authentication_key); + OPK_Pack_OEMCrypto_Substring( + msg, (const OEMCrypto_Substring*)&obj->authentication_key_iv); +} + +void OPK_Pack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2 const* obj) { + OPK_Pack_OEMCrypto_Substring(msg, (const OEMCrypto_Substring*)&obj->key_id); + OPK_Pack_OEMCrypto_Substring(msg, + (const OEMCrypto_Substring*)&obj->key_data_iv); + OPK_Pack_OEMCrypto_Substring(msg, (const OEMCrypto_Substring*)&obj->key_data); + OPK_Pack_OEMCrypto_Substring( + msg, (const OEMCrypto_Substring*)&obj->key_control_iv); + OPK_Pack_OEMCrypto_Substring(msg, + (const OEMCrypto_Substring*)&obj->key_control); + OPK_Pack_OEMCrypto_DeCENC_Mitigation_Info( + msg, + (const OEMCrypto_DeCENC_Mitigation_Info*)&obj->decenc_mitigation_info); +} + void OPK_Unpack_OEMCrypto_Substring(ODK_Message* msg, OEMCrypto_Substring* obj) { OEMCrypto_Substring tmp_obj; @@ -320,6 +367,31 @@ void OPK_Unpack_OEMCrypto_KeyObject(ODK_Message* msg, OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_control); } +void OPK_Unpack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo* obj) { + OEMCrypto_AuthenticationKeyInfo tmp_obj; + if (obj == NULL) { + obj = &tmp_obj; + } + OPK_Unpack_OEMCrypto_Substring(msg, &obj->authentication_key); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->authentication_key_iv); +} + +void OPK_Unpack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2* obj) { + OEMCrypto_KeyObjectV2 tmp_obj; + if (obj == NULL) { + obj = &tmp_obj; + } + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_id); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_data_iv); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_data); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_control_iv); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_control); + OPK_Unpack_OEMCrypto_DeCENC_Mitigation_Info(msg, + &obj->decenc_mitigation_info); +} + void OPK_PackNullable_uint64_t(ODK_Message* msg, const uint64_t* value) { OPK_PackBoolValue(msg, value == NULL); if (value) { diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.h b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.h index 4369d57..20d60c0 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.h +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/GEN_common_serializer.h @@ -24,11 +24,13 @@ extern "C" { bool SuccessResult(OEMCryptoResult result); bool Is_Valid_OEMCryptoResult(uint32_t value); -bool Is_Valid_OEMCrypto_Usage_Entry_Status(uint32_t value); +bool Is_Valid_OEMCrypto_UsageEntryStatus(uint32_t value); bool Is_Valid_OEMCrypto_ProvisioningRenewalType(uint32_t value); bool Is_Valid_OEMCrypto_LicenseType(uint32_t value); bool Is_Valid_OEMCrypto_PrivateKeyType(uint32_t value); bool Is_Valid_OEMCrypto_TimerDelayBase(uint32_t value); +bool Is_Valid_OEMCrypto_ClockSecurityLevel(uint32_t value); +bool Is_Valid_OEMCrypto_DeCENC_Mitigation_Option(uint32_t value); bool Is_Valid_OPK_OutputBuffer_Type(uint32_t value); bool Is_Valid_OPK_FeatureStatus(uint32_t value); void OPK_Pack_OEMCrypto_Substring(ODK_Message* msg, @@ -43,6 +45,10 @@ void OPK_Pack_OEMCrypto_DTCP2_CMI_Packet(ODK_Message* msg, OEMCrypto_DTCP2_CMI_Packet const* obj); void OPK_Pack_OEMCrypto_KeyObject(ODK_Message* msg, OEMCrypto_KeyObject const* obj); +void OPK_Pack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo const* obj); +void OPK_Pack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2 const* obj); void OPK_Unpack_OEMCrypto_Substring(ODK_Message* msg, OEMCrypto_Substring* obj); void OPK_Unpack_OEMCrypto_DTCP2_CMI_Descriptor_0( ODK_Message* msg, OEMCrypto_DTCP2_CMI_Descriptor_0* obj); @@ -53,6 +59,10 @@ void OPK_Unpack_OEMCrypto_DTCP2_CMI_Descriptor_2( void OPK_Unpack_OEMCrypto_DTCP2_CMI_Packet(ODK_Message* msg, OEMCrypto_DTCP2_CMI_Packet* obj); void OPK_Unpack_OEMCrypto_KeyObject(ODK_Message* msg, OEMCrypto_KeyObject* obj); +void OPK_Unpack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo* obj); +void OPK_Unpack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2* obj); void OPK_PackNullable_uint64_t(ODK_Message* msg, const uint64_t* value); void OPK_UnpackNullable_uint64_t(ODK_Message* msg, uint64_t** value); void OPK_UnpackAlloc_uint64_t(ODK_Message* msg, uint64_t** value); diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.c b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.c index cd8c52a..ab8ffcb 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.c @@ -5,6 +5,7 @@ */ #include "common_special_cases.h" +#include "GEN_common_serializer.h" #include "log_macros.h" #include "odk_attributes.h" #include "opk_serialization_base.h" @@ -167,8 +168,8 @@ void OPK_Unpack_OEMCrypto_SharedMemory(ODK_Message* message, ODK_MESSAGE_SETSTATUS(message, MESSAGE_STATUS_NULL_POINTER_ERROR); } -void OPK_Pack_OEMCrypto_Clock_Security_Level( - ODK_Message* message, const OEMCrypto_Clock_Security_Level* value) { +void OPK_Pack_OEMCrypto_ClockSecurityLevel( + ODK_Message* message, const OEMCrypto_ClockSecurityLevel* value) { if (value == NULL) { ODK_MESSAGE_SETSTATUS(message, MESSAGE_STATUS_NULL_POINTER_ERROR); return; @@ -177,8 +178,8 @@ void OPK_Pack_OEMCrypto_Clock_Security_Level( OPK_Pack_uint32_t(message, (const uint32_t*)value); } -void OPK_Unpack_OEMCrypto_Clock_Security_Level( - ODK_Message* message, OEMCrypto_Clock_Security_Level* value) { +void OPK_Unpack_OEMCrypto_ClockSecurityLevel( + ODK_Message* message, OEMCrypto_ClockSecurityLevel* value) { OPK_Unpack_uint32_t(message, (uint32_t*)value); } @@ -242,8 +243,8 @@ void OPK_Pack_OEMCrypto_HDCP_Capability( OPK_Pack_uint32_t(message, (const uint32_t*)value); } -void OPK_Pack_OEMCrypto_Usage_Entry_Status( - ODK_Message* message, const OEMCrypto_Usage_Entry_Status* status) { +void OPK_Pack_OEMCrypto_UsageEntryStatus( + ODK_Message* message, const OEMCrypto_UsageEntryStatus* status) { if (status == NULL) { ODK_MESSAGE_SETSTATUS(message, MESSAGE_STATUS_NULL_POINTER_ERROR); return; @@ -282,8 +283,8 @@ void OPK_Unpack_OEMCrypto_HDCP_Capability(ODK_Message* message, OPK_Unpack_uint32_t(message, (uint32_t*)value); } -void OPK_Unpack_OEMCrypto_Usage_Entry_Status( - ODK_Message* message, OEMCrypto_Usage_Entry_Status* status) { +void OPK_Unpack_OEMCrypto_UsageEntryStatus(ODK_Message* message, + OEMCrypto_UsageEntryStatus* status) { OPK_Unpack_uint32_t(message, (uint32_t*)status); } @@ -316,3 +317,42 @@ void OPK_Unpack_CertSignatureType(ODK_Message* message, CertSignatureType* value) { OPK_Unpack_uint32_t(message, (uint32_t*)value); } + +void OPK_Pack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info const* obj) { + OPK_Pack_uint32_t( + msg, (const OEMCrypto_DeCENC_Mitigation_Option*)&obj->mitigation_option); + switch (obj->mitigation_option) { + case OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream: + OPK_Pack_OEMCrypto_AuthenticationKeyInfo( + msg, &obj->configuration_options.authentication_key_info); + break; + case OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream: + // No configuration options + break; + case OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding: + // No configuration options + break; + default: + break; + } +} + +void OPK_Unpack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info* obj) { + OPK_Unpack_uint32_t(msg, &obj->mitigation_option); + switch (obj->mitigation_option) { + case OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream: + OPK_Unpack_OEMCrypto_AuthenticationKeyInfo( + msg, &obj->configuration_options.authentication_key_info); + break; + case OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream: + // No configuration options + break; + case OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding: + // No configuration options + break; + default: + break; + } +} \ No newline at end of file diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.h b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.h index 99fa1eb..0956e3c 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.h +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/common/common_special_cases.h @@ -50,10 +50,10 @@ void OPK_Pack_OEMCrypto_SharedMemory(ODK_Message* message, void OPK_Unpack_OEMCrypto_SharedMemory(ODK_Message* message, OEMCrypto_SharedMemory* value); -void OPK_Pack_OEMCrypto_Clock_Security_Level( - ODK_Message* msg, const OEMCrypto_Clock_Security_Level* value); -void OPK_Unpack_OEMCrypto_Clock_Security_Level( - ODK_Message* msg, OEMCrypto_Clock_Security_Level* value); +void OPK_Pack_OEMCrypto_ClockSecurityLevel( + ODK_Message* msg, const OEMCrypto_ClockSecurityLevel* value); +void OPK_Unpack_OEMCrypto_ClockSecurityLevel( + ODK_Message* msg, OEMCrypto_ClockSecurityLevel* value); void OPK_Pack_OEMCrypto_Security_Level(ODK_Message* msg, const OEMCrypto_Security_Level* value); @@ -67,8 +67,8 @@ void OPK_Pack_OEMCrypto_WatermarkingSupport( ODK_Message* msg, const OEMCrypto_WatermarkingSupport* value); void OPK_Pack_OEMCrypto_HDCP_Capability(ODK_Message* msg, const OEMCrypto_HDCP_Capability* value); -void OPK_Pack_OEMCrypto_Usage_Entry_Status( - ODK_Message* msg, const OEMCrypto_Usage_Entry_Status* status); +void OPK_Pack_OEMCrypto_UsageEntryStatus( + ODK_Message* msg, const OEMCrypto_UsageEntryStatus* status); void OPK_Unpack_OEMCrypto_Security_Level(ODK_Message* msg, OEMCrypto_Security_Level* value); void OPK_Unpack_OEMCrypto_ProvisioningMethod( @@ -84,10 +84,13 @@ void OPK_Pack_OEMCrypto_SignatureHashAlgorithm( ODK_Message* message, const OEMCrypto_SignatureHashAlgorithm* value); void OPK_Unpack_OEMCrypto_SignatureHashAlgorithm( ODK_Message* message, OEMCrypto_SignatureHashAlgorithm* value); -void OPK_Unpack_OEMCrypto_Usage_Entry_Status( - ODK_Message* message, OEMCrypto_Usage_Entry_Status* status); +void OPK_Unpack_OEMCrypto_UsageEntryStatus(ODK_Message* message, + OEMCrypto_UsageEntryStatus* status); void OPK_Pack_CertSignatureType(ODK_Message* message, CertSignatureType* value); void OPK_Unpack_CertSignatureType(ODK_Message* message, CertSignatureType* value); - +void OPK_Pack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info const* obj); +void OPK_Unpack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info* obj); #endif diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/crypto_test.cpp b/oemcrypto/opk/oemcrypto_ta/wtpi_test/crypto_test.cpp index 7903616..e49e763 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/crypto_test.cpp +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/crypto_test.cpp @@ -1780,11 +1780,12 @@ TEST_F(CryptoTest, ECCLoadKeyAndSign) { TEST_F(CryptoTest, GetBCCTypeProv4AndOthers) { OEMCrypto_ProvisioningMethod method = WTPI_GetProvisioningMethod(); - OEMCrypto_BCCType bcc_type; + OEMCrypto_BCCType bcc_type = static_cast(-1); OEMCryptoResult res = WTPI_GetBCCType(&bcc_type); if (method == OEMCrypto_BootCertificateChain) { ASSERT_EQ(res, OEMCrypto_SUCCESS); + ASSERT_NE(bcc_type, static_cast(-1)); } else { ASSERT_EQ(res, OEMCrypto_ERROR_NOT_IMPLEMENTED); } diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_oemcrypto_tee_test_api.c b/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_oemcrypto_tee_test_api.c index 3192394..3118ddd 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_oemcrypto_tee_test_api.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_oemcrypto_tee_test_api.c @@ -2065,9 +2065,9 @@ cleanup_and_return: return result; } -OEMCrypto_Clock_Security_Level WTPI_GetClockType(void) { +OEMCrypto_ClockSecurityLevel WTPI_GetClockType(void) { pthread_mutex_lock(&api_lock); - OEMCrypto_Clock_Security_Level result = kInsecureClock; + OEMCrypto_ClockSecurityLevel result = OEMCrypto_InsecureClock; ODK_Message request = ODK_Message_Create(NULL, 0); ODK_Message response = ODK_Message_Create(NULL, 0); diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.c b/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.c index 7fb1e8c..77208ea 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.c @@ -2128,12 +2128,15 @@ ODK_Message OPK_Pack_GetClockType_Request(void) { } void OPK_Unpack_GetClockType_Response(ODK_Message* msg, - OEMCrypto_Clock_Security_Level* result) { + OEMCrypto_ClockSecurityLevel* result) { uint32_t api_value = UINT32_MAX; OPK_Unpack_uint32_t(msg, &api_value); if (api_value != 10057) ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_API_VALUE_ERROR); - OPK_Unpack_OEMCrypto_Clock_Security_Level(msg, result); + OPK_Unpack_uint32_t(msg, result); + if (!Is_Valid_OEMCrypto_ClockSecurityLevel(*result)) { + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); + } OPK_UnpackEOM(msg); OPK_SharedBuffer_FinalizeUnpacking(); } diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.h b/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.h index a9bb7fc..4930619 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.h +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/ree/GEN_ree_serializer.h @@ -327,7 +327,7 @@ void OPK_Unpack_TerminateClock_Response(ODK_Message* msg, OEMCryptoResult* result); ODK_Message OPK_Pack_GetClockType_Request(void); void OPK_Unpack_GetClockType_Response(ODK_Message* msg, - OEMCrypto_Clock_Security_Level* result); + OEMCrypto_ClockSecurityLevel* result); ODK_Message OPK_Pack_GetSecurityLevel_Request(void); void OPK_Unpack_GetSecurityLevel_Response(ODK_Message* msg, OEMCrypto_Security_Level* result); diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_dispatcher.c b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_dispatcher.c index 25cca8d..c0c7540 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_dispatcher.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_dispatcher.c @@ -84,6 +84,23 @@ void OPK_Init_OEMCrypto_KeyObject(OEMCrypto_KeyObject* obj) { OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_control); } +void OPK_Init_OEMCrypto_AuthenticationKeyInfo( + OEMCrypto_AuthenticationKeyInfo* obj) { + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->authentication_key); + OPK_Init_OEMCrypto_Substring( + (OEMCrypto_Substring*)&obj->authentication_key_iv); +} + +void OPK_Init_OEMCrypto_KeyObjectV2(OEMCrypto_KeyObjectV2* obj) { + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_id); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_data_iv); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_data); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_control_iv); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_control); + OPK_Init_OEMCrypto_DeCENC_Mitigation_Info( + (OEMCrypto_DeCENC_Mitigation_Info*)&obj->decenc_mitigation_info); +} + /* See opk_dispatcher.h for definition of OPK_DispatchMessage() */ ODK_MessageStatus OPK_DispatchMessage(ODK_Message* request, ODK_Message* response) { @@ -1240,9 +1257,8 @@ ODK_MessageStatus OPK_DispatchMessage(ODK_Message* request, { OPK_Unpack_GetClockType_Request(request); if (!ODK_Message_IsValid(request)) goto handle_invalid_request; - OEMCrypto_Clock_Security_Level result; - OPK_Init_OEMCrypto_Clock_Security_Level( - (OEMCrypto_Clock_Security_Level*)&result); + OEMCrypto_ClockSecurityLevel result; + OPK_Init_uint32_t((uint32_t*)&result); LOGD("GetClockType"); result = WTPI_GetClockType(); *response = OPK_Pack_GetClockType_Response(result); diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.c b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.c index 1dfdd04..4a1ab8b 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.c @@ -1647,11 +1647,11 @@ void OPK_Unpack_GetClockType_Request(ODK_Message* msg) { } ODK_Message OPK_Pack_GetClockType_Response( - OEMCrypto_Clock_Security_Level result) { + OEMCrypto_ClockSecurityLevel result) { uint32_t api_value = 10057; /* from _tee10057 */ ODK_Message msg = TOS_Transport_GetResponse(); OPK_Pack_uint32_t(&msg, &api_value); - OPK_Pack_OEMCrypto_Clock_Security_Level(&msg, &result); + OPK_Pack_uint32_t(&msg, &result); OPK_PackEOM(&msg); OPK_SharedBuffer_FinalizePacking(); return msg; diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.h b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.h index f24fe23..f616a66 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.h +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/GEN_tee_serializer.h @@ -297,8 +297,7 @@ ODK_Message OPK_Pack_InitializeClock_Response(OEMCryptoResult result); void OPK_Unpack_TerminateClock_Request(ODK_Message* msg); ODK_Message OPK_Pack_TerminateClock_Response(OEMCryptoResult result); void OPK_Unpack_GetClockType_Request(ODK_Message* msg); -ODK_Message OPK_Pack_GetClockType_Response( - OEMCrypto_Clock_Security_Level result); +ODK_Message OPK_Pack_GetClockType_Response(OEMCrypto_ClockSecurityLevel result); void OPK_Unpack_GetSecurityLevel_Request(ODK_Message* msg); ODK_Message OPK_Pack_GetSecurityLevel_Response(OEMCrypto_Security_Level result); void OPK_Unpack_GetProvisioningMethod_Request(ODK_Message* msg); diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.c b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.c index 71fe99e..ab73051 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.c +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.c @@ -69,10 +69,9 @@ void OPK_Init_KeySize(KeySize* obj) { } } -void OPK_Init_OEMCrypto_Clock_Security_Level( - OEMCrypto_Clock_Security_Level* obj) { +void OPK_Init_OEMCrypto_ClockSecurityLevel(OEMCrypto_ClockSecurityLevel* obj) { if (obj) { - memset(obj, 0, sizeof(OEMCrypto_Clock_Security_Level)); + memset(obj, 0, sizeof(OEMCrypto_ClockSecurityLevel)); } } diff --git a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.h b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.h index 23435de..38a9d67 100644 --- a/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.h +++ b/oemcrypto/opk/oemcrypto_ta/wtpi_test/tee/tee_special_cases.h @@ -19,8 +19,7 @@ void OPK_Init_AsymmetricKeyType(AsymmetricKeyType* obj); void OPK_Init_WTPI_AsymmetricKey_Handle(WTPI_AsymmetricKey_Handle* obj); void OPK_Init_RSA_Padding_Scheme(RSA_Padding_Scheme* obj); void OPK_Init_KeySize(KeySize* obj); -void OPK_Init_OEMCrypto_Clock_Security_Level( - OEMCrypto_Clock_Security_Level* obj); +void OPK_Init_OEMCrypto_ClockSecurityLevel(OEMCrypto_ClockSecurityLevel* obj); void OPK_Init_OEMCrypto_Security_Level(OEMCrypto_Security_Level* obj); void OPK_Init_OEMCrypto_ProvisioningMethod(OEMCrypto_ProvisioningMethod* obj); void OPK_Init_OEMCrypto_WatermarkingSupport(OEMCrypto_WatermarkingSupport* obj); diff --git a/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/Makefile b/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/Makefile index 1931349..af2bb09 100644 --- a/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/Makefile +++ b/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/Makefile @@ -55,7 +55,16 @@ cflags += \ -DWTPI_BUILD_INFO=\"$(WTPI_BUILD_INFO)\" \ -DSUPPORT_CAS \ -DWV_POSIX_RESOURCE_ID=\"$(project)\" \ - -D_DEFAULT_SOURCE + -DOPK_CONFIG_SOC_VENDOR_NAME=$(SOC_VENDOR) \ + -DOPK_CONFIG_SOC_MODEL_NAME=$(SOC_MODEL) \ + -DOPK_CONFIG_TEE_OS_NAME=$(TEE_OS) \ + -DOPK_CONFIG_TEE_OS_VERSION=$(TEE_VERSION) \ + -DOPK_CONFIG_DEVICE_FORM_FACTOR=$(DEVICE_FORM_FACTOR) \ + -DOPK_CONFIG_IMPLEMENTER_NAME=$(IMPLEMENTER) \ + -DOPK_CONFIG_PROVISIONING_METHOD=$(PROVISIONING_METHOD) \ + -D_DEFAULT_SOURCE \ + -fstack-protector-strong \ + -D_DEBUG ldflags = \ -lrt -lpthread \ diff --git a/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/data_share.h b/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/data_share.h index 8b623e0..095930a 100644 --- a/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/data_share.h +++ b/oemcrypto/opk/ports/linux/cas/tee/tee_simulator_cas/data_share.h @@ -9,7 +9,7 @@ #include #include #include "OEMCryptoCENC.h" -#include "wtpi_config_macros.h" +#include "opk_config.h" #include "wtpi_crypto_and_key_management_interface_layer1.h" #ifdef __cplusplus diff --git a/oemcrypto/opk/ports/linux/common/posix_services.h b/oemcrypto/opk/ports/linux/common/posix_services.h index 67242b2..5bdc54c 100644 --- a/oemcrypto/opk/ports/linux/common/posix_services.h +++ b/oemcrypto/opk/ports/linux/common/posix_services.h @@ -48,6 +48,9 @@ class Semaphore : public Resource { } } + Semaphore(const Semaphore&) = delete; + void operator=(const Semaphore&) = delete; + ~Semaphore() { if (sem_) { sem_close(sem_); @@ -63,9 +66,6 @@ class Semaphore : public Resource { private: sem_t* sem_; - - Semaphore(const Semaphore&) = delete; - void operator=(const Semaphore&) = delete; }; // Manage blocks of shared memory across process boundaries using @@ -83,6 +83,9 @@ class SharedMemory : public Resource { Open(); } + SharedMemory(const SharedMemory&) = delete; + void operator=(const SharedMemory&) = delete; + ~SharedMemory() { Close(); } uint8_t* GetAddress() const { return address_; } @@ -127,10 +130,8 @@ class SharedMemory : public Resource { fd_ = -1; } } - - SharedMemory(const SharedMemory&) = delete; - void operator=(const SharedMemory&) = delete; }; -}; // namespace posix + +} // namespace posix #endif /* LINUX_IPC_POSIX_SERVICES_H_ */ diff --git a/oemcrypto/opk/ports/linux/host/oemcrypto_unittests/oemcrypto_ref_compat_test.cpp b/oemcrypto/opk/ports/linux/host/oemcrypto_unittests/oemcrypto_ref_compat_test.cpp index d28a2c4..8123b66 100644 --- a/oemcrypto/opk/ports/linux/host/oemcrypto_unittests/oemcrypto_ref_compat_test.cpp +++ b/oemcrypto/opk/ports/linux/host/oemcrypto_unittests/oemcrypto_ref_compat_test.cpp @@ -19,7 +19,7 @@ TEST_F(OEMCryptoProv2ReferenceWrap, LoadWrappedKey) { OEMCryptoResult sts = OEMCrypto_ERROR_INVALID_CONTEXT; sts = OEMCrypto_LoadDRMPrivateKey( - session_.session_id(), OEMCrypto_RSA_Private_Key, wrapped_rsa_key.data(), + session_.session_id(), OEMCrypto_RSAPrivateKey, wrapped_rsa_key.data(), wrapped_rsa_key.size()); ASSERT_EQ(OEMCrypto_SUCCESS, sts); diff --git a/oemcrypto/opk/ports/linux/oemcrypto_tee_simulator.gyp b/oemcrypto/opk/ports/linux/oemcrypto_tee_simulator.gyp index 3108e1a..ad29bfc 100644 --- a/oemcrypto/opk/ports/linux/oemcrypto_tee_simulator.gyp +++ b/oemcrypto/opk/ports/linux/oemcrypto_tee_simulator.gyp @@ -36,9 +36,9 @@ '<(serialization_adapter_dir)/tos_shared_memory.cpp', '<(serialization_adapter_dir)/tos_transport.cpp', '<(util_dir)/src/platform.cpp', - '<(util_dir)/src/rw_lock.cpp', '<(util_dir)/src/string_conversions.cpp', '<(util_dir)/src/string_format.cpp', + '<(util_dir)/src/wv_duration.cpp', ], 'dependencies': [ '<(serialization_dir)/tee/tee.gyp:opk_tee', @@ -60,9 +60,9 @@ '<(serialization_adapter_dir)/tos_shared_memory.cpp', '<(serialization_adapter_dir)/tos_transport.cpp', '<(util_dir)/src/platform.cpp', - '<(util_dir)/src/rw_lock.cpp', '<(util_dir)/src/string_conversions.cpp', '<(util_dir)/src/string_format.cpp', + '<(util_dir)/src/wv_duration.cpp', ], 'dependencies': [ '<(serialization_dir)/tee/tee.gyp:opk_tee', diff --git a/oemcrypto/opk/ports/linux/ta/common/clock.cpp b/oemcrypto/opk/ports/linux/ta/common/clock.cpp index 1eda0cb..67668d0 100644 --- a/oemcrypto/opk/ports/linux/ta/common/clock.cpp +++ b/oemcrypto/opk/ports/linux/ta/common/clock.cpp @@ -3,18 +3,36 @@ // Agreement. // // Clock - implemented using the standard linux time library - #include "clock.h" +#include + #include -namespace wvutil { +#include "wv_timestamp.h" +namespace wvutil { int64_t Clock::GetCurrentTime() { - struct timeval tv; - tv.tv_sec = tv.tv_usec = 0; - gettimeofday(&tv, nullptr); + struct timeval tv = {}; + if (gettimeofday(&tv, nullptr) != 0) { + // This might be called within logging functions, do not + // log error. + return 0; + } return tv.tv_sec; } +Timestamp Clock::GetCurrentTimestamp() { + struct timeval tv = {}; + if (gettimeofday(&tv, nullptr) != 0) { + // This might be called within logging functions, do not + // log error. + return Timestamp(); + } + constexpr int kMaxUSec = 999999; + if (tv.tv_sec < 0 || tv.tv_usec < 0 || tv.tv_usec > kMaxUSec) + return Timestamp(); + return Timestamp::FromUnixSeconds(static_cast(tv.tv_sec), + static_cast(tv.tv_usec / 1000)); +} } // namespace wvutil diff --git a/oemcrypto/opk/ports/linux/wtpi_tee_simulator.gyp b/oemcrypto/opk/ports/linux/wtpi_tee_simulator.gyp index f376eaa..d78a12d 100644 --- a/oemcrypto/opk/ports/linux/wtpi_tee_simulator.gyp +++ b/oemcrypto/opk/ports/linux/wtpi_tee_simulator.gyp @@ -46,7 +46,6 @@ '<(platform_specific_dir)/file_store.cpp', '<(platform_specific_dir)/log.cpp', '<(util_dir)/src/platform.cpp', - '<(util_dir)/src/rw_lock.cpp', '<(util_dir)/src/string_conversions.cpp', '<(util_dir)/src/string_format.cpp', ], diff --git a/oemcrypto/opk/ports/optee/ta/common/wtpi_impl/genkeypair_ecc.c b/oemcrypto/opk/ports/optee/ta/common/wtpi_impl/genkeypair_ecc.c index 2c0fd9a..9a62998 100644 --- a/oemcrypto/opk/ports/optee/ta/common/wtpi_impl/genkeypair_ecc.c +++ b/oemcrypto/opk/ports/optee/ta/common/wtpi_impl/genkeypair_ecc.c @@ -3,7 +3,6 @@ * source code may only be used and distributed under the Widevine * License Agreement. */ - #include "asymmetric_key.h" #include "crypto_util.h" #include "oemcrypto_check_macros.h" @@ -31,6 +30,7 @@ static OEMCryptoResult NewEccKeyPair(uint8_t* private_key_data, TEE_Attribute attr; TEE_InitValueAttribute(&attr, TEE_ATTR_ECC_CURVE, curve_type, 0); + tee_res = TEE_GenerateKey(key, KEY_SIZE_256 * 8, &attr, 1); if (tee_res != TEE_SUCCESS) { EMSG("TEE_GenerateKey failed with result 0x%x", tee_res); diff --git a/oemcrypto/opk/ports/trusty/ta/reference/rustfmt.toml b/oemcrypto/opk/ports/trusty/ta/reference/rustfmt.toml new file mode 100644 index 0000000..cefaa42 --- /dev/null +++ b/oemcrypto/opk/ports/trusty/ta/reference/rustfmt.toml @@ -0,0 +1,5 @@ +# Android Format Style + +edition = "2021" +use_small_heuristics = "Max" +newline_style = "Unix" diff --git a/oemcrypto/opk/serialization/common/GEN_common_serializer.c b/oemcrypto/opk/serialization/common/GEN_common_serializer.c index fed33e1..d44bab7 100644 --- a/oemcrypto/opk/serialization/common/GEN_common_serializer.c +++ b/oemcrypto/opk/serialization/common/GEN_common_serializer.c @@ -116,13 +116,13 @@ bool Is_Valid_OEMCryptoResult(uint32_t value) { } } -bool Is_Valid_OEMCrypto_Usage_Entry_Status(uint32_t value) { +bool Is_Valid_OEMCrypto_UsageEntryStatus(uint32_t value) { switch (value) { - case 0: /* kUnused */ - case 1: /* kActive */ - case 2: /* kInactive */ - case 3: /* kInactiveUsed */ - case 4: /* kInactiveUnused */ + case 0: /* OEMCrypto_Unused */ + case 1: /* OEMCrypto_Active */ + case 2: /* OEMCrypto_Inactive */ + case 3: /* OEMCrypto_InactiveUsed */ + case 4: /* OEMCrypto_InactiveUnused */ return true; default: return false; @@ -151,8 +151,8 @@ bool Is_Valid_OEMCrypto_LicenseType(uint32_t value) { bool Is_Valid_OEMCrypto_PrivateKeyType(uint32_t value) { switch (value) { - case 0: /* OEMCrypto_RSA_Private_Key */ - case 1: /* OEMCrypto_ECC_Private_Key */ + case 0: /* OEMCrypto_RSAPrivateKey */ + case 1: /* OEMCrypto_ECCPrivateKey */ return true; default: return false; @@ -161,9 +161,33 @@ bool Is_Valid_OEMCrypto_PrivateKeyType(uint32_t value) { bool Is_Valid_OEMCrypto_TimerDelayBase(uint32_t value) { switch (value) { - case 0: /* OEMCrypto_License_Start */ - case 1: /* OEMCrypto_License_Load */ - case 2: /* OEMCrypto_First_Decrypt */ + case 0: /* OEMCrypto_LicenseStart */ + case 1: /* OEMCrypto_LicenseLoad */ + case 2: /* OEMCrypto_FirstDecrypt */ + return true; + default: + return false; + } +} + +bool Is_Valid_OEMCrypto_ClockSecurityLevel(uint32_t value) { + switch (value) { + case 0: /* OEMCrypto_InsecureClock */ + case 1: /* OEMCrypto_MonotonicClock */ + case 2: /* OEMCrypto_SecureClock */ + case 3: /* OEMCrypto_HardwareSecureClock */ + return true; + default: + return false; + } +} + +bool Is_Valid_OEMCrypto_DeCENC_Mitigation_Option(uint32_t value) { + switch (value) { + case 0: /* OEMCrypto_DeCENC_Mitigation_Option_None */ + case 1: /* OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream */ + case 2: /* OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream */ + case 4: /* OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding */ return true; default: return false; @@ -208,18 +232,6 @@ bool Is_Valid_OEMCrypto_Algorithm(uint32_t value) { } } -bool Is_Valid_OEMCrypto_Clock_Security_Level(uint32_t value) { - switch (value) { - case 0: /* kInsecureClock */ - case 1: /* kMonotonicClock */ - case 2: /* kSecureClock */ - case 3: /* kHardwareSecureClock */ - return true; - default: - return false; - } -} - bool Is_Valid_OEMCrypto_HDCP_Capability(uint32_t value) { switch (value) { case 0: /* HDCP_NONE */ @@ -385,6 +397,29 @@ void OPK_Pack_OEMCrypto_KeyObject(ODK_Message* msg, (const OEMCrypto_Substring*)&obj->key_control); } +void OPK_Pack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo const* obj) { + OPK_Pack_OEMCrypto_Substring( + msg, (const OEMCrypto_Substring*)&obj->authentication_key); + OPK_Pack_OEMCrypto_Substring( + msg, (const OEMCrypto_Substring*)&obj->authentication_key_iv); +} + +void OPK_Pack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2 const* obj) { + OPK_Pack_OEMCrypto_Substring(msg, (const OEMCrypto_Substring*)&obj->key_id); + OPK_Pack_OEMCrypto_Substring(msg, + (const OEMCrypto_Substring*)&obj->key_data_iv); + OPK_Pack_OEMCrypto_Substring(msg, (const OEMCrypto_Substring*)&obj->key_data); + OPK_Pack_OEMCrypto_Substring( + msg, (const OEMCrypto_Substring*)&obj->key_control_iv); + OPK_Pack_OEMCrypto_Substring(msg, + (const OEMCrypto_Substring*)&obj->key_control); + OPK_Pack_OEMCrypto_DeCENC_Mitigation_Info( + msg, + (const OEMCrypto_DeCENC_Mitigation_Info*)&obj->decenc_mitigation_info); +} + void OPK_Pack_OEMCrypto_SubSampleDescription( ODK_Message* msg, OEMCrypto_SubSampleDescription const* obj) { OPK_Pack_size_t(msg, (const size_t*)&obj->num_bytes_clear); @@ -513,6 +548,31 @@ void OPK_Unpack_OEMCrypto_KeyObject(ODK_Message* msg, OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_control); } +void OPK_Unpack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo* obj) { + OEMCrypto_AuthenticationKeyInfo tmp_obj; + if (obj == NULL) { + obj = &tmp_obj; + } + OPK_Unpack_OEMCrypto_Substring(msg, &obj->authentication_key); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->authentication_key_iv); +} + +void OPK_Unpack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2* obj) { + OEMCrypto_KeyObjectV2 tmp_obj; + if (obj == NULL) { + obj = &tmp_obj; + } + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_id); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_data_iv); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_data); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_control_iv); + OPK_Unpack_OEMCrypto_Substring(msg, &obj->key_control); + OPK_Unpack_OEMCrypto_DeCENC_Mitigation_Info(msg, + &obj->decenc_mitigation_info); +} + void OPK_Unpack_OEMCrypto_SubSampleDescription( ODK_Message* msg, OEMCrypto_SubSampleDescription* obj) { OEMCrypto_SubSampleDescription tmp_obj; diff --git a/oemcrypto/opk/serialization/common/GEN_common_serializer.h b/oemcrypto/opk/serialization/common/GEN_common_serializer.h index 62bbd10..362f4f5 100644 --- a/oemcrypto/opk/serialization/common/GEN_common_serializer.h +++ b/oemcrypto/opk/serialization/common/GEN_common_serializer.h @@ -19,15 +19,16 @@ extern "C" { bool SuccessResult(OEMCryptoResult result); bool Is_Valid_OEMCryptoResult(uint32_t value); -bool Is_Valid_OEMCrypto_Usage_Entry_Status(uint32_t value); +bool Is_Valid_OEMCrypto_UsageEntryStatus(uint32_t value); bool Is_Valid_OEMCrypto_ProvisioningRenewalType(uint32_t value); bool Is_Valid_OEMCrypto_LicenseType(uint32_t value); bool Is_Valid_OEMCrypto_PrivateKeyType(uint32_t value); bool Is_Valid_OEMCrypto_TimerDelayBase(uint32_t value); +bool Is_Valid_OEMCrypto_ClockSecurityLevel(uint32_t value); +bool Is_Valid_OEMCrypto_DeCENC_Mitigation_Option(uint32_t value); bool Is_Valid_OEMCryptoBufferType(uint32_t value); bool Is_Valid_OEMCryptoCipherMode(uint32_t value); bool Is_Valid_OEMCrypto_Algorithm(uint32_t value); -bool Is_Valid_OEMCrypto_Clock_Security_Level(uint32_t value); bool Is_Valid_OEMCrypto_HDCP_Capability(uint32_t value); bool Is_Valid_OEMCrypto_DTCP2_Capability(uint32_t value); bool Is_Valid_OEMCrypto_ProvisioningMethod(uint32_t value); @@ -49,6 +50,10 @@ void OPK_Pack_OEMCrypto_DTCP2_CMI_Packet(ODK_Message* msg, OEMCrypto_DTCP2_CMI_Packet const* obj); void OPK_Pack_OEMCrypto_KeyObject(ODK_Message* msg, OEMCrypto_KeyObject const* obj); +void OPK_Pack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo const* obj); +void OPK_Pack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2 const* obj); void OPK_Pack_OEMCrypto_SubSampleDescription( ODK_Message* msg, OEMCrypto_SubSampleDescription const* obj); void OPK_Pack_OEMCrypto_CENCEncryptPatternDesc( @@ -71,6 +76,10 @@ void OPK_Unpack_OEMCrypto_DTCP2_CMI_Descriptor_2( void OPK_Unpack_OEMCrypto_DTCP2_CMI_Packet(ODK_Message* msg, OEMCrypto_DTCP2_CMI_Packet* obj); void OPK_Unpack_OEMCrypto_KeyObject(ODK_Message* msg, OEMCrypto_KeyObject* obj); +void OPK_Unpack_OEMCrypto_AuthenticationKeyInfo( + ODK_Message* msg, OEMCrypto_AuthenticationKeyInfo* obj); +void OPK_Unpack_OEMCrypto_KeyObjectV2(ODK_Message* msg, + OEMCrypto_KeyObjectV2* obj); void OPK_Unpack_OEMCrypto_SubSampleDescription( ODK_Message* msg, OEMCrypto_SubSampleDescription* obj); void OPK_Unpack_OEMCrypto_CENCEncryptPatternDesc( diff --git a/oemcrypto/opk/serialization/common/common_special_cases.c b/oemcrypto/opk/serialization/common/common_special_cases.c index 96c0cbe..2d0d1ca 100644 --- a/oemcrypto/opk/serialization/common/common_special_cases.c +++ b/oemcrypto/opk/serialization/common/common_special_cases.c @@ -188,3 +188,42 @@ void OPK_Unpack_OEMCrypto_SampleDescription(ODK_Message* msg, obj->subsamples = subsamples; } + +void OPK_Pack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info const* obj) { + OPK_Pack_uint32_t( + msg, (const OEMCrypto_DeCENC_Mitigation_Option*)&obj->mitigation_option); + switch (obj->mitigation_option) { + case OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream: + OPK_Pack_OEMCrypto_AuthenticationKeyInfo( + msg, &obj->configuration_options.authentication_key_info); + break; + case OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream: + // No configuration options + break; + case OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding: + // No configuration options + break; + default: + break; + } +} + +void OPK_Unpack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info* obj) { + OPK_Unpack_uint32_t(msg, &obj->mitigation_option); + switch (obj->mitigation_option) { + case OEMCrypto_DeCENC_Mitigation_Option_Authenticate_Bitstream: + OPK_Unpack_OEMCrypto_AuthenticationKeyInfo( + msg, &obj->configuration_options.authentication_key_info); + break; + case OEMCrypto_DeCENC_Mitigation_Option_Validate_Bitstream: + // No configuration options + break; + case OEMCrypto_DeCENC_Mitigation_Option_Restrict_Decoding: + // No configuration options + break; + default: + break; + } +} diff --git a/oemcrypto/opk/serialization/common/common_special_cases.h b/oemcrypto/opk/serialization/common/common_special_cases.h index 2385856..9883dc6 100644 --- a/oemcrypto/opk/serialization/common/common_special_cases.h +++ b/oemcrypto/opk/serialization/common/common_special_cases.h @@ -15,6 +15,7 @@ extern "C" { #include #include "OEMCryptoCENC.h" +#include "OEMCryptoCENCCommon.h" #include "odk_message.h" /* @@ -35,6 +36,11 @@ void OPK_Pack_OEMCrypto_SampleDescription( void OPK_Unpack_OEMCrypto_SampleDescription(ODK_Message* msg, OEMCrypto_SampleDescription* obj); +void OPK_Pack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info const* obj); +void OPK_Unpack_OEMCrypto_DeCENC_Mitigation_Info( + ODK_Message* msg, OEMCrypto_DeCENC_Mitigation_Info* obj); + #ifdef __cplusplus } // extern "C" #endif diff --git a/oemcrypto/opk/serialization/common/include/marshaller_base.h b/oemcrypto/opk/serialization/common/include/marshaller_base.h index bb376ce..ce229fe 100644 --- a/oemcrypto/opk/serialization/common/include/marshaller_base.h +++ b/oemcrypto/opk/serialization/common/include/marshaller_base.h @@ -51,6 +51,9 @@ void OPK_Init_OEMCrypto_InputOutputPair(OEMCrypto_InputOutputPair* obj); void OPK_Init_OEMCrypto_SampleDescription(OEMCrypto_SampleDescription* obj); +void OPK_Init_OEMCrypto_DeCENC_Mitigation_Info( + OEMCrypto_DeCENC_Mitigation_Info* obj); + #ifdef __cplusplus } // extern "C" #endif diff --git a/oemcrypto/opk/serialization/common/marshaller_base.c b/oemcrypto/opk/serialization/common/marshaller_base.c index cb72321..97d3954 100644 --- a/oemcrypto/opk/serialization/common/marshaller_base.c +++ b/oemcrypto/opk/serialization/common/marshaller_base.c @@ -40,7 +40,6 @@ void OPK_Init_c_str(char** value) { *value = NULL; } } - void OPK_Init_uint8_t(uint8_t* value) { if (value) { *value = ~0; @@ -106,3 +105,11 @@ void OPK_Init_OEMCrypto_SampleDescription(OEMCrypto_SampleDescription* obj) { OPK_Init_OEMCrypto_InputOutputPair((OEMCrypto_InputOutputPair*)&obj->buffers); OPK_InitMemory(&obj->iv[0], sizeof(obj->iv)); } + +void OPK_Init_OEMCrypto_DeCENC_Mitigation_Info( + OEMCrypto_DeCENC_Mitigation_Info* obj) { + if (obj) { + memset(obj, 0, sizeof(OEMCrypto_DeCENC_Mitigation_Info)); + obj->mitigation_option = OEMCrypto_DeCENC_Mitigation_Option_None; + } +} diff --git a/oemcrypto/opk/serialization/common/shared_buffer_allocator.c b/oemcrypto/opk/serialization/common/shared_buffer_allocator.c index 1c0efbd..ef457f5 100644 --- a/oemcrypto/opk/serialization/common/shared_buffer_allocator.c +++ b/oemcrypto/opk/serialization/common/shared_buffer_allocator.c @@ -122,7 +122,7 @@ void OPK_SharedBuffer_Terminate(void) { * allocated. */ void OPK_SharedBuffer_Reset(void) { - memset(&allocations_[0], 0, sizeof(allocations_)); + memset(&allocations_[0], 0, buffer_count_ * sizeof(allocations_[0])); next_buffer_offset_ = 0; next_buffer_index_ = 0; buffer_count_ = 0; diff --git a/oemcrypto/opk/serialization/os_interfaces/opk_dispatcher.h b/oemcrypto/opk/serialization/os_interfaces/opk_dispatcher.h index 22e7e76..2759874 100644 --- a/oemcrypto/opk/serialization/os_interfaces/opk_dispatcher.h +++ b/oemcrypto/opk/serialization/os_interfaces/opk_dispatcher.h @@ -42,16 +42,16 @@ extern "C" { * The message delivery protocol must handle timeouts and error conditions. * Those are out of scope of this function. * - * OPK_Initialize() must be called before using OPK_DispatchMesage(). + * OPK_Initialize() must be called before using OPK_DispatchMessage(). * * Parameters: * request - The received message that is to be processed. The caller - * retains ownership of the request messsage and must release any + * retains ownership of the request message and must release any * memory associated with it after OPK_DispatchMessage returns. * * response - A newly created message that should be sent to the REE using * a message delivery protocol. The caller takes ownership of the request - * messsage and must release any memory associated with it after + * message and must release any memory associated with it after * OPK_DispatchMessage returns. * * Return value: diff --git a/oemcrypto/opk/serialization/os_interfaces/tos_transport_interface.h b/oemcrypto/opk/serialization/os_interfaces/tos_transport_interface.h index 7e3087a..44f6e17 100644 --- a/oemcrypto/opk/serialization/os_interfaces/tos_transport_interface.h +++ b/oemcrypto/opk/serialization/os_interfaces/tos_transport_interface.h @@ -45,8 +45,8 @@ extern "C" { * the receive function to get a message, then invokes the dispatcher * function to process the message, and calls the send function to * deliver the reply. This allows the trusted app to control the - * message processing loop. See dispatcher_interface.h for a - * definition of this interface. + * message processing loop. See opk_dispatcher.h for a definition of this + * interface. * * @{ */ @@ -69,10 +69,9 @@ typedef enum { /* * OPK_TRANSPORT_STATUS_IO_ERROR must be returned from - * TOS_Transport_SendMessage or TOS_Transport_ReceiveMessage if the - * transport interface was unable to deliver or receive a message - * for any reason. The transport implementation should be designed - * to be robust against communication failures, e.g. by providing + * TOS_Transport_SendMessage if the transport interface was unable to deliver + * or receive a message for any reason. The transport implementation should be + * designed to be robust against communication failures, e.g. by providing * logic to retry delivery or other techniques if possible. The opk * library does not make any attempts to improve communication * reliability using these techniques. @@ -146,7 +145,7 @@ ODK_Message TOS_Transport_GetRequest(void); * * An implementation may opt to allocate the payload buffer of the * message as shared memory to avoid having to copy the response data - * into the message at the time it is sent to the TEE. + * into the message at the time it is sent to the REE. * * A single common payload buffer may be used for both the request and * response messages. They are used sequentially, i.e. all of the data diff --git a/oemcrypto/opk/serialization/ree/GEN_oemcrypto_api.c b/oemcrypto/opk/serialization/ree/GEN_oemcrypto_api.c index 15c484b..4fe967c 100644 --- a/oemcrypto/opk/serialization/ree/GEN_oemcrypto_api.c +++ b/oemcrypto/opk/serialization/ree/GEN_oemcrypto_api.c @@ -2777,7 +2777,7 @@ cleanup_and_return: } OEMCRYPTO_API OEMCryptoResult OEMCrypto_GetUsageEntryInfo( - OEMCrypto_SESSION session, OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_SESSION session, OEMCrypto_UsageEntryStatus* status, int64_t* seconds_since_license_received, int64_t* seconds_since_first_decrypt) { pthread_mutex_lock(&api_lock); @@ -3037,41 +3037,6 @@ cleanup_and_return: return result; } -OEMCRYPTO_API OEMCryptoResult OEMCrypto_InstallOemPrivateKey( - OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, - const uint8_t* wrapped_private_key, size_t wrapped_private_key_length) { - pthread_mutex_lock(&api_lock); - OEMCryptoResult result = OEMCrypto_ERROR_UNKNOWN_FAILURE; - ODK_Message request = ODK_Message_Create(NULL, 0); - ODK_Message response = ODK_Message_Create(NULL, 0); - - API_Initialize(); - request = OPK_Pack_InstallOemPrivateKey_Request( - session, key_type, wrapped_private_key, wrapped_private_key_length); - if (ODK_Message_GetStatus(&request) != MESSAGE_STATUS_OK) { - if (ODK_Message_GetStatus(&request) == MESSAGE_STATUS_BUFFER_TOO_LARGE) { - api_result = OEMCrypto_ERROR_BUFFER_TOO_LARGE; - } else { - api_result = OEMCrypto_ERROR_UNKNOWN_FAILURE; - } - goto cleanup_and_return; - } - response = API_Transact(&request); - if (api_result != OEMCrypto_SUCCESS) goto cleanup_and_return; - OPK_Unpack_InstallOemPrivateKey_Response(&response, &result); - - if (ODK_Message_GetStatus(&response) != MESSAGE_STATUS_OK) { - api_result = OEMCrypto_ERROR_UNKNOWN_FAILURE; - } -cleanup_and_return: - TOS_Transport_ReleaseMessage(&request); - TOS_Transport_ReleaseMessage(&response); - - result = API_CheckResult(result); - pthread_mutex_unlock(&api_lock); - return result; -} - OEMCRYPTO_API OEMCryptoResult OEMCrypto_EnterTestMode(void) { pthread_mutex_lock(&api_lock); OEMCryptoResult result = OEMCrypto_ERROR_UNKNOWN_FAILURE; @@ -3273,6 +3238,40 @@ cleanup_and_return: return result; } +OEMCRYPTO_API OEMCryptoResult OEMCrypto_GetRandom( + OEMCrypto_SharedMemory* random_data, size_t random_data_length) { + pthread_mutex_lock(&api_lock); + OEMCryptoResult result = OEMCrypto_ERROR_UNKNOWN_FAILURE; + ODK_Message request = ODK_Message_Create(NULL, 0); + ODK_Message response = ODK_Message_Create(NULL, 0); + + API_Initialize(); + request = OPK_Pack_GetRandom_Request(random_data, random_data_length); + if (ODK_Message_GetStatus(&request) != MESSAGE_STATUS_OK) { + if (ODK_Message_GetStatus(&request) == MESSAGE_STATUS_BUFFER_TOO_LARGE) { + api_result = OEMCrypto_ERROR_BUFFER_TOO_LARGE; + } else { + api_result = OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + goto cleanup_and_return; + } + response = API_Transact(&request); + if (api_result != OEMCrypto_SUCCESS) goto cleanup_and_return; + OPK_Unpack_GetRandom_Response(&response, &result, &random_data, + &random_data_length); + + if (ODK_Message_GetStatus(&response) != MESSAGE_STATUS_OK) { + api_result = OEMCrypto_ERROR_UNKNOWN_FAILURE; + } +cleanup_and_return: + TOS_Transport_ReleaseMessage(&request); + TOS_Transport_ReleaseMessage(&response); + + result = API_CheckResult(result); + pthread_mutex_unlock(&api_lock); + return result; +} + OEMCRYPTO_API OEMCryptoResult OEMCrypto_GenerateOTARequest(OEMCrypto_SESSION session, uint8_t* buffer, size_t* buffer_length, uint32_t use_test_key) { @@ -3344,6 +3343,41 @@ cleanup_and_return: return result; } +OEMCRYPTO_API OEMCryptoResult OEMCrypto_InstallOemPrivateKey( + OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, + const uint8_t* wrapped_private_key, size_t wrapped_private_key_length) { + pthread_mutex_lock(&api_lock); + OEMCryptoResult result = OEMCrypto_ERROR_UNKNOWN_FAILURE; + ODK_Message request = ODK_Message_Create(NULL, 0); + ODK_Message response = ODK_Message_Create(NULL, 0); + + API_Initialize(); + request = OPK_Pack_InstallOemPrivateKey_Request( + session, key_type, wrapped_private_key, wrapped_private_key_length); + if (ODK_Message_GetStatus(&request) != MESSAGE_STATUS_OK) { + if (ODK_Message_GetStatus(&request) == MESSAGE_STATUS_BUFFER_TOO_LARGE) { + api_result = OEMCrypto_ERROR_BUFFER_TOO_LARGE; + } else { + api_result = OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + goto cleanup_and_return; + } + response = API_Transact(&request); + if (api_result != OEMCrypto_SUCCESS) goto cleanup_and_return; + OPK_Unpack_InstallOemPrivateKey_Response(&response, &result); + + if (ODK_Message_GetStatus(&response) != MESSAGE_STATUS_OK) { + api_result = OEMCrypto_ERROR_UNKNOWN_FAILURE; + } +cleanup_and_return: + TOS_Transport_ReleaseMessage(&request); + TOS_Transport_ReleaseMessage(&response); + + result = API_CheckResult(result); + pthread_mutex_unlock(&api_lock); + return result; +} + OEMCRYPTO_API OEMCryptoResult OEMCrypto_GetEmbeddedDrmCertificate( uint8_t* public_cert, size_t* public_cert_length) { pthread_mutex_lock(&api_lock); diff --git a/oemcrypto/opk/serialization/ree/GEN_ree_serializer.c b/oemcrypto/opk/serialization/ree/GEN_ree_serializer.c index acb3d21..4b8da44 100644 --- a/oemcrypto/opk/serialization/ree/GEN_ree_serializer.c +++ b/oemcrypto/opk/serialization/ree/GEN_ree_serializer.c @@ -2900,7 +2900,7 @@ void OPK_Unpack_ReportUsage_Response(ODK_Message* msg, OEMCryptoResult* result, } ODK_Message OPK_Pack_GetUsageEntryInfo_Request( - OEMCrypto_SESSION session, const OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_SESSION session, const OEMCrypto_UsageEntryStatus* status, const int64_t* seconds_since_license_received, const int64_t* seconds_since_first_decrypt) { uint32_t api_value = 148; /* from _oecc148 */ @@ -2919,7 +2919,7 @@ ODK_Message OPK_Pack_GetUsageEntryInfo_Request( void OPK_Unpack_GetUsageEntryInfo_Response( ODK_Message* msg, OEMCryptoResult* result, - OEMCrypto_Usage_Entry_Status** status, + OEMCrypto_UsageEntryStatus** status, int64_t** seconds_since_license_received, int64_t** seconds_since_first_decrypt) { uint32_t api_value = UINT32_MAX; @@ -2933,7 +2933,7 @@ void OPK_Unpack_GetUsageEntryInfo_Response( if (SuccessResult(*result)) { OPK_UnpackNullable_uint32_t(msg, status); if (*status) { - if (!Is_Valid_OEMCrypto_Usage_Entry_Status(**status)) { + if (!Is_Valid_OEMCrypto_UsageEntryStatus(**status)) { ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); } } @@ -3258,41 +3258,6 @@ void OPK_Unpack_GetDeviceSignedCsrPayload_Response( } } -ODK_Message OPK_Pack_InstallOemPrivateKey_Request( - OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, - const uint8_t* wrapped_private_key, size_t wrapped_private_key_length) { - uint32_t api_value = 118; /* from _oecc118 */ - ODK_Message msg = TOS_Transport_GetRequest(); - OPK_Pack_uint32_t(&msg, &api_value); - uint64_t timestamp = time(0); - OPK_Pack_uint64_t(&msg, ×tamp); - OPK_Pack_size_t(&msg, &wrapped_private_key_length); - OPK_Pack_uint32_t(&msg, &session); - OPK_Pack_uint32_t(&msg, &key_type); - OPK_PackMemory(&msg, (const uint8_t*)wrapped_private_key, - OPK_ToLengthType(wrapped_private_key_length)); - OPK_PackEOM(&msg); - OPK_SharedBuffer_FinalizePacking(); - return msg; -} - -void OPK_Unpack_InstallOemPrivateKey_Response(ODK_Message* msg, - OEMCryptoResult* result) { - uint32_t api_value = UINT32_MAX; - OPK_Unpack_uint32_t(msg, &api_value); - if (api_value != 118) - ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_API_VALUE_ERROR); - OPK_Unpack_uint32_t(msg, result); - if (!Is_Valid_OEMCryptoResult(*result)) { - ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); - } - OPK_UnpackEOM(msg); - - if (SuccessResult(*result)) { - OPK_SharedBuffer_FinalizeUnpacking(); - } -} - ODK_Message OPK_Pack_EnterTestMode_Request(void) { uint32_t api_value = 140; /* from _oecc140 */ ODK_Message msg = TOS_Transport_GetRequest(); @@ -3493,6 +3458,46 @@ void OPK_Unpack_FreeSecureBuffer_Response( } } +ODK_Message OPK_Pack_GetRandom_Request( + const OEMCrypto_SharedMemory* random_data, size_t random_data_length) { + uint32_t api_value = 6; /* from _oecc6 */ + ODK_Message msg = TOS_Transport_GetRequest(); + OPK_Pack_uint32_t(&msg, &api_value); + uint64_t timestamp = time(0); + OPK_Pack_uint64_t(&msg, ×tamp); + OPK_Pack_size_t(&msg, &random_data_length); + OPK_PackSharedBuffer(&msg, random_data, OPK_ToLengthType(random_data_length), + /* map */ true, /* copy_in */ false, + /* is_output */ true); + OPK_PackEOM(&msg); + OPK_SharedBuffer_FinalizePacking(); + return msg; +} + +void OPK_Unpack_GetRandom_Response(ODK_Message* msg, OEMCryptoResult* result, + OEMCrypto_SharedMemory** random_data, + size_t* random_data_length) { + uint32_t api_value = UINT32_MAX; + OPK_Unpack_uint32_t(msg, &api_value); + if (api_value != 6) + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_API_VALUE_ERROR); + OPK_Unpack_size_t(msg, random_data_length); + OPK_Unpack_uint32_t(msg, result); + if (!Is_Valid_OEMCryptoResult(*result)) { + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); + } + if (SuccessResult(*result)) { + OPK_UnpackSharedBuffer(msg, random_data, + OPK_FromSizeTPtr(random_data_length), + /* map */ false, /* is_output */ true); + } + OPK_UnpackEOM(msg); + + if (SuccessResult(*result)) { + OPK_SharedBuffer_FinalizeUnpacking(); + } +} + ODK_Message OPK_Pack_OPK_SerializationVersion_Request( const uint32_t* ree_major, const uint32_t* ree_minor, const uint32_t* tee_major, const uint32_t* tee_minor) { @@ -3620,6 +3625,41 @@ void OPK_Unpack_ProcessOTAKeybox_Response(ODK_Message* msg, } } +ODK_Message OPK_Pack_InstallOemPrivateKey_Request( + OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, + const uint8_t* wrapped_private_key, size_t wrapped_private_key_length) { + uint32_t api_value = 118; /* from _oecc118 */ + ODK_Message msg = TOS_Transport_GetRequest(); + OPK_Pack_uint32_t(&msg, &api_value); + uint64_t timestamp = time(0); + OPK_Pack_uint64_t(&msg, ×tamp); + OPK_Pack_size_t(&msg, &wrapped_private_key_length); + OPK_Pack_uint32_t(&msg, &session); + OPK_Pack_uint32_t(&msg, &key_type); + OPK_PackMemory(&msg, (const uint8_t*)wrapped_private_key, + OPK_ToLengthType(wrapped_private_key_length)); + OPK_PackEOM(&msg); + OPK_SharedBuffer_FinalizePacking(); + return msg; +} + +void OPK_Unpack_InstallOemPrivateKey_Response(ODK_Message* msg, + OEMCryptoResult* result) { + uint32_t api_value = UINT32_MAX; + OPK_Unpack_uint32_t(msg, &api_value); + if (api_value != 118) + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_API_VALUE_ERROR); + OPK_Unpack_uint32_t(msg, result); + if (!Is_Valid_OEMCryptoResult(*result)) { + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); + } + OPK_UnpackEOM(msg); + + if (SuccessResult(*result)) { + OPK_SharedBuffer_FinalizeUnpacking(); + } +} + ODK_Message OPK_Pack_GetEmbeddedDrmCertificate_Request( const uint8_t* public_cert, const size_t* public_cert_length) { uint32_t api_value = 151; /* from _oecc151 */ diff --git a/oemcrypto/opk/serialization/ree/GEN_ree_serializer.h b/oemcrypto/opk/serialization/ree/GEN_ree_serializer.h index b7f0060..9e60486 100644 --- a/oemcrypto/opk/serialization/ree/GEN_ree_serializer.h +++ b/oemcrypto/opk/serialization/ree/GEN_ree_serializer.h @@ -405,12 +405,12 @@ ODK_Message OPK_Pack_ReportUsage_Request(OEMCrypto_SESSION session, void OPK_Unpack_ReportUsage_Response(ODK_Message* msg, OEMCryptoResult* result, uint8_t** buffer, size_t** buffer_length); ODK_Message OPK_Pack_GetUsageEntryInfo_Request( - OEMCrypto_SESSION session, const OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_SESSION session, const OEMCrypto_UsageEntryStatus* status, const int64_t* seconds_since_license_received, const int64_t* seconds_since_first_decrypt); void OPK_Unpack_GetUsageEntryInfo_Response( ODK_Message* msg, OEMCryptoResult* result, - OEMCrypto_Usage_Entry_Status** status, + OEMCrypto_UsageEntryStatus** status, int64_t** seconds_since_license_received, int64_t** seconds_since_first_decrypt); ODK_Message OPK_Pack_MoveEntry_Request(OEMCrypto_SESSION session, @@ -456,11 +456,6 @@ ODK_Message OPK_Pack_GetDeviceSignedCsrPayload_Request( void OPK_Unpack_GetDeviceSignedCsrPayload_Response( ODK_Message* msg, OEMCryptoResult* result, uint8_t** signed_csr_payload, size_t** signed_csr_payload_length); -ODK_Message OPK_Pack_InstallOemPrivateKey_Request( - OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, - const uint8_t* wrapped_private_key, size_t wrapped_private_key_length); -void OPK_Unpack_InstallOemPrivateKey_Response(ODK_Message* msg, - OEMCryptoResult* result); ODK_Message OPK_Pack_EnterTestMode_Request(void); void OPK_Unpack_EnterTestMode_Response(ODK_Message* msg, OEMCryptoResult* result); @@ -489,6 +484,11 @@ ODK_Message OPK_Pack_FreeSecureBuffer_Request( void OPK_Unpack_FreeSecureBuffer_Response( ODK_Message* msg, OEMCryptoResult* result, OEMCrypto_DestBufferDesc** output_descriptor); +ODK_Message OPK_Pack_GetRandom_Request( + const OEMCrypto_SharedMemory* random_data, size_t random_data_length); +void OPK_Unpack_GetRandom_Response(ODK_Message* msg, OEMCryptoResult* result, + OEMCrypto_SharedMemory** random_data, + size_t* random_data_length); ODK_Message OPK_Pack_OPK_SerializationVersion_Request( const uint32_t* ree_major, const uint32_t* ree_minor, const uint32_t* tee_major, const uint32_t* tee_minor); @@ -509,6 +509,11 @@ ODK_Message OPK_Pack_ProcessOTAKeybox_Request(OEMCrypto_SESSION session, uint32_t use_test_key); void OPK_Unpack_ProcessOTAKeybox_Response(ODK_Message* msg, OEMCryptoResult* result); +ODK_Message OPK_Pack_InstallOemPrivateKey_Request( + OEMCrypto_SESSION session, OEMCrypto_PrivateKeyType key_type, + const uint8_t* wrapped_private_key, size_t wrapped_private_key_length); +void OPK_Unpack_InstallOemPrivateKey_Response(ODK_Message* msg, + OEMCryptoResult* result); ODK_Message OPK_Pack_GetEmbeddedDrmCertificate_Request( const uint8_t* public_cert, const size_t* public_cert_length); void OPK_Unpack_GetEmbeddedDrmCertificate_Response(ODK_Message* msg, diff --git a/oemcrypto/opk/serialization/tee/GEN_dispatcher.c b/oemcrypto/opk/serialization/tee/GEN_dispatcher.c index 4192081..0832a51 100644 --- a/oemcrypto/opk/serialization/tee/GEN_dispatcher.c +++ b/oemcrypto/opk/serialization/tee/GEN_dispatcher.c @@ -81,6 +81,23 @@ void OPK_Init_OEMCrypto_KeyObject(OEMCrypto_KeyObject* obj) { OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_control); } +void OPK_Init_OEMCrypto_AuthenticationKeyInfo( + OEMCrypto_AuthenticationKeyInfo* obj) { + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->authentication_key); + OPK_Init_OEMCrypto_Substring( + (OEMCrypto_Substring*)&obj->authentication_key_iv); +} + +void OPK_Init_OEMCrypto_KeyObjectV2(OEMCrypto_KeyObjectV2* obj) { + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_id); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_data_iv); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_data); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_control_iv); + OPK_Init_OEMCrypto_Substring((OEMCrypto_Substring*)&obj->key_control); + OPK_Init_OEMCrypto_DeCENC_Mitigation_Info( + (OEMCrypto_DeCENC_Mitigation_Info*)&obj->decenc_mitigation_info); +} + void OPK_Init_OEMCrypto_SubSampleDescription( OEMCrypto_SubSampleDescription* obj) { OPK_Init_size_t((size_t*)&obj->num_bytes_clear); @@ -1628,7 +1645,7 @@ ODK_MessageStatus OPK_DispatchMessage(ODK_Message* request, { OEMCrypto_SESSION session; OPK_Init_uint32_t((uint32_t*)&session); - OEMCrypto_Usage_Entry_Status* status; + OEMCrypto_UsageEntryStatus* status; OPK_InitPointer((uint8_t**)&status); int64_t* seconds_since_license_received; OPK_InitPointer((uint8_t**)&seconds_since_license_received); @@ -1802,28 +1819,6 @@ ODK_MessageStatus OPK_DispatchMessage(ODK_Message* request, result, signed_csr_payload, signed_csr_payload_length); break; } - case 118: /* OEMCrypto_InstallOemPrivateKey */ - { - size_t wrapped_private_key_length; - OPK_Init_size_t((size_t*)&wrapped_private_key_length); - OEMCrypto_SESSION session; - OPK_Init_uint32_t((uint32_t*)&session); - OEMCrypto_PrivateKeyType key_type; - OPK_Init_uint32_t((uint32_t*)&key_type); - uint8_t* wrapped_private_key; - OPK_InitPointer((uint8_t**)&wrapped_private_key); - OPK_Unpack_InstallOemPrivateKey_Request(request, &session, &key_type, - &wrapped_private_key, - &wrapped_private_key_length); - if (!ODK_Message_IsValid(request)) goto handle_invalid_request; - OEMCryptoResult result; - OPK_Init_uint32_t((uint32_t*)&result); - LOGD("InstallOemPrivateKey"); - result = OEMCrypto_InstallOemPrivateKey( - session, key_type, wrapped_private_key, wrapped_private_key_length); - *response = OPK_Pack_InstallOemPrivateKey_Response(result); - break; - } case 140: /* OEMCrypto_EnterTestMode */ { OPK_Unpack_EnterTestMode_Request(request); @@ -1925,6 +1920,22 @@ ODK_MessageStatus OPK_DispatchMessage(ODK_Message* request, *response = OPK_Pack_FreeSecureBuffer_Response(result, output_descriptor); break; } + case 6: /* OEMCrypto_GetRandom */ + { + size_t random_data_length; + OPK_Init_size_t((size_t*)&random_data_length); + OEMCrypto_SharedMemory* random_data; + OPK_InitPointer((uint8_t**)&random_data); + OPK_Unpack_GetRandom_Request(request, &random_data, &random_data_length); + if (!ODK_Message_IsValid(request)) goto handle_invalid_request; + OEMCryptoResult result; + OPK_Init_uint32_t((uint32_t*)&result); + LOGD("GetRandom"); + result = OEMCrypto_GetRandom(random_data, random_data_length); + *response = OPK_Pack_GetRandom_Response( + result, OPK_SharedBuffer_NextOutputBuffer(), random_data_length); + break; + } case 115: /* OEMCrypto_OPK_SerializationVersion */ { uint32_t* ree_major = (uint32_t*)OPK_VarAlloc(sizeof(uint32_t)); @@ -1993,6 +2004,28 @@ ODK_MessageStatus OPK_DispatchMessage(ODK_Message* request, *response = OPK_Pack_ProcessOTAKeybox_Response(result); break; } + case 118: /* OEMCrypto_InstallOemPrivateKey */ + { + size_t wrapped_private_key_length; + OPK_Init_size_t((size_t*)&wrapped_private_key_length); + OEMCrypto_SESSION session; + OPK_Init_uint32_t((uint32_t*)&session); + OEMCrypto_PrivateKeyType key_type; + OPK_Init_uint32_t((uint32_t*)&key_type); + uint8_t* wrapped_private_key; + OPK_InitPointer((uint8_t**)&wrapped_private_key); + OPK_Unpack_InstallOemPrivateKey_Request(request, &session, &key_type, + &wrapped_private_key, + &wrapped_private_key_length); + if (!ODK_Message_IsValid(request)) goto handle_invalid_request; + OEMCryptoResult result; + OPK_Init_uint32_t((uint32_t*)&result); + LOGD("InstallOemPrivateKey"); + result = OEMCrypto_InstallOemPrivateKey( + session, key_type, wrapped_private_key, wrapped_private_key_length); + *response = OPK_Pack_InstallOemPrivateKey_Response(result); + break; + } case 151: /* OEMCrypto_GetEmbeddedDrmCertificate */ { size_t* public_cert_length = (size_t*)OPK_VarAlloc(sizeof(size_t)); diff --git a/oemcrypto/opk/serialization/tee/GEN_tee_serializer.c b/oemcrypto/opk/serialization/tee/GEN_tee_serializer.c index 47d4c64..23baaa4 100644 --- a/oemcrypto/opk/serialization/tee/GEN_tee_serializer.c +++ b/oemcrypto/opk/serialization/tee/GEN_tee_serializer.c @@ -2380,7 +2380,7 @@ ODK_Message OPK_Pack_ReportUsage_Response(OEMCryptoResult result, void OPK_Unpack_GetUsageEntryInfo_Request( ODK_Message* msg, OEMCrypto_SESSION* session, - OEMCrypto_Usage_Entry_Status** status, + OEMCrypto_UsageEntryStatus** status, int64_t** seconds_since_license_received, int64_t** seconds_since_first_decrypt) { uint32_t api_value = UINT32_MAX; @@ -2390,8 +2390,7 @@ void OPK_Unpack_GetUsageEntryInfo_Request( uint64_t timestamp; OPK_Unpack_uint64_t(msg, ×tamp); OPK_Unpack_uint32_t(msg, session); - *status = - (uint32_t*)OPK_UnpackAlloc(msg, sizeof(OEMCrypto_Usage_Entry_Status)); + *status = (uint32_t*)OPK_UnpackAlloc(msg, sizeof(OEMCrypto_UsageEntryStatus)); *seconds_since_license_received = (int64_t*)OPK_UnpackAlloc(msg, sizeof(int64_t)); *seconds_since_first_decrypt = @@ -2401,7 +2400,7 @@ void OPK_Unpack_GetUsageEntryInfo_Request( } ODK_Message OPK_Pack_GetUsageEntryInfo_Response( - OEMCryptoResult result, const OEMCrypto_Usage_Entry_Status* status, + OEMCryptoResult result, const OEMCrypto_UsageEntryStatus* status, const int64_t* seconds_since_license_received, const int64_t* seconds_since_first_decrypt) { uint32_t api_value = 148; /* from _oecc148 */ @@ -2647,38 +2646,6 @@ ODK_Message OPK_Pack_GetDeviceSignedCsrPayload_Response( return msg; } -void OPK_Unpack_InstallOemPrivateKey_Request( - ODK_Message* msg, OEMCrypto_SESSION* session, - OEMCrypto_PrivateKeyType* key_type, uint8_t** wrapped_private_key, - size_t* wrapped_private_key_length) { - uint32_t api_value = UINT32_MAX; - OPK_Unpack_uint32_t(msg, &api_value); - if (api_value != 118) - ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_API_VALUE_ERROR); - uint64_t timestamp; - OPK_Unpack_uint64_t(msg, ×tamp); - OPK_Unpack_size_t(msg, wrapped_private_key_length); - OPK_Unpack_uint32_t(msg, session); - OPK_Unpack_uint32_t(msg, key_type); - if (!Is_Valid_OEMCrypto_PrivateKeyType(*key_type)) { - ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); - } - OPK_UnpackInPlace(msg, (uint8_t**)wrapped_private_key, - OPK_FromSizeTPtr(wrapped_private_key_length)); - OPK_UnpackEOM(msg); - OPK_SharedBuffer_FinalizeUnpacking(); -} - -ODK_Message OPK_Pack_InstallOemPrivateKey_Response(OEMCryptoResult result) { - uint32_t api_value = 118; /* from _oecc118 */ - ODK_Message msg = TOS_Transport_GetResponse(); - OPK_Pack_uint32_t(&msg, &api_value); - OPK_Pack_uint32_t(&msg, &result); - OPK_PackEOM(&msg); - OPK_SharedBuffer_FinalizePacking(); - return msg; -} - void OPK_Unpack_EnterTestMode_Request(ODK_Message* msg) { uint32_t api_value = UINT32_MAX; OPK_Unpack_uint32_t(msg, &api_value); @@ -2835,6 +2802,40 @@ ODK_Message OPK_Pack_FreeSecureBuffer_Response( return msg; } +void OPK_Unpack_GetRandom_Request(ODK_Message* msg, + OEMCrypto_SharedMemory** random_data, + size_t* random_data_length) { + uint32_t api_value = UINT32_MAX; + OPK_Unpack_uint32_t(msg, &api_value); + if (api_value != 6) + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_API_VALUE_ERROR); + uint64_t timestamp; + OPK_Unpack_uint64_t(msg, ×tamp); + OPK_Unpack_size_t(msg, random_data_length); + OPK_UnpackSharedBuffer(msg, random_data, OPK_FromSizeTPtr(random_data_length), + /* map */ true, /* is_output */ true); + OPK_UnpackEOM(msg); + OPK_SharedBuffer_FinalizeUnpacking(); +} + +ODK_Message OPK_Pack_GetRandom_Response( + OEMCryptoResult result, const OEMCrypto_SharedMemory* random_data, + size_t random_data_length) { + uint32_t api_value = 6; /* from _oecc6 */ + ODK_Message msg = TOS_Transport_GetResponse(); + OPK_Pack_uint32_t(&msg, &api_value); + OPK_Pack_size_t(&msg, &random_data_length); + OPK_Pack_uint32_t(&msg, &result); + if (SuccessResult(result)) { + OPK_PackSharedBuffer(&msg, random_data, + OPK_ToLengthType(random_data_length), /* map */ false, + /* copy_in */ false, /* is_output */ true); + } + OPK_PackEOM(&msg); + OPK_SharedBuffer_FinalizePacking(); + return msg; +} + void OPK_Unpack_OPK_SerializationVersion_Request(ODK_Message* msg, uint32_t** ree_major, uint32_t** ree_minor, @@ -2937,6 +2938,38 @@ ODK_Message OPK_Pack_ProcessOTAKeybox_Response(OEMCryptoResult result) { return msg; } +void OPK_Unpack_InstallOemPrivateKey_Request( + ODK_Message* msg, OEMCrypto_SESSION* session, + OEMCrypto_PrivateKeyType* key_type, uint8_t** wrapped_private_key, + size_t* wrapped_private_key_length) { + uint32_t api_value = UINT32_MAX; + OPK_Unpack_uint32_t(msg, &api_value); + if (api_value != 118) + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_API_VALUE_ERROR); + uint64_t timestamp; + OPK_Unpack_uint64_t(msg, ×tamp); + OPK_Unpack_size_t(msg, wrapped_private_key_length); + OPK_Unpack_uint32_t(msg, session); + OPK_Unpack_uint32_t(msg, key_type); + if (!Is_Valid_OEMCrypto_PrivateKeyType(*key_type)) { + ODK_MESSAGE_SETSTATUS(msg, MESSAGE_STATUS_INVALID_ENUM_VALUE); + } + OPK_UnpackInPlace(msg, (uint8_t**)wrapped_private_key, + OPK_FromSizeTPtr(wrapped_private_key_length)); + OPK_UnpackEOM(msg); + OPK_SharedBuffer_FinalizeUnpacking(); +} + +ODK_Message OPK_Pack_InstallOemPrivateKey_Response(OEMCryptoResult result) { + uint32_t api_value = 118; /* from _oecc118 */ + ODK_Message msg = TOS_Transport_GetResponse(); + OPK_Pack_uint32_t(&msg, &api_value); + OPK_Pack_uint32_t(&msg, &result); + OPK_PackEOM(&msg); + OPK_SharedBuffer_FinalizePacking(); + return msg; +} + void OPK_Unpack_GetEmbeddedDrmCertificate_Request(ODK_Message* msg, uint8_t** public_cert, size_t** public_cert_length) { diff --git a/oemcrypto/opk/serialization/tee/GEN_tee_serializer.h b/oemcrypto/opk/serialization/tee/GEN_tee_serializer.h index 4b3ded9..460c59a 100644 --- a/oemcrypto/opk/serialization/tee/GEN_tee_serializer.h +++ b/oemcrypto/opk/serialization/tee/GEN_tee_serializer.h @@ -392,11 +392,11 @@ ODK_Message OPK_Pack_ReportUsage_Response(OEMCryptoResult result, const size_t* buffer_length); void OPK_Unpack_GetUsageEntryInfo_Request( ODK_Message* msg, OEMCrypto_SESSION* session, - OEMCrypto_Usage_Entry_Status** status, + OEMCrypto_UsageEntryStatus** status, int64_t** seconds_since_license_received, int64_t** seconds_since_first_decrypt); ODK_Message OPK_Pack_GetUsageEntryInfo_Response( - OEMCryptoResult result, const OEMCrypto_Usage_Entry_Status* status, + OEMCryptoResult result, const OEMCrypto_UsageEntryStatus* status, const int64_t* seconds_since_license_received, const int64_t* seconds_since_first_decrypt); void OPK_Unpack_MoveEntry_Request(ODK_Message* msg, OEMCrypto_SESSION* session, @@ -441,11 +441,6 @@ void OPK_Unpack_GetDeviceSignedCsrPayload_Request( ODK_Message OPK_Pack_GetDeviceSignedCsrPayload_Response( OEMCryptoResult result, const uint8_t* signed_csr_payload, const size_t* signed_csr_payload_length); -void OPK_Unpack_InstallOemPrivateKey_Request( - ODK_Message* msg, OEMCrypto_SESSION* session, - OEMCrypto_PrivateKeyType* key_type, uint8_t** wrapped_private_key, - size_t* wrapped_private_key_length); -ODK_Message OPK_Pack_InstallOemPrivateKey_Response(OEMCryptoResult result); void OPK_Unpack_EnterTestMode_Request(ODK_Message* msg); ODK_Message OPK_Pack_EnterTestMode_Response(OEMCryptoResult result); void OPK_Unpack_SupportsDecryptHash_Request(ODK_Message* msg); @@ -470,6 +465,12 @@ void OPK_Unpack_FreeSecureBuffer_Request( OEMCrypto_DestBufferDesc** output_descriptor, int* secure_fd); ODK_Message OPK_Pack_FreeSecureBuffer_Response( OEMCryptoResult result, const OEMCrypto_DestBufferDesc* output_descriptor); +void OPK_Unpack_GetRandom_Request(ODK_Message* msg, + OEMCrypto_SharedMemory** random_data, + size_t* random_data_length); +ODK_Message OPK_Pack_GetRandom_Response( + OEMCryptoResult result, const OEMCrypto_SharedMemory* random_data, + size_t random_data_length); void OPK_Unpack_OPK_SerializationVersion_Request(ODK_Message* msg, uint32_t** ree_major, uint32_t** ree_minor, @@ -493,6 +494,11 @@ void OPK_Unpack_ProcessOTAKeybox_Request(ODK_Message* msg, size_t* buffer_length, uint32_t* use_test_key); ODK_Message OPK_Pack_ProcessOTAKeybox_Response(OEMCryptoResult result); +void OPK_Unpack_InstallOemPrivateKey_Request( + ODK_Message* msg, OEMCrypto_SESSION* session, + OEMCrypto_PrivateKeyType* key_type, uint8_t** wrapped_private_key, + size_t* wrapped_private_key_length); +ODK_Message OPK_Pack_InstallOemPrivateKey_Response(OEMCryptoResult result); void OPK_Unpack_GetEmbeddedDrmCertificate_Request(ODK_Message* msg, uint8_t** public_cert, size_t** public_cert_length); diff --git a/oemcrypto/test/Android.mk b/oemcrypto/test/Android.mk deleted file mode 100644 index 98e9b4d..0000000 --- a/oemcrypto/test/Android.mk +++ /dev/null @@ -1,27 +0,0 @@ -LOCAL_PATH:= $(call my-dir) - -include $(CLEAR_VARS) - -LOCAL_C_INCLUDES := \ - vendor/widevine/libwvdrmengine/cdm/util/include \ - -LOCAL_MODULE:=oemcrypto_test -LOCAL_LICENSE_KINDS:=legacy_by_exception_only legacy_proprietary -LOCAL_LICENSE_CONDITIONS:=by_exception_only proprietary by_exception_only -LOCAL_MODULE_TAGS := tests - -LOCAL_MODULE_OWNER := widevine -LOCAL_PROPRIETARY_MODULE := true - -LOCAL_C_INCLUDES += external/googletest/googlemock/include \ - -# When built, explicitly put it in the DATA/nativetest directory. -LOCAL_MODULE_PATH := $(TARGET_OUT_DATA)/nativetest - -ifneq ($(TARGET_ENABLE_MEDIADRM_64), true) -LOCAL_MODULE_TARGET_ARCH := arm x86 mips -endif - -include $(LOCAL_PATH)/common.mk - -include $(BUILD_EXECUTABLE) diff --git a/oemcrypto/test/GEN_api_lock_file.c b/oemcrypto/test/GEN_api_lock_file.c index 9a5af65..91b18b7 100644 --- a/oemcrypto/test/GEN_api_lock_file.c +++ b/oemcrypto/test/GEN_api_lock_file.c @@ -213,7 +213,7 @@ OEMCryptoResult _oecc34(void); OEMCryptoResult _oecc70(uint64_t time_since_license_received, uint64_t time_since_first_decrypt, uint64_t time_since_last_decrypt, - OEMCrypto_Usage_Entry_Status status, + OEMCrypto_UsageEntryStatus status, uint8_t* server_mac_key, uint8_t* client_mac_key, const uint8_t* pst, size_t pst_length); OEMCryptoResult _oecc12(OEMCrypto_SESSION session, @@ -412,7 +412,7 @@ OEMCryptoResult _oecc146(OEMCrypto_SESSION session, // OEMCrypto_GetUsageEntryInfo defined in v19.0 OEMCryptoResult _oecc148(OEMCrypto_SESSION session, - OEMCrypto_Usage_Entry_Status* status, + OEMCrypto_UsageEntryStatus* status, int64_t* seconds_since_license_received, int64_t* seconds_since_first_decrypt); @@ -450,3 +450,11 @@ OEMCryptoResult _oecc157(OEMCrypto_SESSION session, uint8_t* wrapped_pvr_key, OEMCryptoResult _oecc158(OEMCrypto_SESSION session, const uint8_t* wrapped_pvr_key, size_t wrapped_pvr_key_length); + +// OEMCrypto_LoadLicenseData defined in v20.0 +OEMCryptoResult _oecc159(OEMCrypto_SESSION session, const uint8_t* data, + size_t data_length); + +// OEMCrypto_SaveLicenseData defined in v20.0 +OEMCryptoResult _oecc160(OEMCrypto_SESSION session, uint8_t* data, + size_t* data_length); diff --git a/oemcrypto/test/common.mk b/oemcrypto/test/common.mk deleted file mode 100644 index 2e9cb1b..0000000 --- a/oemcrypto/test/common.mk +++ /dev/null @@ -1,76 +0,0 @@ -LOCAL_PATH:= $(call my-dir) - -ifeq ($(filter mips mips64, $(TARGET_ARCH)),) -# Tests need to be compatible with devices that do not support gnu hash-style -LOCAL_LDFLAGS+=-Wl,--hash-style=both -endif - -# The unit tests can access v15 functions through the dynamic adapter: -LOCAL_CFLAGS += -DTEST_OEMCRYPTO_V15 - -LOCAL_SRC_FILES:= \ - GEN_api_lock_file.c \ - oec_device_features.cpp \ - oec_decrypt_fallback_chain.cpp \ - oec_key_deriver.cpp \ - oec_session_util.cpp \ - oemcrypto_corpus_generator_helper.cpp \ - oemcrypto_session_tests_helper.cpp \ - oemcrypto_basic_test.cpp \ - oemcrypto_cast_test.cpp \ - oemcrypto_decrypt_test.cpp \ - oemcrypto_generic_crypto_test.cpp \ - oemcrypto_license_test.cpp \ - oemcrypto_provisioning_test.cpp \ - oemcrypto_security_test.cpp \ - oemcrypto_usage_table_test.cpp \ - oemcrypto_test.cpp \ - oemcrypto_test_android.cpp \ - oemcrypto_test_main.cpp \ - ota_keybox_test.cpp \ - ../../cdm/util/test/test_sleep.cpp \ - ../util/src/bcc_validator.cpp \ - ../util/src/cbor_validator.cpp \ - ../util/src/device_info_validator.cpp \ - ../util/src/oemcrypto_ecc_key.cpp \ - ../util/src/oemcrypto_rsa_key.cpp \ - ../util/src/prov4_validation_helper.cpp \ - ../util/src/signed_csr_payload_validator.cpp \ - ../util/src/wvcrc.cpp \ - -LOCAL_C_INCLUDES += \ - $(LOCAL_PATH)/fuzz_tests \ - $(LOCAL_PATH)/../include \ - $(LOCAL_PATH)/../odk/include \ - $(LOCAL_PATH)/../odk/kdo/include \ - $(LOCAL_PATH)/../ref/src \ - $(LOCAL_PATH)/../util/include \ - vendor/widevine/libwvdrmengine/cdm/core/include \ - vendor/widevine/libwvdrmengine/cdm/util/include \ - vendor/widevine/libwvdrmengine/cdm/util/test \ - -LOCAL_STATIC_LIBRARIES := \ - libcdm \ - libcppbor \ - libjsmn \ - libgmock \ - libgtest \ - libgtest_main \ - libwvlevel3 \ - libcdm_protos \ - libcdm_utils \ - libwv_kdo \ - libwv_odk \ - libPlatformProperties \ - -LOCAL_SHARED_LIBRARIES := \ - libbase \ - libcrypto \ - libdl \ - libbinder_ndk \ - liblog \ - libmedia_omx \ - libprotobuf-cpp-lite \ - libstagefright_foundation \ - libutils \ - libz \ diff --git a/oemcrypto/test/fuzz_tests/README.md b/oemcrypto/test/fuzz_tests/README.md index 249fe2f..dd2d683 100644 --- a/oemcrypto/test/fuzz_tests/README.md +++ b/oemcrypto/test/fuzz_tests/README.md @@ -96,10 +96,10 @@ typically have less than 1000 executions per second (exec/s). `LLVMFuzzerInitialize` can be used for global initialization, but there is no corresponding termination method. -A good starting example is [`oemcrypto_install_oem_private_key_fuzz.cc`][4]. -Targets should be added to `oemcrypto_opk_fuzztests.gyp` and, if the fuzz test -applies to partner OEMCrypto implementations, `partner_oemcrypto_fuzztests.gyp`. -The infrastructure expects that the target name starts with *oemcrypto* and ends +A good starting example is [`oemcrypto_get_key_handle_fuzz.cc`][4]. Targets +should be added to `oemcrypto_opk_fuzztests.gyp` and, if the fuzz test applies +to partner OEMCrypto implementations, `partner_oemcrypto_fuzztests.gyp`. The +infrastructure expects that the target name starts with *oemcrypto* and ends with *fuzz*. For additional information about writing fuzz tests, see @@ -136,5 +136,5 @@ For additional information about writing fuzz tests, see [1]: clusterfuzz_setup.md [2]: https://wiki.sei.cmu.edu/confluence/display/c/SEI+CERT+C+Coding+Standard [3]: https://github.com/llvm/llvm-project/blob/main/compiler-rt/include/fuzzer/FuzzedDataProvider.h -[4]: oemcrypto_install_oem_private_key_fuzz.cc +[4]: oemcrypto_get_key_handle_fuzz.cc [5]: https://github.com/google/fuzzing/blob/master/docs/good-fuzz-target.md diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_deactivate_usage_entry_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_deactivate_usage_entry_fuzz.cc index 9644106..52a269c 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_deactivate_usage_entry_fuzz.cc +++ b/oemcrypto/test/fuzz_tests/oemcrypto_deactivate_usage_entry_fuzz.cc @@ -13,7 +13,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { entry.CreateUsageTableHeader(); entry.InstallTestDrmKey(); entry.session().CreateNewUsageEntry(); - entry.session().GenerateNonce(); std::vector encrypted_usage_header; entry.session().UpdateUsageEntry(&encrypted_usage_header); // LoadLicense sets the pst for usage entry. diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_fuzz_structs.h b/oemcrypto/test/fuzz_tests/oemcrypto_fuzz_structs.h index f925b03..f6d6d74 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_fuzz_structs.h +++ b/oemcrypto/test/fuzz_tests/oemcrypto_fuzz_structs.h @@ -45,8 +45,6 @@ struct OEMCrypto_Renewal_Response_Fuzz { struct OEMCrypto_Release_Response_Fuzz { oemcrypto_core_message::ODK_ReleaseRequest core_request; - int64_t seconds_since_license_received; - int64_t seconds_since_first_decrypt; // license_release_response is of variable length and not included in this // structure. }; diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_fuzztests.gypi b/oemcrypto/test/fuzz_tests/oemcrypto_fuzztests.gypi index 187d503..18c917a 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_fuzztests.gypi +++ b/oemcrypto/test/fuzz_tests/oemcrypto_fuzztests.gypi @@ -27,9 +27,10 @@ '<(platform_specific_dir)/file_store.cpp', '<(platform_specific_dir)/log.cpp', '<(util_dir)/src/platform.cpp', - '<(util_dir)/src/rw_lock.cpp', '<(util_dir)/src/string_conversions.cpp', '<(util_dir)/src/string_format.cpp', + '<(util_dir)/src/string_utils.cpp', + '<(util_dir)/src/wv_duration.cpp', '<(util_dir)/test/test_sleep.cpp', '<(util_dir)/test/test_clock.cpp', ], @@ -50,7 +51,8 @@ 'dependencies': [ '../../../third_party/googletest.gyp:gtest', '../../../third_party/googletest.gyp:gmock', - '<(oemcrypto_dir)/util/oec_ref_util.gyp:oec_ref_util', + '<(oemcrypto_dir)/util/build.gyp:liboec_ref_util', + '<(oemcrypto_dir)/util/wvcrc32.gyp:libwvcrc32', ], 'defines': [ 'OEMCRYPTO_FUZZ_TESTS', diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_generate_certificate_key_pair_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_generate_certificate_key_pair_fuzz.cc index f66353a..cff4bd2 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_generate_certificate_key_pair_fuzz.cc +++ b/oemcrypto/test/fuzz_tests/oemcrypto_generate_certificate_key_pair_fuzz.cc @@ -19,38 +19,6 @@ wvoec::OEMCryptoProvisioningAPIFuzz& provisioning_api_fuzz = extern "C" int LLVMFuzzerInitialize(int* argc, char*** argv) { wvoec::RedirectStdoutToFile(); provisioning_api_fuzz.Initialize(); - -#ifdef SECOND_STAGE - - const uint32_t session_id = provisioning_api_fuzz.session().session_id(); - - size_t public_key_length = 0; - size_t public_key_signature_length = 0; - size_t wrapped_private_key_length = 0; - OEMCrypto_PrivateKeyType key_type = OEMCrypto_RSA_Private_Key; - OEMCryptoResult result = OEMCrypto_GenerateCertificateKeyPair( - session_id, nullptr, &public_key_length, nullptr, - &public_key_signature_length, nullptr, &wrapped_private_key_length, - &key_type); - wvoec::CheckStatusAndExitFuzzerOnFailure(result, - OEMCrypto_ERROR_SHORT_BUFFER); - - std::vector public_key(public_key_length); - std::vector public_key_signature(public_key_signature_length); - std::vector wrapped_private_key(wrapped_private_key_length); - result = OEMCrypto_GenerateCertificateKeyPair( - session_id, public_key.data(), &public_key_length, - public_key_signature.data(), &public_key_signature_length, - wrapped_private_key.data(), &wrapped_private_key_length, &key_type); - wvoec::CheckStatusAndExitFuzzerOnFailure(result, OEMCrypto_SUCCESS); - - result = OEMCrypto_InstallOemPrivateKey(session_id, key_type, - wrapped_private_key.data(), - wrapped_private_key_length); - wvoec::CheckStatusAndExitFuzzerOnFailure(result, OEMCrypto_SUCCESS); - -#endif - return 0; } diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_get_random_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_get_random_fuzz.cc index e81f956..c00a9c9 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_get_random_fuzz.cc +++ b/oemcrypto/test/fuzz_tests/oemcrypto_get_random_fuzz.cc @@ -16,7 +16,7 @@ extern "C" int LLVMFuzzerInitialize(int* argc, char*** argv) { } extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - std::vector random_data( + std::vector random_data( FuzzedDataProvider(data, size) .ConsumeIntegralInRange(0, wvoec::MAX_FUZZ_OUTPUT_LENGTH)); OEMCrypto_GetRandom(random_data.data(), random_data.size()); diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_install_oem_private_key_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_install_oem_private_key_fuzz.cc deleted file mode 100644 index 310999c..0000000 --- a/oemcrypto/test/fuzz_tests/oemcrypto_install_oem_private_key_fuzz.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2022 Google LLC. All Rights Reserved. This file and proprietary -// source code may only be used and distributed under the Widevine -// License Agreement. - -#include "FuzzedDataProvider.h" -#include "OEMCryptoCENC.h" -#include "oemcrypto_fuzz_helper.h" - -extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - wvoec::RedirectStdoutToFile(); - - wvoec::SessionFuzz session_fuzz; - session_fuzz.Initialize(); - FuzzedDataProvider fuzzed_data(data, size); - const OEMCrypto_PrivateKeyType key_type = wvoec::ConvertDataToValidEnum( - fuzzed_data, OEMCrypto_PrivateKeyType_MaxValue); - const std::vector wrapped_private_key = - fuzzed_data.ConsumeRemainingBytes(); - OEMCrypto_InstallOemPrivateKey(session_fuzz.session().session_id(), key_type, - wrapped_private_key.data(), - wrapped_private_key.size()); - session_fuzz.Terminate(); - - return 0; -} diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc index 231cf40..e270e4f 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc +++ b/oemcrypto/test/fuzz_tests/oemcrypto_opk_dispatcher_fuzz.cc @@ -7,13 +7,13 @@ namespace { void OpenOEMCryptoTASession() { uint8_t request_body[] = { - 0x06, // TAG_UINT32 + 0x07, // TAG_UINT32 0x09, 0x00, 0x00, 0x00, // API value (0x09) - 0x07, // TAG_UINT64 + 0x08, // TAG_UINT64 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Timestamp 0x01, // TAG_BOOL 0x00, // value (false) - 0x0a // TAG_EOM + 0x0b // TAG_EOM }; ODK_Message request = ODK_Message_Create(request_body, sizeof(request_body)); ODK_Message_SetSize(&request, sizeof(request_body)); @@ -23,11 +23,11 @@ void OpenOEMCryptoTASession() { void InitializeOEMCryptoTA() { uint8_t request_body[] = { - 0x06, // TAG_UINT32 + 0x07, // TAG_UINT32 0x01, 0x00, 0x00, 0x00, // API value (0x01) - 0x07, // TAG_UINT64 + 0x08, // TAG_UINT64 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Timestamp - 0x0a // TAG_EOM + 0x0b // TAG_EOM }; ODK_Message request = ODK_Message_Create(request_body, sizeof(request_body)); ODK_Message_SetSize(&request, sizeof(request_body)); diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp b/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp index be0ab62..b38a1be 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp +++ b/oemcrypto/test/fuzz_tests/oemcrypto_opk_fuzztests.gyp @@ -42,12 +42,8 @@ { 'target_name': 'oemcrypto_opk_dispatcher_fuzz', 'include_dirs': [ - '<(oemcrypto_dir)/opk/serialization/common', '<(oemcrypto_dir)/opk/serialization/common/include', '<(oemcrypto_dir)/opk/serialization/os_interfaces', - '<(oemcrypto_dir)/opk/serialization/tee', - '<(oemcrypto_dir)/opk/serialization/tee/include', - '<(oemcrypto_dir)/opk/ports/trusty/include/', ], 'dependencies': [ '<(oemcrypto_dir)/opk/serialization/tee/tee.gyp:opk_tee', @@ -55,9 +51,9 @@ 'sources': [ 'oemcrypto_opk_dispatcher_fuzz.cc', '<(oemcrypto_dir)/opk/serialization/test/tos_secure_buffers.c', - '<(oemcrypto_dir)/opk/serialization/test/tos_transport_interface.c', '<(oemcrypto_dir)/opk/serialization/test/tos_logging.c', - '<(oemcrypto_dir)/opk/ports/trusty/serialization_adapter/shared_memory.c', + '<(oemcrypto_dir)/opk/serialization/test/tos_shared_memory.c', + '<(oemcrypto_dir)/opk/serialization/test/tos_transport_interface.c', ], }, { @@ -67,20 +63,11 @@ ], }, { - 'target_name': 'oemcrypto_opk_generate_certificate_key_pair_first_stage_fuzz', + 'target_name': 'oemcrypto_opk_generate_certificate_key_pair_fuzz', 'sources': [ 'oemcrypto_generate_certificate_key_pair_fuzz.cc', ], }, - { - 'target_name': 'oemcrypto_opk_generate_certificate_key_pair_second_stage_fuzz', - 'sources': [ - 'oemcrypto_generate_certificate_key_pair_fuzz.cc', - ], - 'defines': [ - 'SECOND_STAGE', - ], - }, { 'target_name': 'oemcrypto_opk_generate_rsa_signature_fuzz', 'sources': [ @@ -141,12 +128,6 @@ 'oemcrypto_get_random_fuzz.cc', ], }, - { - 'target_name': 'oemcrypto_opk_install_oem_private_key_fuzz', - 'sources': [ - 'oemcrypto_install_oem_private_key_fuzz.cc', - ], - }, { 'target_name': 'oemcrypto_opk_license_request_fuzz', 'sources': [ diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_release_request_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_release_request_fuzz.cc index 11df453..77f9079 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_release_request_fuzz.cc +++ b/oemcrypto/test/fuzz_tests/oemcrypto_release_request_fuzz.cc @@ -20,7 +20,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { entry.CreateUsageTableHeader(); entry.InstallTestDrmKey(); entry.session().CreateNewUsageEntry(); - entry.session().GenerateNonce(); std::vector encrypted_usage_header; entry.session().UpdateUsageEntry(&encrypted_usage_header); entry.LoadLicense(); diff --git a/oemcrypto/test/fuzz_tests/oemcrypto_report_usage_fuzz.cc b/oemcrypto/test/fuzz_tests/oemcrypto_report_usage_fuzz.cc index 8b63654..cd3923c 100644 --- a/oemcrypto/test/fuzz_tests/oemcrypto_report_usage_fuzz.cc +++ b/oemcrypto/test/fuzz_tests/oemcrypto_report_usage_fuzz.cc @@ -21,7 +21,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { entry.CreateUsageTableHeader(); entry.InstallTestDrmKey(); entry.session().CreateNewUsageEntry(); - entry.session().GenerateNonce(); std::vector encrypted_usage_header; entry.session().UpdateUsageEntry(&encrypted_usage_header); // Sets pst for usage entry. diff --git a/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gyp b/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gyp index 29826b4..b06a7e1 100644 --- a/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gyp +++ b/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gyp @@ -44,20 +44,11 @@ ], }, { - 'target_name': 'oemcrypto_generate_certificate_key_pair_first_stage_fuzz', + 'target_name': 'oemcrypto_generate_certificate_key_pair_fuzz', 'sources': [ 'oemcrypto_generate_certificate_key_pair_fuzz.cc', ], }, - { - 'target_name': 'oemcrypto_generate_certificate_key_pair_second_stage_fuzz', - 'sources': [ - 'oemcrypto_generate_certificate_key_pair_fuzz.cc', - ], - 'defines': [ - 'SECOND_STAGE', - ], - }, { 'target_name': 'oemcrypto_generate_rsa_signature_fuzz', 'sources': [ @@ -118,12 +109,6 @@ 'oemcrypto_get_random_fuzz.cc', ], }, - { - 'target_name': 'oemcrypto_install_oem_private_key_fuzz', - 'sources': [ - 'oemcrypto_install_oem_private_key_fuzz.cc', - ], - }, { 'target_name': 'oemcrypto_license_request_fuzz', 'sources': [ diff --git a/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gypi b/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gypi index 7252e2f..27e9619 100644 --- a/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gypi +++ b/oemcrypto/test/fuzz_tests/partner_oemcrypto_fuzztests.gypi @@ -25,7 +25,6 @@ '<(platform_specific_dir)/file_store.cpp', '<(platform_specific_dir)/log.cpp', '<(util_dir)/src/platform.cpp', - '<(util_dir)/src/rw_lock.cpp', '<(util_dir)/src/string_conversions.cpp', '<(util_dir)/src/string_format.cpp', '<(util_dir)/test/test_sleep.cpp', @@ -49,6 +48,7 @@ '../../../third_party/googletest.gyp:gtest', '../../../third_party/googletest.gyp:gmock', '<(oemcrypto_dir)/util/oec_ref_util.gyp:oec_ref_util', + '<(oemcrypto_dir)/util/wvcrc32.gyp:libwvcrc32', ], 'defines': [ 'OEMCRYPTO_FUZZ_TESTS', diff --git a/oemcrypto/test/oec_decrypt_fallback_chain.h b/oemcrypto/test/oec_decrypt_fallback_chain.h index 0bbee74..2be1465 100644 --- a/oemcrypto/test/oec_decrypt_fallback_chain.h +++ b/oemcrypto/test/oec_decrypt_fallback_chain.h @@ -31,6 +31,9 @@ namespace wvoec { // call, we want to test that they correctly reject larger calls. class DecryptFallbackChain { public: + // There is no reason to have an instance of this class. + DecryptFallbackChain() = delete; + static OEMCryptoResult Decrypt( const uint8_t* key_handle, size_t key_handle_length, const OEMCrypto_SampleDescription* samples, size_t samples_length, @@ -56,8 +59,6 @@ class DecryptFallbackChain { const OEMCrypto_CENCEncryptPatternDesc* pattern, OEMCryptoCipherMode cipher_mode); - // There is no reason to have an instance of this class. - DecryptFallbackChain() = delete; CORE_DISALLOW_COPY_AND_ASSIGN(DecryptFallbackChain); }; diff --git a/oemcrypto/test/oec_device_features.cpp b/oemcrypto/test/oec_device_features.cpp index 276e52d..487aba6 100644 --- a/oemcrypto/test/oec_device_features.cpp +++ b/oemcrypto/test/oec_device_features.cpp @@ -12,7 +12,6 @@ #include "log.h" #include "oec_test_data.h" -#include "string_conversions.h" #include "test_sleep.h" namespace wvoec { @@ -87,7 +86,15 @@ void DeviceFeatures::Initialize() { printf("supports_cas = %s.\n", supports_cas ? "true" : "false"); OEMCrypto_CloseSession(session); api_version = OEMCrypto_APIVersion(); - printf("api_version = %u.\n", api_version); + const uint32_t minor_version = OEMCrypto_MinorAPIVersion(); + // This is not exactly right, because it should be the minor version of the + // ODK library that has been compiled into OEMCrypto. But that is hard to + // discover, and we don't have to get this exactly right for test + // purposes. This is only used for information and to skip a few integration + // tests. + prerelease = (api_version >= 20) && (minor_version == 0); + printf("api_version = %u.%u%s.\n", api_version, minor_version, + prerelease ? ", (prerelease)" : ""); if (api_version < kCoreMessagesAPI) { printf("--------- WARNING: minimum API is %d ----------\n", api_version); printf("--------- Expect most tests will fail. --------\n"); diff --git a/oemcrypto/test/oec_device_features.h b/oemcrypto/test/oec_device_features.h index 2eb1dd9..78ba3c1 100644 --- a/oemcrypto/test/oec_device_features.h +++ b/oemcrypto/test/oec_device_features.h @@ -10,7 +10,7 @@ namespace wvoec { // These tests are designed to work for this version: -constexpr unsigned int kCurrentAPI = 19; +constexpr unsigned int kCurrentAPI = 20; // The API version when Core Messages were introduced. constexpr unsigned int kCoreMessagesAPI = 16; // The API version when we stopped encrypting key control blocks. @@ -54,6 +54,7 @@ class DeviceFeatures { bool supports_crc; // Supported decrypt hash type CRC. bool test_secure_buffers; // If we can create a secure buffer for testing. bool supports_cas; // Device supports CAS (Condition Access System). + bool prerelease; // If the device uses a prerelease version of ODK. uint32_t api_version; OEMCrypto_ProvisioningMethod provisioning_method; diff --git a/oemcrypto/test/oec_session_util.cpp b/oemcrypto/test/oec_session_util.cpp index 70297a8..b95d3df 100644 --- a/oemcrypto/test/oec_session_util.cpp +++ b/oemcrypto/test/oec_session_util.cpp @@ -19,12 +19,13 @@ #include #include #include +#include #include +#include #include -#include -#include #include +#include #include #include @@ -218,7 +219,7 @@ class boringssl_ptr { }; Test_PST_Report::Test_PST_Report(const std::string& pst_in, - OEMCrypto_Usage_Entry_Status status_in) + OEMCrypto_UsageEntryStatus status_in) : status(status_in), seconds_since_license_received(0), seconds_since_first_decrypt(0), @@ -226,129 +227,6 @@ Test_PST_Report::Test_PST_Report(const std::string& pst_in, pst(pst_in), time_created(wvutil::Clock().GetCurrentTime()) {} -template -OEMCryptoResult -RoundTrip:: - SignAndCreateRequestWithCustomBufferLengths(bool verify_request) { - // In the real world, a message should be signed by the client and - // verified by the server. This simulates that. - size_t gen_signature_length = 0; - size_t core_message_length = 0; - const vector context = session()->GetDefaultContext(); - const size_t small_size = context.size(); // arbitrary. - if (RequestHasNonce()) { - session()->GenerateNonce(); - } - uint32_t session_id = session()->session_id(); - GetDefaultRequestSignatureAndCoreMessageLengths( - session_id, small_size, &gen_signature_length, &core_message_length); - // Used to test request APIs with varying lengths of core message. - core_message_length = - std::max(core_message_length, required_core_message_size_); - // Used to test request APIs with varying lengths of signature. - gen_signature_length = - std::max(gen_signature_length, required_request_signature_size_); - // Make the message buffer a little bigger than the core message, or the - // required size, whichever is larger. - size_t message_size = - std::max(required_message_size_, core_message_length + small_size); - vector data(message_size); - memcpy(&data[core_message_length], context.data(), context.size()); - for (size_t i = context.size() + core_message_length; i < data.size(); i++) { - data[i] = i & 0xFF; - } - if (ShouldGenerateCorpus()) { - WriteRequestApiCorpus(gen_signature_length, - core_message_length, data); - } - - vector gen_signature(gen_signature_length); - OEMCryptoResult result = PrepAndSignRequest( - session()->session_id(), data.data(), data.size(), &core_message_length, - gen_signature.data(), &gen_signature_length); - // We need to fill in core request and verify signature only for calls other - // than OEMCryptoMemory buffer overflow test. Any test other than buffer - // overflow will pass true. - if (result == OEMCrypto_SUCCESS) { - gen_signature.resize(gen_signature_length); - } - if (!verify_request || result != OEMCrypto_SUCCESS) return result; - std::string core_message(reinterpret_cast(data.data()), - core_message_length); - FillAndVerifyCoreRequest(core_message); - VerifyRequestSignature(data, gen_signature, core_message_length); - return result; -} - -template -void RoundTrip::SetEncryptAndSignResponseLengths() { - encrypted_response_length_ = encrypted_response_.size(); - response_signature_length_ = response_signature_.size(); -} - -template -void RoundTrip::VerifyEncryptAndSignResponseLengths() const { - EXPECT_NE(encrypted_response_length_, 0u); - EXPECT_EQ(encrypted_response_length_, encrypted_response_.size()); - EXPECT_NE(response_signature_length_, 0u); - EXPECT_EQ(response_signature_length_, response_signature_.size()); -} - -template -void GetDefaultRequestSignatureAndCoreMessageLengths( - uint32_t& session_id, const size_t& small_size, - size_t* gen_signature_length, size_t* core_message_length) { - vector data(small_size); - for (size_t i = 0; i < data.size(); i++) data[i] = i & 0xFF; - ASSERT_EQ( - PrepAndSignRequest(session_id, data.data(), data.size(), - core_message_length, nullptr, gen_signature_length), - OEMCrypto_ERROR_SHORT_BUFFER); -} - -template -void RoundTrip::InjectFuzzedRequestData(uint8_t* data, - size_t size) { - OEMCrypto_Request_Fuzz fuzz_structure; - // Copy data into fuzz structure, cap signature length at 1mb as it will be - // used to initialize signature vector. - memcpy(&fuzz_structure, data, sizeof(fuzz_structure)); - fuzz_structure.signature_length = - std::min(fuzz_structure.signature_length, MB); - vector signature(fuzz_structure.signature_length); - - // Interpret rest of data as actual message buffer to request APIs. - uint8_t* message_ptr = data + sizeof(fuzz_structure); - size_t message_size = size - sizeof(fuzz_structure); - PrepAndSignRequest(session()->session_id(), message_ptr, message_size, - &fuzz_structure.core_message_length, signature.data(), - &fuzz_structure.signature_length); -} - -template -OEMCrypto_Substring RoundTrip::FindSubstring(const void* pointer, - size_t length) { - OEMCrypto_Substring substring; - if (length == 0 || pointer == nullptr) { - substring.offset = 0; - substring.length = 0; - } else { - substring.offset = reinterpret_cast(pointer) - - reinterpret_cast(&response_data_); - substring.length = length; - } - return substring; -} - void ProvisioningRoundTrip::PrepareSession( const wvoec::WidevineKeybox& keybox) { ASSERT_NO_FATAL_FAILURE(session_->open()); @@ -438,7 +316,7 @@ void ProvisioningRoundTrip::CreateDefaultResponse() { } else { response_data_.enc_message_key_length = 0; } - core_response_.key_type = OEMCrypto_RSA_Private_Key; + core_response_.key_type = OEMCrypto_RSAPrivateKey; core_response_.enc_private_key = FindSubstring(response_data_.rsa_key, response_data_.rsa_key_length); core_response_.enc_private_key_iv = FindSubstring( @@ -465,8 +343,7 @@ void ProvisioningRoundTrip:: } void ProvisioningRoundTrip::SignResponse() { - CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); + CoreMessageFeatures features(kServePrereleaseMessages); ASSERT_TRUE(oemcrypto_core_message::serialize::CreateCoreProvisioningResponse( features, core_response_, core_request_, &serialized_core_message_)); // Resizing for huge core message length unit tests. @@ -581,7 +458,7 @@ void ProvisioningRoundTrip::VerifyLoadFailed() { ASSERT_EQ(zero, wrapped_rsa_key_); } -void Provisioning40RoundTrip::PrepareSession(bool is_oem_key) { +void Provisioning40TwoStageRoundTrip::PrepareSession(bool is_oem_key) { const size_t buffer_size = 10240; // Make sure it is large enough. std::vector public_key(buffer_size); size_t public_key_size = buffer_size; @@ -610,7 +487,7 @@ void Provisioning40RoundTrip::PrepareSession(bool is_oem_key) { } } -void Provisioning40RoundTrip::FillAndVerifyCoreRequest( +void Provisioning40TwoStageRoundTrip::FillAndVerifyCoreRequest( const std::string& core_message_string) { EXPECT_TRUE( oemcrypto_core_message::deserialize::CoreProvisioning40RequestFromMessage( @@ -620,7 +497,7 @@ void Provisioning40RoundTrip::FillAndVerifyCoreRequest( EXPECT_EQ(session()->session_id(), core_request_.session_id); } -void Provisioning40RoundTrip::VerifyRequestSignature( +void Provisioning40TwoStageRoundTrip::VerifyRequestSignature( const vector& data, const vector& generated_signature, size_t /* core_message_length */) { ASSERT_NO_FATAL_FAILURE( @@ -628,7 +505,7 @@ void Provisioning40RoundTrip::VerifyRequestSignature( generated_signature.size(), kSign_RSASSA_PSS)); } -OEMCryptoResult Provisioning40RoundTrip::LoadOEMCertResponse() { +OEMCryptoResult Provisioning40TwoStageRoundTrip::LoadOEMCertResponse() { EXPECT_GE(wrapped_oem_key_.size(), 0UL); return OEMCrypto_InstallOemPrivateKey( session()->session_id(), oem_key_type_, @@ -636,14 +513,14 @@ OEMCryptoResult Provisioning40RoundTrip::LoadOEMCertResponse() { wrapped_oem_key_.size()); } -OEMCryptoResult Provisioning40RoundTrip::LoadDRMCertResponse() { +OEMCryptoResult Provisioning40TwoStageRoundTrip::LoadDRMCertResponse() { EXPECT_GE(wrapped_drm_key_.size(), 0UL); return OEMCrypto_LoadDRMPrivateKey(session()->session_id(), drm_key_type_, wrapped_drm_key_.data(), wrapped_drm_key_.size()); } -void Provisioning40CastRoundTrip::PrepareSession() { +void Provisioning40OneStageRoundTrip::PrepareSession() { const size_t buffer_size = 10240; // Make sure it is large enough. std::vector public_key(buffer_size); size_t public_key_size = buffer_size; @@ -666,14 +543,7 @@ void Provisioning40CastRoundTrip::PrepareSession() { drm_key_type_ = key_type; } -void Provisioning40CastRoundTrip::LoadDRMPrivateKey() { - ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_LoadDRMPrivateKey(session()->session_id(), drm_key_type_, - wrapped_drm_key_.data(), - wrapped_drm_key_.size())); -} - -void Provisioning40CastRoundTrip::FillAndVerifyCoreRequest( +void Provisioning40OneStageRoundTrip::FillAndVerifyCoreRequest( const std::string& core_message_string) { EXPECT_TRUE( oemcrypto_core_message::deserialize::CoreProvisioning40RequestFromMessage( @@ -683,7 +553,62 @@ void Provisioning40CastRoundTrip::FillAndVerifyCoreRequest( EXPECT_EQ(session()->session_id(), core_request_.session_id); } -void Provisioning40CastRoundTrip::VerifyRequestSignature( +void Provisioning40OneStageRoundTrip::VerifyRequestSignature( + const vector& data, const vector& generated_signature, + size_t /* core_message_length */) { + ASSERT_NO_FATAL_FAILURE( + session()->VerifySignature(data, generated_signature.data(), + generated_signature.size(), kSign_RSASSA_PSS)); +} + +OEMCryptoResult Provisioning40OneStageRoundTrip::LoadDRMCertResponse() { + EXPECT_GE(wrapped_drm_key_.size(), 0UL); + return OEMCrypto_LoadDRMPrivateKey(session()->session_id(), drm_key_type_, + wrapped_drm_key_.data(), + wrapped_drm_key_.size()); +} + +void Provisioning40CastTwoStageRoundTrip::PrepareSession() { + const size_t buffer_size = 10240; // Make sure it is large enough. + std::vector public_key(buffer_size); + size_t public_key_size = buffer_size; + std::vector public_key_signature(buffer_size); + size_t public_key_signature_size = buffer_size; + std::vector wrapped_private_key(buffer_size); + size_t wrapped_private_key_size = buffer_size; + OEMCrypto_PrivateKeyType key_type; + ASSERT_EQ( + OEMCrypto_SUCCESS, + OEMCrypto_GenerateCertificateKeyPair( + session()->session_id(), public_key.data(), &public_key_size, + public_key_signature.data(), &public_key_signature_size, + wrapped_private_key.data(), &wrapped_private_key_size, &key_type)); + wrapped_private_key.resize(wrapped_private_key_size); + public_key.resize(public_key_size); + + wrapped_drm_key_ = std::move(wrapped_private_key); + drm_public_key_ = std::move(public_key); + drm_key_type_ = key_type; +} + +void Provisioning40CastTwoStageRoundTrip::LoadDRMPrivateKey() { + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_LoadDRMPrivateKey(session()->session_id(), drm_key_type_, + wrapped_drm_key_.data(), + wrapped_drm_key_.size())); +} + +void Provisioning40CastTwoStageRoundTrip::FillAndVerifyCoreRequest( + const std::string& core_message_string) { + EXPECT_TRUE( + oemcrypto_core_message::deserialize::CoreProvisioning40RequestFromMessage( + core_message_string, &core_request_)); + EXPECT_EQ(global_features.api_version, core_request_.api_major_version); + EXPECT_EQ(session()->nonce(), core_request_.nonce); + EXPECT_EQ(session()->session_id(), core_request_.session_id); +} + +void Provisioning40CastTwoStageRoundTrip::VerifyRequestSignature( const vector& data, const vector& generated_signature, size_t /* core_message_length */) { ASSERT_NO_FATAL_FAILURE( @@ -692,7 +617,7 @@ void Provisioning40CastRoundTrip::VerifyRequestSignature( } // Creates a prov2 response -void Provisioning40CastRoundTrip::CreateDefaultResponse() { +void Provisioning40CastTwoStageRoundTrip::CreateDefaultResponse() { uint32_t algorithm_n = htonl(allowed_schemes_); memcpy(response_data_.rsa_key, "SIGN", 4); memcpy(response_data_.rsa_key + 4, &algorithm_n, 4); @@ -701,7 +626,7 @@ void Provisioning40CastRoundTrip::CreateDefaultResponse() { response_data_.rsa_key_length = 8 + encoded_rsa_key_.size(); response_data_.nonce = session_->nonce(); response_data_.enc_message_key_length = 0; - core_response_.key_type = OEMCrypto_RSA_Private_Key; + core_response_.key_type = OEMCrypto_RSAPrivateKey; core_response_.enc_private_key = FindSubstring(response_data_.rsa_key, response_data_.rsa_key_length); core_response_.enc_private_key_iv = FindSubstring( @@ -710,7 +635,7 @@ void Provisioning40CastRoundTrip::CreateDefaultResponse() { response_data_.enc_message_key, response_data_.enc_message_key_length); } -void Provisioning40CastRoundTrip::EncryptAndSignResponse() { +void Provisioning40CastTwoStageRoundTrip::EncryptAndSignResponse() { session()->key_deriver().PadAndEncryptProvisioningMessage( &response_data_, &encrypted_response_data_); core_response_.enc_private_key.length = @@ -718,9 +643,8 @@ void Provisioning40CastRoundTrip::EncryptAndSignResponse() { SignResponse(); } -void Provisioning40CastRoundTrip::SignResponse() { - CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(ODK_MAJOR_VERSION); +void Provisioning40CastTwoStageRoundTrip::SignResponse() { + CoreMessageFeatures features(kServePrereleaseMessages); // Create prov 2 request struct from prov 4 request oemcrypto_core_message::ODK_ProvisioningRequest core_request_prov2; @@ -760,7 +684,8 @@ void Provisioning40CastRoundTrip::SignResponse() { SetEncryptAndSignResponseLengths(); } -OEMCryptoResult Provisioning40CastRoundTrip::LoadResponse(Session* session) { +OEMCryptoResult Provisioning40CastTwoStageRoundTrip::LoadResponse( + Session* session) { EXPECT_NE(session, nullptr); // Write corpus for oemcrypto_load_provisioning_fuzz. Fuzz script expects // unencrypted response from provisioning server as input corpus data. @@ -787,7 +712,164 @@ OEMCryptoResult Provisioning40CastRoundTrip::LoadResponse(Session* session) { return sts; } -OEMCryptoResult Provisioning40CastRoundTrip::LoadResponseNoRetry( +OEMCryptoResult Provisioning40CastTwoStageRoundTrip::LoadResponseNoRetry( + Session* session, size_t* wrapped_key_length) { + EXPECT_NE(session, nullptr); + VerifyEncryptAndSignResponseLengths(); + const std::vector context = session->GetDefaultContext(); + return OEMCrypto_LoadProvisioningCast( + session->session_id(), session->enc_session_key().data(), + session->enc_session_key().size(), context.data(), context.size(), + encrypted_response_.data(), encrypted_response_.size(), + serialized_core_message_.size(), response_signature_.data(), + response_signature_.size(), wrapped_rsa_key_.data(), wrapped_key_length); +} + +void Provisioning40CastOneStageRoundTrip::PrepareSession() { + const size_t buffer_size = 10240; // Make sure it is large enough. + std::vector public_key(buffer_size); + size_t public_key_size = buffer_size; + std::vector public_key_signature(buffer_size); + size_t public_key_signature_size = buffer_size; + std::vector wrapped_private_key(buffer_size); + size_t wrapped_private_key_size = buffer_size; + OEMCrypto_PrivateKeyType key_type; + ASSERT_EQ( + OEMCrypto_SUCCESS, + OEMCrypto_GenerateCertificateKeyPair( + session()->session_id(), public_key.data(), &public_key_size, + public_key_signature.data(), &public_key_signature_size, + wrapped_private_key.data(), &wrapped_private_key_size, &key_type)); + wrapped_private_key.resize(wrapped_private_key_size); + public_key.resize(public_key_size); + + wrapped_drm_key_ = std::move(wrapped_private_key); + drm_public_key_ = std::move(public_key); + drm_key_type_ = key_type; +} + +void Provisioning40CastOneStageRoundTrip::LoadDRMPrivateKey() { + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_LoadDRMPrivateKey(session()->session_id(), drm_key_type_, + wrapped_drm_key_.data(), + wrapped_drm_key_.size())); +} + +void Provisioning40CastOneStageRoundTrip::FillAndVerifyCoreRequest( + const std::string& core_message_string) { + EXPECT_TRUE( + oemcrypto_core_message::deserialize::CoreProvisioning40RequestFromMessage( + core_message_string, &core_request_)); + EXPECT_EQ(global_features.api_version, core_request_.api_major_version); + EXPECT_EQ(session()->nonce(), core_request_.nonce); + EXPECT_EQ(session()->session_id(), core_request_.session_id); +} + +void Provisioning40CastOneStageRoundTrip::VerifyRequestSignature( + const vector& data, const vector& generated_signature, + size_t /* core_message_length */) { + ASSERT_NO_FATAL_FAILURE( + session()->VerifySignature(data, generated_signature.data(), + generated_signature.size(), kSign_RSASSA_PSS)); +} + +// Creates a prov2 response +void Provisioning40CastOneStageRoundTrip::CreateDefaultResponse() { + uint32_t algorithm_n = htonl(allowed_schemes_); + memcpy(response_data_.rsa_key, "SIGN", 4); + memcpy(response_data_.rsa_key + 4, &algorithm_n, 4); + memcpy(response_data_.rsa_key + 8, encoded_rsa_key_.data(), + encoded_rsa_key_.size()); + response_data_.rsa_key_length = 8 + encoded_rsa_key_.size(); + response_data_.nonce = session_->nonce(); + response_data_.enc_message_key_length = 0; + core_response_.key_type = OEMCrypto_RSAPrivateKey; + core_response_.enc_private_key = + FindSubstring(response_data_.rsa_key, response_data_.rsa_key_length); + core_response_.enc_private_key_iv = FindSubstring( + response_data_.rsa_key_iv, sizeof(response_data_.rsa_key_iv)); + core_response_.encrypted_message_key = FindSubstring( + response_data_.enc_message_key, response_data_.enc_message_key_length); +} + +void Provisioning40CastOneStageRoundTrip::EncryptAndSignResponse() { + session()->key_deriver().PadAndEncryptProvisioningMessage( + &response_data_, &encrypted_response_data_); + core_response_.enc_private_key.length = + encrypted_response_data_.rsa_key_length; + SignResponse(); +} + +void Provisioning40CastOneStageRoundTrip::SignResponse() { + CoreMessageFeatures features(kServePrereleaseMessages); + + // Create prov 2 request struct from prov 4 request + oemcrypto_core_message::ODK_ProvisioningRequest core_request_prov2; + core_request_prov2.api_minor_version = core_request_.api_minor_version; + core_request_prov2.api_major_version = core_request_.api_major_version; + core_request_prov2.nonce = core_request_.nonce; + core_request_prov2.session_id = core_request_.session_id; + memcpy(&core_request_prov2.counter_info, &core_request_.counter_info, + sizeof(core_request_.counter_info)); + + ASSERT_TRUE(oemcrypto_core_message::serialize::CreateCoreProvisioningResponse( + features, core_response_, core_request_prov2, &serialized_core_message_)); + // Resizing for huge core message length unit tests. + serialized_core_message_.resize( + std::max(required_core_message_size_, serialized_core_message_.size())); + // Make the message buffer a just big enough, or the + // required size, whichever is larger. + const size_t message_size = + std::max(required_message_size_, serialized_core_message_.size() + + sizeof(encrypted_response_data_)); + // Stripe the encrypted message. + encrypted_response_.resize(message_size); + for (size_t i = 0; i < encrypted_response_.size(); i++) { + encrypted_response_[i] = i & 0xFF; + } + ASSERT_GE(encrypted_response_.size(), serialized_core_message_.size()); + memcpy(encrypted_response_.data(), serialized_core_message_.data(), + serialized_core_message_.size()); + ASSERT_GE(encrypted_response_.size(), + serialized_core_message_.size() + sizeof(encrypted_response_data_)); + memcpy(encrypted_response_.data() + serialized_core_message_.size(), + reinterpret_cast(&encrypted_response_data_), + sizeof(encrypted_response_data_)); + session()->key_deriver().ServerSignBuffer(encrypted_response_.data(), + encrypted_response_.size(), + &response_signature_); + SetEncryptAndSignResponseLengths(); +} + +OEMCryptoResult Provisioning40CastOneStageRoundTrip::LoadResponse( + Session* session) { + EXPECT_NE(session, nullptr); + // Write corpus for oemcrypto_load_provisioning_fuzz. Fuzz script expects + // unencrypted response from provisioning server as input corpus data. + // Data will be encrypted and signed again explicitly by fuzzer script after + // mutations. + if (ShouldGenerateCorpus()) { + const std::string file_name = + GetFileName("oemcrypto_load_provisioning_fuzz_seed_corpus"); + // Corpus for license response fuzzer should be in the format: + // unencrypted (core_response + response_data). + AppendToFile(file_name, reinterpret_cast(&core_response_), + sizeof(ODK_ParsedProvisioning)); + AppendToFile(file_name, reinterpret_cast(&response_data_), + sizeof(response_data_)); + } + size_t wrapped_key_length = 0; + OEMCryptoResult sts = LoadResponseNoRetry(session, &wrapped_key_length); + if (sts != OEMCrypto_ERROR_SHORT_BUFFER) return sts; + wrapped_rsa_key_.assign(wrapped_key_length, 0); + sts = LoadResponseNoRetry(session, &wrapped_key_length); + if (sts == OEMCrypto_SUCCESS) { + wrapped_rsa_key_.resize(wrapped_key_length); + } + return sts; +} + +OEMCryptoResult Provisioning40CastOneStageRoundTrip::LoadResponseNoRetry( Session* session, size_t* wrapped_key_length) { EXPECT_NE(session, nullptr); VerifyEncryptAndSignResponseLengths(); @@ -1002,7 +1084,7 @@ void LicenseRoundTrip::FillCoreResponseSubstrings() { core_response_.key_array_length = num_keys_; key_array_.clear(); for (unsigned int i = 0; i < num_keys_; i++) { - OEMCrypto_KeyObject obj; + OEMCrypto_KeyObjectV2 obj; obj.key_id = FindSubstring(response_data_.keys[i].key_id, response_data_.keys[i].key_id_length); obj.key_data_iv = FindSubstring(response_data_.keys[i].key_iv, @@ -1020,6 +1102,10 @@ void LicenseRoundTrip::FillCoreResponseSubstrings() { } obj.key_control = FindSubstring(&response_data_.keys[i].control, sizeof(response_data_.keys[i].control)); + obj.decenc_mitigation_info.mitigation_option = + OEMCrypto_DeCENC_Mitigation_Option_None; + memset(&(obj.decenc_mitigation_info.configuration_options), 0, + sizeof(obj.decenc_mitigation_info.configuration_options)); key_array_.push_back(obj); } core_response_.key_array = key_array_.data(); @@ -1121,8 +1207,10 @@ void LicenseRoundTrip::EncryptAndSignResponse() { // We might try to test a future api_version_, but we can only make a core // message with at most the current ODK version. This is only done to verify // that OEMCrypto does not attempt to load a future version. + const uint32_t api_version = + std::min(api_version_, static_cast(ODK_MAJOR_VERSION)); CoreMessageFeatures features = CoreMessageFeatures::DefaultFeatures( - std::min(api_version_, static_cast(ODK_MAJOR_VERSION))); + api_version, kServePrereleaseMessages); CreateCoreLicenseResponseWithFeatures(features); SignEncryptedResponse(); } @@ -1555,8 +1643,8 @@ void RenewalRoundTrip::EncryptAndSignResponse() { // TODO(b/191724203): Test renewal server has different version from license // server. ASSERT_NE(license_messages_, nullptr); - CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(license_messages_->api_version()); + CoreMessageFeatures features = CoreMessageFeatures::DefaultFeatures( + license_messages_->api_version(), kServePrereleaseMessages); ASSERT_TRUE(oemcrypto_core_message::serialize::CreateCoreRenewalResponse( features, core_request_, renewal_duration_seconds_, &serialized_core_message_)); @@ -1595,8 +1683,8 @@ void RenewalRoundTrip::InjectFuzzedResponseData( // TODO(b/191724203): Test renewal server has different version from license // server. ASSERT_NE(license_messages_, nullptr); - CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(license_messages_->api_version()); + CoreMessageFeatures features = CoreMessageFeatures::DefaultFeatures( + license_messages_->api_version(), kServePrereleaseMessages); // Serializing core message. // This call also sets nonce in core response to match with session nonce. oemcrypto_core_message::serialize::CreateCoreRenewalResponse( @@ -1673,11 +1761,10 @@ void ReleaseRoundTrip::EncryptAndSignResponse() { // TODO(b/191724203): Test release server has different version from license // server. ASSERT_NE(license_messages_, nullptr); - CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(license_messages_->api_version()); + CoreMessageFeatures features = CoreMessageFeatures::DefaultFeatures( + license_messages_->api_version(), kServePrereleaseMessages); ASSERT_TRUE(oemcrypto_core_message::serialize::CreateCoreReleaseResponse( - features, core_request_, seconds_since_license_received_, - seconds_since_first_decrypt_, &serialized_core_message_)); + features, core_request_, &serialized_core_message_)); // Resize serialize core message to be just big enough or required core // message size, whichever is larger. serialized_core_message_.resize( @@ -1711,14 +1798,12 @@ void ReleaseRoundTrip::InjectFuzzedResponseData( const OEMCrypto_Release_Response_Fuzz& fuzzed_data, const uint8_t* release_response, const size_t release_response_size) { ASSERT_NE(license_messages_, nullptr); - CoreMessageFeatures features = - CoreMessageFeatures::DefaultFeatures(license_messages_->api_version()); + CoreMessageFeatures features = CoreMessageFeatures::DefaultFeatures( + license_messages_->api_version(), kServePrereleaseMessages); // Serializing core message. // This call also sets nonce in core response to match with session nonce. oemcrypto_core_message::serialize::CreateCoreReleaseResponse( - features, fuzzed_data.core_request, - fuzzed_data.seconds_since_license_received, - fuzzed_data.seconds_since_first_decrypt, &serialized_core_message_); + features, fuzzed_data.core_request, &serialized_core_message_); // Copy serialized core message and encrypted response from data and // calculate signature. Now we will have a valid signature for data @@ -1743,10 +1828,6 @@ OEMCryptoResult ReleaseRoundTrip::LoadResponse(Session* session) { // OEMCrypto_Release_Response_Fuzz + license_release_response. OEMCrypto_Release_Response_Fuzz release_response_fuzz; release_response_fuzz.core_request = core_request_; - release_response_fuzz.seconds_since_license_received = - seconds_since_license_received_; - release_response_fuzz.seconds_since_first_decrypt = - seconds_since_first_decrypt_; AppendToFile(file_name, reinterpret_cast(&release_response_fuzz), sizeof(release_response_fuzz)); @@ -1796,7 +1877,11 @@ void Session::close() { void Session::GenerateNonce(int* error_counter) { // We make one attempt. If it fails, we assume there was a nonce flood. - if (OEMCrypto_SUCCESS == OEMCrypto_GenerateNonce(session_id(), &nonce_)) { + // Using |temp_nonce| to avoid member |nonce_| being modified + // during failure. + uint32_t temp_nonce = 0; + if (OEMCrypto_SUCCESS == OEMCrypto_GenerateNonce(session_id(), &temp_nonce)) { + nonce_ = temp_nonce; return; } if (error_counter) { @@ -1806,7 +1891,8 @@ void Session::GenerateNonce(int* error_counter) { // The following is after a 1 second pause, so it cannot be from a nonce // flood. ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_GenerateNonce(session_id(), &nonce_)); + OEMCrypto_GenerateNonce(session_id(), &temp_nonce)); + nonce_ = temp_nonce; } } @@ -2050,11 +2136,11 @@ void Session::SetPublicKeyFromPrivateKeyInfo(OEMCrypto_PrivateKeyType key_type, const uint8_t* buffer, size_t length) { switch (key_type) { - case OEMCrypto_RSA_Private_Key: + case OEMCrypto_RSAPrivateKey: ASSERT_NO_FATAL_FAILURE( SetRsaPublicKeyFromPrivateKeyInfo(buffer, length)); return; - case OEMCrypto_ECC_Private_Key: + case OEMCrypto_ECCPrivateKey: ASSERT_NO_FATAL_FAILURE( SetEccPublicKeyFromPrivateKeyInfo(buffer, length)); return; @@ -2079,11 +2165,11 @@ void Session::SetEccPublicKeyFromPrivateKeyInfo(const uint8_t* buffer, void Session::SetPublicKeyFromSubjectPublicKey( OEMCrypto_PrivateKeyType key_type, const uint8_t* buffer, size_t length) { switch (key_type) { - case OEMCrypto_RSA_Private_Key: + case OEMCrypto_RSAPrivateKey: ASSERT_NO_FATAL_FAILURE( SetRsaPublicKeyFromSubjectPublicKey(buffer, length)); return; - case OEMCrypto_ECC_Private_Key: + case OEMCrypto_ECCPrivateKey: ASSERT_NO_FATAL_FAILURE( SetEccPublicKeyFromSubjectPublicKey(buffer, length)); return; @@ -2115,18 +2201,22 @@ void Session::VerifyRsaSignature(const vector& message, FAIL() << "Padding scheme not supported: " << padding_scheme; return; } - const util::RsaSignatureAlgorithm algorithm = - padding_scheme == kSign_RSASSA_PSS ? util::kRsaPssDefault - : util::kRsaPkcs1Cast; - OEMCrypto_SignatureHashAlgorithm hash_algorithm = OEMCrypto_SHA1; - if (algorithm == util::kRsaPssDefault) { - ASSERT_THAT( - OEMCrypto_GetSignatureHashAlgorithm(session_id(), &hash_algorithm), - AnyOf(OEMCrypto_SUCCESS, OEMCrypto_ERROR_NOT_IMPLEMENTED)); + util::RsaSignatureAlgorithm algorithm = padding_scheme == kSign_RSASSA_PSS + ? util::kRsaPssSha1 + : util::kRsaPkcs1Cast; + if (algorithm == util::kRsaPssSha1) { + OEMCrypto_SignatureHashAlgorithm hash_algorithm = OEMCrypto_SHA1; + const OEMCryptoResult result = + OEMCrypto_GetSignatureHashAlgorithm(session_id(), &hash_algorithm); + ASSERT_THAT(result, + AnyOf(OEMCrypto_SUCCESS, OEMCrypto_ERROR_NOT_IMPLEMENTED)); + if (result == OEMCrypto_SUCCESS) { + algorithm = + util::RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm(hash_algorithm); + } } - const OEMCryptoResult result = - public_rsa_->VerifySignature(message.data(), message.size(), signature, - signature_length, algorithm, hash_algorithm); + const OEMCryptoResult result = public_rsa_->VerifySignature( + message.data(), message.size(), signature, signature_length, algorithm); ASSERT_EQ(result, OEMCrypto_SUCCESS) << "RSA signature check failed"; } @@ -2208,12 +2298,12 @@ void Session::LoadWrappedDrmKey(OEMCrypto_PrivateKeyType key_type, void Session::LoadWrappedRsaDrmKey(const vector& wrapped_rsa_key) { ASSERT_NO_FATAL_FAILURE( - LoadWrappedDrmKey(OEMCrypto_RSA_Private_Key, wrapped_rsa_key)); + LoadWrappedDrmKey(OEMCrypto_RSAPrivateKey, wrapped_rsa_key)); } void Session::LoadWrappedEccDrmKey(const vector& wrapped_ecc_key) { ASSERT_NO_FATAL_FAILURE( - LoadWrappedDrmKey(OEMCrypto_ECC_Private_Key, wrapped_ecc_key)); + LoadWrappedDrmKey(OEMCrypto_ECCPrivateKey, wrapped_ecc_key)); } void Session::CreateNewUsageEntry(OEMCryptoResult* status) { @@ -2293,8 +2383,8 @@ void Session::GenerateReport(const std::string& pst, key_deriver_.ClientSignPstReport(pst_report_buffer_, &computed_signature); EXPECT_EQ(0, memcmp(computed_signature.data(), pst_report().signature(), SHA_DIGEST_LENGTH)); - EXPECT_GE(kInactiveUnused, pst_report().status()); - EXPECT_GE(kHardwareSecureClock, pst_report().clock_security_level()); + EXPECT_GE(OEMCrypto_InactiveUnused, pst_report().status()); + EXPECT_GE(OEMCrypto_HardwareSecureClock, pst_report().clock_security_level()); EXPECT_EQ(pst.length(), pst_report().pst_length()); EXPECT_EQ(0, memcmp(pst.c_str(), pst_report().pst(), pst.length())); } @@ -2310,7 +2400,8 @@ void Session::VerifyPST(const Test_PST_Report& expected) { EXPECT_NEAR(expected.seconds_since_license_received + age, computed.seconds_since_license_received(), kTimeTolerance); // Decrypt times only valid on licenses that have been active. - if (expected.status == kActive || expected.status == kInactiveUsed) { + if (expected.status == OEMCrypto_Active || + expected.status == OEMCrypto_InactiveUsed) { EXPECT_NEAR(expected.seconds_since_first_decrypt + age, computed.seconds_since_first_decrypt(), kUsageTableTimeTolerance); @@ -2354,37 +2445,6 @@ bool ConvertByteToValidBoolean(const bool* in) { return false; } -template -void WriteRequestApiCorpus(size_t signature_length, size_t core_message_length, - vector& data) { - std::string file_name; - if (std::is_same::value) { - file_name = GetFileName("oemcrypto_license_request_fuzz_seed_corpus"); - } else if (std::is_same< - CoreRequest, - oemcrypto_core_message::ODK_ProvisioningRequest>::value) { - file_name = GetFileName("oemcrypto_provisioning_request_fuzz_seed_corpus"); - } else if (std::is_same::value) { - file_name = GetFileName("oemcrypto_renewal_request_fuzz_seed_corpus"); - } else if (std::is_same::value) { - file_name = GetFileName("oemcrypto_release_request_fuzz_seed_corpus"); - } else { - LOGE("Invalid CoreRequest type while writing request api corups."); - } - // Corpus for request APIs should be signature_length + core_message_length + - // data pointer. - OEMCrypto_Request_Fuzz request_fuzz_struct; - request_fuzz_struct.core_message_length = core_message_length; - request_fuzz_struct.signature_length = signature_length; - AppendToFile(file_name, reinterpret_cast(&request_fuzz_struct), - sizeof(OEMCrypto_Request_Fuzz)); - AppendToFile(file_name, reinterpret_cast(data.data()), - data.size()); -} - OEMCryptoResult GetKeyHandleIntoVector(OEMCrypto_SESSION session, const uint8_t* key_id, size_t key_id_length, diff --git a/oemcrypto/test/oec_session_util.h b/oemcrypto/test/oec_session_util.h index 3f2ee9a..69280c1 100644 --- a/oemcrypto/test/oec_session_util.h +++ b/oemcrypto/test/oec_session_util.h @@ -7,21 +7,29 @@ // // OEMCrypto unit tests // -#include -#include -#include +#include +#include +#include +#include + +#include #include +#include #include +#include +#include #include #include "core_message_deserialize.h" #include "core_message_features.h" #include "core_message_serialize.h" +#include "log.h" #include "odk.h" #include "oec_device_features.h" #include "oec_key_deriver.h" #include "oemcrypto_ecc_key.h" +#include "oemcrypto_corpus_generator_helper.h" #include "oemcrypto_fuzz_structs.h" #include "oemcrypto_rsa_key.h" #include "oemcrypto_types.h" @@ -67,6 +75,8 @@ constexpr int kDefaultKeyIdLength = 16; constexpr size_t kMaxPSTLength = 255; // In specification. constexpr size_t kMaxCoreMessage = 200 * kMaxNumKeys + 200; // Rough estimate. +constexpr bool kServePrereleaseMessages = true; + typedef struct { uint8_t key_id[kTestKeyIdMaxLength]; size_t key_id_length; @@ -92,9 +102,9 @@ struct MessageData { struct Test_PST_Report { Test_PST_Report(const std::string& pst_in, - OEMCrypto_Usage_Entry_Status status_in); + OEMCrypto_UsageEntryStatus status_in); - OEMCrypto_Usage_Entry_Status status; + OEMCrypto_UsageEntryStatus status; int64_t seconds_since_license_received; int64_t seconds_since_first_decrypt; int64_t seconds_since_last_decrypt; @@ -167,7 +177,7 @@ class RoundTrip { // Have OEMCrypto sign a request message and then verify the signature and the // core message. - virtual void SignAndVerifyRequest() { + void SignAndVerifyRequest() { // Boolean true generates core request and verifies the request. // Custom message sizes are 0 by default, so the behavior of following // functions will be sign and verify request without any custom buffers @@ -177,11 +187,11 @@ class RoundTrip { } // Have OEMCrypto sign and call create request APIs. Buffer parameters in API // can be set to custom values to test with varying lengths of buffers. - virtual OEMCryptoResult SignAndCreateRequestWithCustomBufferLengths( + OEMCryptoResult SignAndCreateRequestWithCustomBufferLengths( bool verify_request = false); // Used for OEMCrypto Fuzzing: Function to convert fuzzer data to valid // License/Provisioning/Renwal request data that can be serialized. - virtual void InjectFuzzedRequestData(uint8_t* data, size_t size); + void InjectFuzzedRequestData(uint8_t* data, size_t size); // Create a default |response_data| and |core_response|. virtual void CreateDefaultResponse() = 0; // Copy fields from |response_data| to |padded_response_data|, encrypting @@ -191,7 +201,7 @@ class RoundTrip { // Attempt to load the response and return the error. Short buffer errors are // handled by LoadResponse, not the caller. All other errors should be // handled by the caller. - virtual OEMCryptoResult LoadResponse() { return LoadResponse(session_); } + OEMCryptoResult LoadResponse() { return LoadResponse(session_); } // As with LoadResponse, but load into a different session. virtual OEMCryptoResult LoadResponse(Session* session) = 0; @@ -237,7 +247,7 @@ class RoundTrip { virtual void FillAndVerifyCoreRequest( const std::string& core_message_string) = 0; // Find the given pointer in the response_data_. - virtual OEMCrypto_Substring FindSubstring(const void* pointer, size_t length); + OEMCrypto_Substring FindSubstring(const void* pointer, size_t length); // Set EncryptAndSignResponse output lengths for later verification. void SetEncryptAndSignResponseLengths(); @@ -265,6 +275,161 @@ class RoundTrip { size_t response_signature_length_; }; +template +void GetDefaultRequestSignatureAndCoreMessageLengths( + uint32_t& session_id, const size_t& small_size, + size_t* gen_signature_length, size_t* core_message_length) { + vector data(small_size); + for (size_t i = 0; i < data.size(); i++) data[i] = i & 0xFF; + ASSERT_EQ( + PrepAndSignRequest(session_id, data.data(), data.size(), + core_message_length, nullptr, gen_signature_length), + OEMCrypto_ERROR_SHORT_BUFFER); +} + +// Used for OEMCrypto Fuzzing: Generates corpus for request APIs. +template +void WriteRequestApiCorpus(size_t signature_length, size_t core_message_length, + vector& data) { + std::string file_name; + if (std::is_same::value) { + file_name = GetFileName("oemcrypto_license_request_fuzz_seed_corpus"); + } else if (std::is_same< + CoreRequest, + oemcrypto_core_message::ODK_ProvisioningRequest>::value) { + file_name = GetFileName("oemcrypto_provisioning_request_fuzz_seed_corpus"); + } else if (std::is_same::value) { + file_name = GetFileName("oemcrypto_renewal_request_fuzz_seed_corpus"); + } else if (std::is_same::value) { + file_name = GetFileName("oemcrypto_release_request_fuzz_seed_corpus"); + } else { + LOGE("Invalid CoreRequest type while writing request api corups."); + } + // Corpus for request APIs should be signature_length + core_message_length + + // data pointer. + OEMCrypto_Request_Fuzz request_fuzz_struct; + request_fuzz_struct.core_message_length = core_message_length; + request_fuzz_struct.signature_length = signature_length; + AppendToFile(file_name, reinterpret_cast(&request_fuzz_struct), + sizeof(OEMCrypto_Request_Fuzz)); + AppendToFile(file_name, reinterpret_cast(data.data()), + data.size()); +} + +template +OEMCryptoResult +RoundTrip:: + SignAndCreateRequestWithCustomBufferLengths(bool verify_request) { + // In the real world, a message should be signed by the client and + // verified by the server. This simulates that. + size_t gen_signature_length = 0; + size_t core_message_length = 0; + const vector context = session()->GetDefaultContext(); + const size_t small_size = context.size(); // arbitrary. + if (RequestHasNonce()) { + session()->GenerateNonce(); + } + uint32_t session_id = session()->session_id(); + GetDefaultRequestSignatureAndCoreMessageLengths( + session_id, small_size, &gen_signature_length, &core_message_length); + // Used to test request APIs with varying lengths of core message. + core_message_length = + std::max(core_message_length, required_core_message_size_); + // Used to test request APIs with varying lengths of signature. + gen_signature_length = + std::max(gen_signature_length, required_request_signature_size_); + // Make the message buffer a little bigger than the core message, or the + // required size, whichever is larger. + size_t message_size = + std::max(required_message_size_, core_message_length + small_size); + vector data(message_size); + memcpy(&data[core_message_length], context.data(), context.size()); + for (size_t i = context.size() + core_message_length; i < data.size(); i++) { + data[i] = i & 0xFF; + } + if (ShouldGenerateCorpus()) { + WriteRequestApiCorpus(gen_signature_length, + core_message_length, data); + } + + vector gen_signature(gen_signature_length); + OEMCryptoResult result = PrepAndSignRequest( + session()->session_id(), data.data(), data.size(), &core_message_length, + gen_signature.data(), &gen_signature_length); + // We need to fill in core request and verify signature only for calls other + // than OEMCryptoMemory buffer overflow test. Any test other than buffer + // overflow will pass true. + if (result == OEMCrypto_SUCCESS) { + gen_signature.resize(gen_signature_length); + } + if (!verify_request || result != OEMCrypto_SUCCESS) return result; + std::string core_message(reinterpret_cast(data.data()), + core_message_length); + FillAndVerifyCoreRequest(core_message); + VerifyRequestSignature(data, gen_signature, core_message_length); + return result; +} + +template +void RoundTrip::SetEncryptAndSignResponseLengths() { + encrypted_response_length_ = encrypted_response_.size(); + response_signature_length_ = response_signature_.size(); +} + +template +void RoundTrip::VerifyEncryptAndSignResponseLengths() const { + EXPECT_NE(encrypted_response_length_, 0u); + EXPECT_EQ(encrypted_response_length_, encrypted_response_.size()); + EXPECT_NE(response_signature_length_, 0u); + EXPECT_EQ(response_signature_length_, response_signature_.size()); +} + +template +void RoundTrip::InjectFuzzedRequestData(uint8_t* data, + size_t size) { + OEMCrypto_Request_Fuzz fuzz_structure; + // Copy data into fuzz structure, cap signature length at 1mb as it will be + // used to initialize signature vector. + memcpy(&fuzz_structure, data, sizeof(fuzz_structure)); + fuzz_structure.signature_length = + std::min(fuzz_structure.signature_length, MB); + vector signature(fuzz_structure.signature_length); + + // Interpret rest of data as actual message buffer to request APIs. + uint8_t* message_ptr = data + sizeof(fuzz_structure); + size_t message_size = size - sizeof(fuzz_structure); + PrepAndSignRequest(session()->session_id(), message_ptr, message_size, + &fuzz_structure.core_message_length, signature.data(), + &fuzz_structure.signature_length); +} + +template +OEMCrypto_Substring RoundTrip::FindSubstring(const void* pointer, + size_t length) { + OEMCrypto_Substring substring; + if (length == 0 || pointer == nullptr) { + substring.offset = 0; + substring.length = 0; + } else { + substring.offset = reinterpret_cast(pointer) - + reinterpret_cast(&response_data_); + substring.length = length; + } + return substring; +} + class ProvisioningRoundTrip : public RoundTrip< /* CoreRequest */ oemcrypto_core_message::ODK_ProvisioningRequest, @@ -279,12 +444,12 @@ class ProvisioningRoundTrip keybox_(nullptr), encoded_rsa_key_(encoded_rsa_key) {} // Prepare the session for signing the request. - virtual void PrepareSession(const wvoec::WidevineKeybox& keybox); + void PrepareSession(const wvoec::WidevineKeybox& keybox); void CreateDefaultResponse() override; void EncryptAndSignResponse() override; void EncryptAndSignResponseWithoutUpdatingEncPrivateKeyLength(); void SignResponse(); - OEMCryptoResult LoadResponse() override { return LoadResponse(session_); } + using RoundTrip::LoadResponse; OEMCryptoResult LoadResponse(Session* session) override; void VerifyLoadFailed(); const std::vector& request() { return request_; } @@ -325,23 +490,24 @@ class ProvisioningRoundTrip std::vector wrapped_rsa_key_; }; -class Provisioning40RoundTrip +// In OEMCrypto v18 and v19, provisioning 4.0 uses two-stages. +// First stage is OEM certificate, second stage is DRM certificate. +class Provisioning40TwoStageRoundTrip : public RoundTrip< /* CoreRequest */ oemcrypto_core_message::ODK_Provisioning40Request, OEMCrypto_PrepAndSignProvisioningRequest, /* CoreResponse */ ODK_ParsedProvisioning, /* ResponseData */ Prov40CertMessage> { public: - Provisioning40RoundTrip(Session* session) - : RoundTrip(session), allowed_schemes_(kSign_RSASSA_PSS) {} + Provisioning40TwoStageRoundTrip(Session* session) : RoundTrip(session) {} void PrepareSession(bool is_oem_key); // Not used. Use Load*CertResponse() below to load OEM/DRM response // respectively. - void CreateDefaultResponse() override {}; - void EncryptAndSignResponse() override {}; - OEMCryptoResult LoadResponse(Session* session) override { - (void)session; + void CreateDefaultResponse() override {} + void EncryptAndSignResponse() override {} + using RoundTrip::LoadResponse; + OEMCryptoResult LoadResponse(Session*) override { return OEMCrypto_ERROR_NOT_IMPLEMENTED; } @@ -367,35 +533,142 @@ class Provisioning40RoundTrip virtual void FillAndVerifyCoreRequest( const std::string& core_message_string) override; - uint32_t allowed_schemes_; + uint32_t allowed_schemes_ = kSign_RSASSA_PSS; + std::vector wrapped_oem_key_; std::vector oem_public_key_; OEMCrypto_PrivateKeyType oem_key_type_; + std::vector wrapped_drm_key_; std::vector drm_public_key_; OEMCrypto_PrivateKeyType drm_key_type_; }; -class Provisioning40CastRoundTrip +// Starting in OEMCrypto v20, provisioning 4.0 uses one-stage +// provisioning. +class Provisioning40OneStageRoundTrip + : public RoundTrip< + /* CoreRequest */ oemcrypto_core_message::ODK_Provisioning40Request, + OEMCrypto_PrepAndSignProvisioningRequest, + /* CoreResponse */ ODK_ParsedProvisioning, + /* ResponseData */ Prov40CertMessage> { + public: + Provisioning40OneStageRoundTrip(Session* session) : RoundTrip(session) {} + void PrepareSession(); + + // Not used. Use Load*CertResponse() below to load OEM/DRM response + // respectively. + void CreateDefaultResponse() override {} + void EncryptAndSignResponse() override {} + using RoundTrip::LoadResponse; + OEMCryptoResult LoadResponse(Session*) override { + return OEMCrypto_ERROR_NOT_IMPLEMENTED; + } + + OEMCryptoResult LoadDRMCertResponse(); + + const std::vector& wrapped_drm_key() { return wrapped_drm_key_; } + const std::vector& drm_public_key() { return drm_public_key_; } + OEMCrypto_PrivateKeyType drm_key_type() { return drm_key_type_; } + void set_allowed_schemes(uint32_t allowed_schemes) { + allowed_schemes_ = allowed_schemes; + } + + protected: + bool RequestHasNonce() override { return true; } + void VerifyRequestSignature(const vector& data, + const vector& generated_signature, + size_t core_message_length) override; + // Verify the values of the core response. + virtual void FillAndVerifyCoreRequest( + const std::string& core_message_string) override; + + uint32_t allowed_schemes_ = kSign_RSASSA_PSS; + std::vector wrapped_drm_key_; + std::vector drm_public_key_; + OEMCrypto_PrivateKeyType drm_key_type_; +}; + +// In OEMCrypto v18 and v19, cast provisioning 4.0 uses two-stages. +// First stage is OEM certificate, second stage is DRM certificate. +class Provisioning40CastTwoStageRoundTrip : public RoundTrip< /* CoreRequest */ oemcrypto_core_message::ODK_Provisioning40Request, OEMCrypto_PrepAndSignProvisioningRequest, /* CoreResponse */ ODK_ParsedProvisioning, /* ResponseData */ RSAPrivateKeyMessage> { public: - Provisioning40CastRoundTrip(Session* session, - const std::vector& encoded_rsa_key) - : RoundTrip(session), - allowed_schemes_(kSign_RSASSA_PSS), - encryptor_(), - encoded_rsa_key_(encoded_rsa_key) {} + Provisioning40CastTwoStageRoundTrip( + Session* session, const std::vector& encoded_rsa_key) + : RoundTrip(session), encoded_rsa_key_(encoded_rsa_key) {} void PrepareSession(); void LoadDRMPrivateKey(); void CreateDefaultResponse() override; void SignResponse(); void EncryptAndSignResponse() override; - OEMCryptoResult LoadResponse() override { return LoadResponse(session_); } + using RoundTrip::LoadResponse; + OEMCryptoResult LoadResponse(Session* session) override; + OEMCryptoResult LoadResponseNoRetry(Session* session, + size_t* wrapped_key_length); + + // Returned + const std::vector& wrapped_oem_key() { return wrapped_oem_key_; } + const std::vector& oem_public_key() { return oem_public_key_; } + OEMCrypto_PrivateKeyType oem_key_type() { return oem_key_type_; } + + const std::vector& wrapped_drm_key() { return wrapped_drm_key_; } + const std::vector& wrapped_rsa_key() { return wrapped_rsa_key_; } + const std::vector& drm_public_key() { return drm_public_key_; } + OEMCrypto_PrivateKeyType drm_key_type() { return drm_key_type_; } + + void set_allowed_schemes(uint32_t allowed_schemes) { + allowed_schemes_ = allowed_schemes; + } + + protected: + bool RequestHasNonce() override { return true; } + void VerifyRequestSignature(const vector& data, + const vector& generated_signature, + size_t core_message_length) override; + // Verify the values of the core response. + virtual void FillAndVerifyCoreRequest( + const std::string& core_message_string) override; + + uint32_t allowed_schemes_ = kSign_RSASSA_PSS; + Encryptor encryptor_; + // OEM key. + std::vector wrapped_oem_key_; + std::vector oem_public_key_; + OEMCrypto_PrivateKeyType oem_key_type_; + // DRM key + std::vector wrapped_drm_key_; + std::vector drm_public_key_; + OEMCrypto_PrivateKeyType drm_key_type_; + // Cast key. + std::vector encoded_rsa_key_; + std::vector wrapped_rsa_key_; +}; + +// Starting in OEMCrypto v20, cast provisioning 4.0 uses one-stage +// provisioning. +class Provisioning40CastOneStageRoundTrip + : public RoundTrip< + /* CoreRequest */ oemcrypto_core_message::ODK_Provisioning40Request, + OEMCrypto_PrepAndSignProvisioningRequest, + /* CoreResponse */ ODK_ParsedProvisioning, + /* ResponseData */ RSAPrivateKeyMessage> { + public: + Provisioning40CastOneStageRoundTrip( + Session* session, const std::vector& encoded_rsa_key) + : RoundTrip(session), encoded_rsa_key_(encoded_rsa_key) {} + + void PrepareSession(); + void LoadDRMPrivateKey(); + void CreateDefaultResponse() override; + void SignResponse(); + void EncryptAndSignResponse() override; + using RoundTrip::LoadResponse; OEMCryptoResult LoadResponse(Session* session) override; OEMCryptoResult LoadResponseNoRetry(Session* session, size_t* wrapped_key_length); @@ -418,14 +691,13 @@ class Provisioning40CastRoundTrip virtual void FillAndVerifyCoreRequest( const std::string& core_message_string) override; - uint32_t allowed_schemes_; + uint32_t allowed_schemes_ = kSign_RSASSA_PSS; Encryptor encryptor_; - std::vector wrapped_oem_key_; - std::vector oem_public_key_; - OEMCrypto_PrivateKeyType oem_key_type_; + // DRM key std::vector wrapped_drm_key_; std::vector drm_public_key_; OEMCrypto_PrivateKeyType drm_key_type_; + // Cast key. std::vector encoded_rsa_key_; std::vector wrapped_rsa_key_; }; @@ -467,7 +739,7 @@ class LicenseRoundTrip // is allowed only one type of operation. void CreateResponseWithGenericCryptoKeys(); // Fill the |core_response| substrings. - virtual void FillCoreResponseSubstrings(); + void FillCoreResponseSubstrings(); void EncryptAndSignResponse() override; // Encrypt and sign license response created from a specific odk version. void EncryptAndSignResponseWithCoreMessageFeatures( @@ -481,7 +753,7 @@ class LicenseRoundTrip const oemcrypto_core_message::features::CoreMessageFeatures& features); // Sign license response. This is used in EncryptAndSignResponse(). void SignEncryptedResponse(); - OEMCryptoResult LoadResponse() override { return LoadResponse(session_); } + using RoundTrip::LoadResponse; OEMCryptoResult LoadResponse(Session* session) override; OEMCryptoResult LoadResponse(Session* session, bool verify_keys); // Reload an offline license into a different session. This derives new mac @@ -563,7 +835,7 @@ class LicenseRoundTrip uint8_t request_hash_[ODK_SHA256_HASH_SIZE]; // Used to hold and add/update key information to be transferred into the core // response later on. - std::vector key_array_; + std::vector key_array_; }; class RenewalRoundTrip @@ -587,7 +859,7 @@ class RenewalRoundTrip void InjectFuzzedResponseData( const OEMCrypto_Renewal_Response_Fuzz& fuzzed_data, const uint8_t* renewal_response, size_t renewal_response_size); - OEMCryptoResult LoadResponse() override { return LoadResponse(session_); } + using RoundTrip::LoadResponse; OEMCryptoResult LoadResponse(Session* session) override; uint64_t renewal_duration_seconds() const { return renewal_duration_seconds_; @@ -627,21 +899,8 @@ class ReleaseRoundTrip void InjectFuzzedResponseData( const OEMCrypto_Release_Response_Fuzz& fuzzed_data, const uint8_t* release_response, size_t release_response_size); - OEMCryptoResult LoadResponse() override { return LoadResponse(session_); } + using RoundTrip::LoadResponse; OEMCryptoResult LoadResponse(Session* session) override; - int64_t seconds_since_license_received() const { - return seconds_since_license_received_; - } - void set_seconds_since_license_received( - int64_t seconds_since_license_received) { - seconds_since_license_received_ = seconds_since_license_received; - } - int64_t seconds_since_first_decrypt() const { - return seconds_since_first_decrypt_; - } - void set_seconds_since_first_decrypt(int64_t seconds_since_first_decrypt) { - seconds_since_first_decrypt_ = seconds_since_first_decrypt; - } protected: bool RequestHasNonce() override { return false; } @@ -652,8 +911,6 @@ class ReleaseRoundTrip virtual void FillAndVerifyCoreRequest( const std::string& core_message_string) override; LicenseRoundTrip* license_messages_; - int64_t seconds_since_license_received_; - int64_t seconds_since_first_decrypt_; }; class EntitledMessage { @@ -950,14 +1207,6 @@ class Session { // Used for OEMCrypto Fuzzing: Convert byte to a valid boolean to avoid errors // generated by msan. bool ConvertByteToValidBoolean(const bool* in); -// Used for OEMCrypto Fuzzing: Generates corpus for request APIs. -template -void WriteRequestApiCorpus(size_t signature_length, size_t core_message_length, - vector& data); -template -void GetDefaultRequestSignatureAndCoreMessageLengths( - uint32_t& session_id, const size_t& small_size, - size_t* gen_signature_length, size_t* core_message_length); // Loads the key matching the given |key_id| into the |session| in OEMCrypto for // the given |cipher_mode| and returns a handle to that key. This function // handles negotiating the size of the |key_handle| buffer. For non-bypassing diff --git a/oemcrypto/test/oemcrypto_basic_test.cpp b/oemcrypto/test/oemcrypto_basic_test.cpp index 71ad453..eaac56a 100644 --- a/oemcrypto/test/oemcrypto_basic_test.cpp +++ b/oemcrypto/test/oemcrypto_basic_test.cpp @@ -317,14 +317,14 @@ TEST_F(OEMCryptoClientTest, FreeUnallocatedSecureBufferNoFailure) { */ TEST_F(OEMCryptoClientTest, VersionNumber) { const std::string log_message = - "OEMCrypto unit tests for API 19.5. Tests last updated 2025-03-11"; + "OEMCrypto unit tests for API 20.0. Tests last updated 2025-05-20"; cout << " " << log_message << "\n"; - cout << " " << "These tests are part of Android V." << "\n"; + cout << " " << "These tests are part of Android 17." << "\n"; LOGI("%s", log_message.c_str()); // If any of the following fail, then it is time to update the log message // above. - EXPECT_EQ(ODK_MAJOR_VERSION, 19); - EXPECT_EQ(ODK_MINOR_VERSION, 5); + EXPECT_EQ(ODK_MAJOR_VERSION, 20); + EXPECT_EQ(ODK_MINOR_VERSION, 0); EXPECT_EQ(kCurrentAPI, static_cast(ODK_MAJOR_VERSION)); RecordWvProperty("test_major_version", std::to_string(ODK_MAJOR_VERSION)); RecordWvProperty("test_minor_version", std::to_string(ODK_MINOR_VERSION)); @@ -498,45 +498,58 @@ TEST_F(OEMCryptoClientTest, CheckBuildInformation_OutputLengthAPI17) { ASSERT_GT(build_info_length, kZero) << "Signaling ERROR_SHORT_BUFFER should have assigned a length"; + // Try again using the size they provided, ensuring that it + // is successful. + const size_t initial_estimate_length = build_info_length; + build_info.assign(build_info_length, kNullChar); + result = OEMCrypto_BuildInformation(&build_info[0], &build_info_length); + ASSERT_EQ(result, OEMCrypto_SUCCESS) + << "initial_estimate_length = " << initial_estimate_length + << ", build_info_length (output) = " << build_info_length; + ASSERT_GT(build_info_length, kZero) << "Build info cannot be empty"; + // Ensure the real length is within the size originally specified. + // OK if final length is smaller than estimated length. + ASSERT_LE(build_info_length, initial_estimate_length); + const size_t expected_length = build_info_length; + // Force a ERROR_SHORT_BUFFER using a non-zero value. // Note: It is assumed that vendors will provide more than a single // character of info. - const size_t second_attempt_length = - (build_info_length >= 2) ? build_info_length / 2 : 1; - build_info.assign(second_attempt_length, kNullChar); + const size_t short_length = (expected_length >= 2) ? expected_length / 2 : 1; + build_info.assign(short_length, kNullChar); build_info_length = build_info.size(); result = OEMCrypto_BuildInformation(&build_info[0], &build_info_length); ASSERT_EQ(result, OEMCrypto_ERROR_SHORT_BUFFER) - << "second_attempt_length = " << second_attempt_length - << ", build_info_length" << build_info_length; + << "short_length = " << short_length + << ", expected_length = " << expected_length << ", build_info_length" + << build_info_length; // OEM specified build info length should be larger than the // original length if returning ERROR_SHORT_BUFFER. - ASSERT_GT(build_info_length, second_attempt_length); + ASSERT_GT(build_info_length, short_length); // Final attempt with a buffer large enough buffer, padding to // ensure the caller truncates. constexpr size_t kBufferPadSize = 42; - const size_t expected_length = build_info_length; - const size_t final_attempt_length = expected_length + kBufferPadSize; - build_info.assign(final_attempt_length, kNullChar); + const size_t oversize_length = expected_length + kBufferPadSize; + build_info.assign(oversize_length, kNullChar); build_info_length = build_info.size(); result = OEMCrypto_BuildInformation(&build_info[0], &build_info_length); ASSERT_EQ(result, OEMCrypto_SUCCESS) - << "final_attempt_length = " << final_attempt_length + << "oversize_length = " << oversize_length << ", expected_length = " << expected_length - << ", build_info_length = " << build_info_length; + << ", build_info_length (output) = " << build_info_length; // Ensure not empty. ASSERT_GT(build_info_length, kZero) << "Build info cannot be empty"; // Ensure it was truncated down from the padded length. - ASSERT_LT(build_info_length, final_attempt_length) + ASSERT_LT(build_info_length, oversize_length) << "Should have truncated from oversized buffer: expected_length = " << expected_length; - // Ensure the real length is within the size originally specified. - // OK if final length is smaller than estimated length. - ASSERT_LE(build_info_length, expected_length); + // Ensure that length is equal to the length of the previous + // successful call. + ASSERT_EQ(build_info_length, expected_length); } // Verifies that OEMCrypto_BuildInformation() is behaving as expected diff --git a/oemcrypto/test/oemcrypto_cast_test.h b/oemcrypto/test/oemcrypto_cast_test.h index 0aa0491..fbc1383 100644 --- a/oemcrypto/test/oemcrypto_cast_test.h +++ b/oemcrypto/test/oemcrypto_cast_test.h @@ -98,7 +98,39 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { size_t len = provisioning_messages.response_data().rsa_key_length; encoded_rsa_key_ = std::vector(ptr, ptr + len); wrapped_drm_key_ = provisioning_messages.wrapped_rsa_key(); - drm_key_type_ = OEMCrypto_RSA_Private_Key; + drm_key_type_ = OEMCrypto_RSAPrivateKey; + EXPECT_GT(wrapped_drm_key_.size(), 0u); + EXPECT_EQ(nullptr, find(wrapped_drm_key_, encoded_rsa_key_)); + } + if (force) { + EXPECT_EQ(OEMCrypto_SUCCESS, sts); + } + } else if (global_features.provisioning_method == + OEMCrypto_BootCertificateChain && + global_features.api_version >= 20) { + // In v20+, use single stage provisioning. + Session s; + ASSERT_NO_FATAL_FAILURE(s.open()); + Provisioning40CastOneStageRoundTrip prov_cast(&s, encoded_rsa_key_); + prov_cast.set_allowed_schemes(schemes); + ASSERT_NO_FATAL_FAILURE(prov_cast.PrepareSession()); + ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); + + ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( + prov_cast.drm_key_type(), prov_cast.drm_public_key().data(), + prov_cast.drm_public_key().size())); + ASSERT_NO_FATAL_FAILURE(prov_cast.SignAndVerifyRequest()); + ASSERT_NO_FATAL_FAILURE(s.GenerateDerivedKeysFromSessionKey()); + ASSERT_NO_FATAL_FAILURE(prov_cast.CreateDefaultResponse()); + ASSERT_NO_FATAL_FAILURE(prov_cast.EncryptAndSignResponse()); + OEMCryptoResult sts = prov_cast.LoadResponse(); + key_loaded_ = (OEMCrypto_SUCCESS == sts); + if (key_loaded_) { + uint8_t* ptr = prov_cast.response_data().rsa_key; + size_t len = prov_cast.response_data().rsa_key_length; + encoded_rsa_key_ = std::vector(ptr, ptr + len); + wrapped_drm_key_ = prov_cast.wrapped_rsa_key(); + drm_key_type_ = OEMCrypto_RSAPrivateKey; EXPECT_GT(wrapped_drm_key_.size(), 0u); EXPECT_EQ(nullptr, find(wrapped_drm_key_, encoded_rsa_key_)); } @@ -107,6 +139,7 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { } } else if (global_features.provisioning_method == OEMCrypto_BootCertificateChain) { + // In v18 and v19, use two stage provisioning. Session s1; ASSERT_NO_FATAL_FAILURE(s1.open()); ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&s1)); @@ -117,7 +150,7 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { wrapped_oem_key_.data(), wrapped_oem_key_.size()), OEMCrypto_SUCCESS); - Provisioning40CastRoundTrip prov_cast(&s2, encoded_rsa_key_); + Provisioning40CastTwoStageRoundTrip prov_cast(&s2, encoded_rsa_key_); prov_cast.set_allowed_schemes(schemes); ASSERT_NO_FATAL_FAILURE(prov_cast.PrepareSession()); ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); @@ -136,7 +169,7 @@ class OEMCryptoLoadsCertificateAlternates : public OEMCryptoLoadsCertificate { size_t len = prov_cast.response_data().rsa_key_length; encoded_rsa_key_ = std::vector(ptr, ptr + len); wrapped_drm_key_ = prov_cast.wrapped_rsa_key(); - drm_key_type_ = OEMCrypto_RSA_Private_Key; + drm_key_type_ = OEMCrypto_RSAPrivateKey; EXPECT_GT(wrapped_drm_key_.size(), 0u); EXPECT_EQ(nullptr, find(wrapped_drm_key_, encoded_rsa_key_)); } diff --git a/oemcrypto/test/oemcrypto_license_test.cpp b/oemcrypto/test/oemcrypto_license_test.cpp index 9ad5542..b814aa7 100644 --- a/oemcrypto/test/oemcrypto_license_test.cpp +++ b/oemcrypto/test/oemcrypto_license_test.cpp @@ -11,6 +11,7 @@ #include "test_sleep.h" using ::testing::Range; +using oemcrypto_core_message::features::CoreMessageFeatures; namespace wvoec { @@ -55,9 +56,8 @@ void TestMaxKeys(SessionUtil* util, size_t num_keys_per_session) { std::vector> licenses; size_t total_keys = 0; for (size_t i = 0; total_keys < max_total_keys; i++) { - sessions.push_back(std::unique_ptr(new Session())); - licenses.push_back(std::unique_ptr( - new LicenseRoundTrip(sessions[i].get()))); + sessions.push_back(std::make_unique()); + licenses.push_back(std::make_unique(sessions[i].get())); const size_t num_keys = std::min(max_total_keys - total_keys, num_keys_per_session); licenses[i]->set_num_keys(static_cast(num_keys)); @@ -694,9 +694,8 @@ TEST_F(OEMCryptoSessionTests, ClearKcbAPI17) { ASSERT_NO_FATAL_FAILURE(license_messages.SignAndVerifyRequest()); ASSERT_NO_FATAL_FAILURE(license_messages.CreateDefaultResponse()); // Set odk version in the license response to be 16.4 - oemcrypto_core_message::features::CoreMessageFeatures features = {}; - features.maximum_major_version = 16; - features.maximum_minor_version = 4; + CoreMessageFeatures features = + CoreMessageFeatures::DefaultFeatures(16, kServePrereleaseMessages); constexpr bool kForceClearKcb = true; ASSERT_NO_FATAL_FAILURE( license_messages.EncryptAndSignResponseWithCoreMessageFeatures( @@ -917,7 +916,7 @@ TEST_P(OEMCryptoRefreshTest, RefreshWithNoSelectKey) { // Test that playback clock is correctly started and that the license can be // renewed. TEST_P(OEMCryptoRefreshTest, RenewLicenseLoadSuccess) { - license_messages_.core_response().renewal_delay_base = OEMCrypto_License_Load; + license_messages_.core_response().renewal_delay_base = OEMCrypto_LicenseLoad; timer_limits_.rental_duration_seconds = kDuration; // 2 seconds. timer_limits_.initial_renewal_duration_seconds = kLongDuration; // 5 seconds. // First version to support Renew on Load. @@ -955,7 +954,7 @@ TEST_P(OEMCryptoRefreshTest, RenewLicenseLoadSuccess) { } TEST_P(OEMCryptoRefreshTest, RenewLicenseLoadOutsideRentalDuration) { - license_messages_.core_response().renewal_delay_base = OEMCrypto_License_Load; + license_messages_.core_response().renewal_delay_base = OEMCrypto_LicenseLoad; timer_limits_.rental_duration_seconds = kDuration; // 2 seconds. timer_limits_.initial_renewal_duration_seconds = kLongDuration; // 5 seconds. @@ -969,6 +968,77 @@ TEST_P(OEMCryptoRefreshTest, RenewLicenseLoadOutsideRentalDuration) { session_.TestDecryptCTR(false, OEMCrypto_ERROR_UNKNOWN_FAILURE)); } +TEST_P(OEMCryptoRefreshTest, ReleaseLicenseLoadSuccessNoDecrypt) { + if (license_api_version_ < 20 || wvoec::global_features.api_version < 20) { + GTEST_SKIP() << "Test for versions 20 and up only."; + } + if (!wvoec::global_features.usage_table) { + GTEST_SKIP() << "Test for usage table devices only."; + } + license_messages_.set_pst("pst"); + ASSERT_NO_FATAL_FAILURE(session_.CreateNewUsageEntry()); + LoadLicense(); + + OEMCrypto_UsageEntryStatus status; + int64_t seconds_since_license_received; + int64_t seconds_since_first_decrypt; + std::vector buf; + session_.UpdateUsageEntry(&buf); + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_GetUsageEntryInfo(session_.session_id(), &status, + &seconds_since_license_received, + &seconds_since_first_decrypt)); + EXPECT_EQ(status, OEMCrypto_Unused); + + ReleaseRoundTrip release_messages(&license_messages_); + MakeReleaseRequest(&release_messages); + + session_.UpdateUsageEntry(&buf); + + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_GetUsageEntryInfo(session_.session_id(), &status, + &seconds_since_license_received, + &seconds_since_first_decrypt)); + EXPECT_EQ(status, OEMCrypto_InactiveUnused); +} + +TEST_P(OEMCryptoRefreshTest, ReleaseLicenseLoadSuccessDecrypt) { + if (license_api_version_ < 20 || wvoec::global_features.api_version < 20) { + GTEST_SKIP() << "Test for versions 20 and up only."; + } + if (!wvoec::global_features.usage_table) { + GTEST_SKIP() << "Test for usage table devices only."; + } + + license_messages_.set_pst("pst"); + ASSERT_NO_FATAL_FAILURE(session_.CreateNewUsageEntry()); + LoadLicense(); + + OEMCrypto_UsageEntryStatus status; + int64_t seconds_since_license_received; + int64_t seconds_since_first_decrypt; + std::vector buf; + session_.UpdateUsageEntry(&buf); + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_GetUsageEntryInfo(session_.session_id(), &status, + &seconds_since_license_received, + &seconds_since_first_decrypt)); + EXPECT_EQ(status, OEMCrypto_Unused); + + ASSERT_NO_FATAL_FAILURE(session_.TestDecryptCTR(true, OEMCrypto_SUCCESS)); + + ReleaseRoundTrip release_messages(&license_messages_); + MakeReleaseRequest(&release_messages); + + session_.UpdateUsageEntry(&buf); + + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_GetUsageEntryInfo(session_.session_id(), &status, + &seconds_since_license_received, + &seconds_since_first_decrypt)); + EXPECT_EQ(status, OEMCrypto_InactiveUsed); +} + INSTANTIATE_TEST_SUITE_P(TestAll, OEMCryptoRefreshTest, Range(kCurrentAPI - 1, kCurrentAPI + 1)); diff --git a/oemcrypto/test/oemcrypto_license_test.h b/oemcrypto/test/oemcrypto_license_test.h index e80c22c..f06b38a 100644 --- a/oemcrypto/test/oemcrypto_license_test.h +++ b/oemcrypto/test/oemcrypto_license_test.h @@ -270,7 +270,7 @@ class LicenseWithUsageEntry { } ASSERT_NO_FATAL_FAILURE( session_.UpdateUsageEntry(&(test->encrypted_usage_header_))); - ASSERT_NO_FATAL_FAILURE(GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(session_.close()); } @@ -325,7 +325,7 @@ class LicenseWithUsageEntry { reinterpret_cast(pst().c_str()), pst().length())); } - void GenerateVerifyReport(OEMCrypto_Usage_Entry_Status status) { + void GenerateVerifyReport(OEMCrypto_UsageEntryStatus status) { ASSERT_NO_FATAL_FAILURE(session_.GenerateReport(pst())); Test_PST_Report expected(pst(), status); ASSERT_NO_FATAL_FAILURE( diff --git a/oemcrypto/test/oemcrypto_provisioning_test.cpp b/oemcrypto/test/oemcrypto_provisioning_test.cpp index 646acdc..2bfb6f0 100644 --- a/oemcrypto/test/oemcrypto_provisioning_test.cpp +++ b/oemcrypto/test/oemcrypto_provisioning_test.cpp @@ -173,46 +173,6 @@ TEST_F(OEMCryptoProv30Test, GetCertOnlyAPI16) { ASSERT_EQ(OEMCrypto_SUCCESS, license_messages.LoadResponse()); } -/** This verifies that the OEM Certificate cannot be used with - * GenerateRSASignature. - */ -TEST_F(OEMCryptoProv40Test, OEMCertForbidGenerateRSASignature1) { - // Create an OEM Cert and save it for later. - Session s1; - ASSERT_NO_FATAL_FAILURE(s1.open()); - ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&s1)); - ASSERT_EQ(s1.IsPublicKeySet(), true); - s1.close(); - Session s2; - ASSERT_NO_FATAL_FAILURE(s2.open()); - ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_InstallOemPrivateKey( - s2.session_id(), oem_key_type_, - reinterpret_cast(wrapped_oem_key_.data()), - wrapped_oem_key_.size())); - DisallowForbiddenPadding(s2.session_id(), kSign_PKCS1_Block1, 80); -} - -/** This verifies that the OEM Certificate cannot be used with - * GenerateRSASignature. - */ -TEST_F(OEMCryptoProv40Test, OEMCertForbidGenerateRSASignature2) { - // Create an OEM Cert and save it for later. - Session s1; - ASSERT_NO_FATAL_FAILURE(s1.open()); - ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&s1)); - ASSERT_EQ(s1.IsPublicKeySet(), true); - s1.close(); - Session s2; - ASSERT_NO_FATAL_FAILURE(s2.open()); - ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_InstallOemPrivateKey( - s2.session_id(), oem_key_type_, - reinterpret_cast(wrapped_oem_key_.data()), - wrapped_oem_key_.size())); - DisallowForbiddenPadding(s2.session_id(), kSign_RSASSA_PSS, 80); -} - // This verifies that the device really does claim to have BCC. // It should be filtered out for devices that have a keybox or factory OEM // cert. @@ -252,6 +212,9 @@ TEST_F(OEMCryptoProv40Test, GetBootCertificateChainSuccess) { OEMCrypto_SUCCESS); util::BccValidator validator; EXPECT_EQ(util::CborMessageStatus::kCborParseOk, validator.Parse(bcc)); + if (wvoec::global_features.api_version >= 20) { + validator.SetCheckWidevineComponentName(true); + } util::CborMessageStatus status = validator.Validate(); EXPECT_LT(status, util::CborMessageStatus::kCborValidateError); if (status >= util::CborMessageStatus::kCborValidateError) { @@ -489,172 +452,6 @@ TEST_F(OEMCryptoProv40Test, GetDeviceSignedCsrPayloadInvalid) { ASSERT_EQ(sts, OEMCrypto_ERROR_INVALID_CONTEXT); } -// Verifies that an OEM private key can be installed. -TEST_F(OEMCryptoProv40Test, InstallOemPrivateKeySuccess) { - Session s; - ASSERT_NO_FATAL_FAILURE(s.open()); - // First generate a key pair. - // Large buffer to make sure it is large enough. - size_t public_key_size = 10000; - std::vector public_key(public_key_size); - size_t public_key_signature_size = 10000; - std::vector public_key_signature(public_key_signature_size); - size_t wrapped_private_key_size = 10000; - std::vector wrapped_private_key(wrapped_private_key_size); - OEMCrypto_PrivateKeyType key_type; - ASSERT_EQ( - OEMCrypto_GenerateCertificateKeyPair( - s.session_id(), public_key.data(), &public_key_size, - public_key_signature.data(), &public_key_signature_size, - wrapped_private_key.data(), &wrapped_private_key_size, &key_type), - OEMCrypto_SUCCESS); - public_key.resize(public_key_size); - public_key_signature.resize(public_key_signature_size); - wrapped_private_key.resize(wrapped_private_key_size); - - // Install the generated private key. - ASSERT_EQ(OEMCrypto_InstallOemPrivateKey(s.session_id(), key_type, - wrapped_private_key.data(), - wrapped_private_key_size), - OEMCrypto_SUCCESS); -} - -// If data is empty or random, the API should return non-success status. -TEST_F(OEMCryptoProv40Test, InstallOemPrivateKeyInvalidDataFail) { - Session s; - ASSERT_NO_FATAL_FAILURE(s.open()); - - // Empty key fails. - std::vector wrapped_private_key; - OEMCrypto_PrivateKeyType key_type = OEMCrypto_RSA_Private_Key; - ASSERT_NE(OEMCrypto_InstallOemPrivateKey(s.session_id(), key_type, - wrapped_private_key.data(), - wrapped_private_key.size()), - OEMCrypto_SUCCESS); - - // Random key data fails. - wrapped_private_key = {1, 2, 3}; - ASSERT_NE(OEMCrypto_InstallOemPrivateKey(s.session_id(), key_type, - wrapped_private_key.data(), - wrapped_private_key.size()), - OEMCrypto_SUCCESS); -} - -// Verifies that an OEM private key can be installed, and used by -// GenerateCertificateKeyPair call. -TEST_F(OEMCryptoProv40Test, InstallOemPrivateKeyCanBeUsed) { - Session s; - ASSERT_NO_FATAL_FAILURE(s.open()); - // First generate a key pair. - size_t public_key_size1 = 10000; - std::vector public_key1(public_key_size1); - size_t public_key_signature_size1 = 10000; - std::vector public_key_signature1(public_key_signature_size1); - size_t wrapped_private_key_size1 = 10000; - std::vector wrapped_private_key1(wrapped_private_key_size1); - OEMCrypto_PrivateKeyType key_type1; - ASSERT_EQ( - OEMCrypto_GenerateCertificateKeyPair( - s.session_id(), public_key1.data(), &public_key_size1, - public_key_signature1.data(), &public_key_signature_size1, - wrapped_private_key1.data(), &wrapped_private_key_size1, &key_type1), - OEMCrypto_SUCCESS); - EXPECT_NE(public_key_size1, 0UL); - EXPECT_NE(public_key_signature_size1, 0UL); - EXPECT_NE(wrapped_private_key_size1, 0UL); - public_key1.resize(public_key_size1); - public_key_signature1.resize(public_key_signature_size1); - wrapped_private_key1.resize(wrapped_private_key_size1); - - // Install the generated private key. - ASSERT_EQ(OEMCrypto_InstallOemPrivateKey(s.session_id(), key_type1, - wrapped_private_key1.data(), - wrapped_private_key_size1), - OEMCrypto_SUCCESS); - - // Now calling GenerateCertificateKeyPair should use wrapped_private_key to - // sign the newly generated public key. - size_t public_key_size2 = 10000; - std::vector public_key2(public_key_size2); - size_t public_key_signature_size2 = 10000; - std::vector public_key_signature2(public_key_signature_size2); - size_t wrapped_private_key_size2 = 10000; - std::vector wrapped_private_key2(wrapped_private_key_size2); - OEMCrypto_PrivateKeyType key_type2; - ASSERT_EQ( - OEMCrypto_GenerateCertificateKeyPair( - s.session_id(), public_key2.data(), &public_key_size2, - public_key_signature2.data(), &public_key_signature_size2, - wrapped_private_key2.data(), &wrapped_private_key_size2, &key_type2), - OEMCrypto_SUCCESS); - EXPECT_NE(public_key_size2, 0UL); - EXPECT_NE(public_key_signature_size2, 0UL); - EXPECT_NE(wrapped_private_key_size2, 0UL); - public_key2.resize(public_key_size2); - public_key_signature2.resize(public_key_signature_size2); - wrapped_private_key2.resize(wrapped_private_key_size2); - - // Verify public_key_signature2 with public_key1. - if (key_type1 == OEMCrypto_PrivateKeyType::OEMCrypto_RSA_Private_Key) { - ASSERT_NO_FATAL_FAILURE(s.SetRsaPublicKeyFromSubjectPublicKey( - public_key1.data(), public_key1.size())); - ASSERT_NO_FATAL_FAILURE( - s.VerifyRsaSignature(public_key2, public_key_signature2.data(), - public_key_signature2.size(), kSign_RSASSA_PSS)); - } else if (key_type1 == OEMCrypto_PrivateKeyType::OEMCrypto_ECC_Private_Key) { - ASSERT_NO_FATAL_FAILURE(s.SetEccPublicKeyFromSubjectPublicKey( - public_key1.data(), public_key1.size())); - ASSERT_NO_FATAL_FAILURE(s.VerifyEccSignature(public_key2, - public_key_signature2.data(), - public_key_signature2.size())); - } -} - -/** Verify that the private key from an OEM Cert cannot be loaded as a DRM - * cert. - */ -TEST_F(OEMCryptoProv40Test, OEMPrivateKeyCannotBeDRMKey) { - // Create an OEM Cert and save it for later. - Session s1; - ASSERT_NO_FATAL_FAILURE(s1.open()); - ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&s1)); - ASSERT_EQ(s1.IsPublicKeySet(), true); - s1.close(); - const std::vector wrapped_oem_key1 = wrapped_oem_key_; - // Now create a new OEM cert, load the second key, and try to load key1 - // as the DRM key. - Session s2; - ASSERT_NO_FATAL_FAILURE(s2.open()); - ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&s2)); - s2.close(); - // Load the current key as the OEM key in session 3. - Session s3; - ASSERT_NO_FATAL_FAILURE(s3.open()); - // Now try to load key 1 as a DRM key. That should fail. - ASSERT_EQ(OEMCrypto_ERROR_INVALID_KEY, - OEMCrypto_LoadDRMPrivateKey(s3.session_id(), oem_key_type_, - wrapped_oem_key1.data(), - wrapped_oem_key1.size())); -} - -/** The private key for a DRM Cert cannot be loaded as an OEM Certificate. */ -TEST_F(OEMCryptoProv40Test, DRMPrivateKeyCannotBeOEMKey) { - // Create a DRM cert and save it for later. - Session s1; - // Make sure the drm private key exists. - ASSERT_NO_FATAL_FAILURE(s1.open()); - ASSERT_NO_FATAL_FAILURE(InstallTestDrmKey(&s1)); - ASSERT_NE(wrapped_drm_key_.size(), 0u); - // Now try to load the drm private key as an OEM key. - Session s2; - ASSERT_NO_FATAL_FAILURE(s2.open()); - ASSERT_EQ(OEMCrypto_ERROR_INVALID_KEY, - OEMCrypto_InstallOemPrivateKey( - s2.session_id(), drm_key_type_, - reinterpret_cast(wrapped_drm_key_.data()), - wrapped_drm_key_.size())); -} - TEST_F(OEMCryptoProv40Test, GetDeviceId) { OEMCryptoResult sts; std::vector dev_id; @@ -678,14 +475,6 @@ TEST_F(OEMCryptoProv40Test, GetDeviceId) { ASSERT_EQ(dev_id2, dev_id); } -// Verifies provisioning stage 1 OEM cert provisioning round trip works -TEST_F(OEMCryptoProv40Test, ProvisionOemCert) { - Session s; - ASSERT_NO_FATAL_FAILURE(s.open()); - ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&s)); - ASSERT_EQ(s.IsPublicKeySet(), true); -} - // Verifies both provisioning stages OEM and DRM cert provisioning round trip // works TEST_F(OEMCryptoProv40Test, ProvisionDrmCert) { @@ -696,35 +485,40 @@ TEST_F(OEMCryptoProv40Test, ProvisionDrmCert) { } TEST_P(OEMCryptoProv40CastTest, ProvisionCastWorks) { - // Generate an OEM key first, to load into next session - Session s; - ASSERT_NO_FATAL_FAILURE(s.open()); - size_t public_key_size = 10000; - std::vector public_key(public_key_size); - size_t public_key_signature_size = 10000; - std::vector public_key_signature(public_key_signature_size); - size_t wrapped_private_key_size = 10000; - std::vector wrapped_private_key(wrapped_private_key_size); - OEMCrypto_PrivateKeyType key_type; - ASSERT_EQ( - OEMCrypto_GenerateCertificateKeyPair( - s.session_id(), public_key.data(), &public_key_size, - public_key_signature.data(), &public_key_signature_size, - wrapped_private_key.data(), &wrapped_private_key_size, &key_type), - OEMCrypto_SUCCESS); - public_key.resize(public_key_size); - public_key_signature.resize(public_key_signature_size); - wrapped_private_key.resize(wrapped_private_key_size); - ASSERT_NO_FATAL_FAILURE(s.close()); - - // Install OEM key and get cast RSA Session s1; - ASSERT_NO_FATAL_FAILURE(s1.open()); - ASSERT_EQ(OEMCrypto_InstallOemPrivateKey(s1.session_id(), key_type, - wrapped_private_key.data(), - wrapped_private_key_size), - OEMCrypto_SUCCESS); + if (global_features.api_version >= 20) { + // Single stage, just need to open the session. + ASSERT_NO_FATAL_FAILURE(s1.open()); + } else { + // Prior to v20, prov 4.0 used two stages. + // Generate an OEM key first, to load into next session + Session s; + ASSERT_NO_FATAL_FAILURE(s.open()); + size_t public_key_size = 10000; + std::vector public_key(public_key_size); + size_t public_key_signature_size = 10000; + std::vector public_key_signature(public_key_signature_size); + size_t wrapped_private_key_size = 10000; + std::vector wrapped_private_key(wrapped_private_key_size); + OEMCrypto_PrivateKeyType key_type; + ASSERT_EQ( + OEMCrypto_GenerateCertificateKeyPair( + s.session_id(), public_key.data(), &public_key_size, + public_key_signature.data(), &public_key_signature_size, + wrapped_private_key.data(), &wrapped_private_key_size, &key_type), + OEMCrypto_SUCCESS); + public_key.resize(public_key_size); + public_key_signature.resize(public_key_signature_size); + wrapped_private_key.resize(wrapped_private_key_size); + ASSERT_NO_FATAL_FAILURE(s.close()); + // Install OEM key and get cast RSA + ASSERT_NO_FATAL_FAILURE(s1.open()); + ASSERT_EQ(OEMCrypto_InstallOemPrivateKey(s1.session_id(), key_type, + wrapped_private_key.data(), + wrapped_private_key_size), + OEMCrypto_SUCCESS); + } ASSERT_NO_FATAL_FAILURE(CreateProv4CastKey(&s1, GetParam())); } @@ -1262,9 +1056,8 @@ TEST_F(OEMCryptoLoadsCertificate, TestMaxDRMKeys) { // It should be able to load up to kMaxTotalDRMPrivateKeys keys for (size_t i = 0; i < max_total_keys; i++) { - sessions.push_back(std::unique_ptr(new Session())); - licenses.push_back(std::unique_ptr( - new LicenseRoundTrip(sessions[i].get()))); + sessions.push_back(std::make_unique()); + licenses.push_back(std::make_unique(sessions[i].get())); const size_t key_index = i % kTestRSAPKCS8PrivateKeys_2048.size(); encoded_rsa_key_.assign(kTestRSAPKCS8PrivateKeys_2048[key_index].begin(), kTestRSAPKCS8PrivateKeys_2048[key_index].end()); diff --git a/oemcrypto/test/oemcrypto_security_test.cpp b/oemcrypto/test/oemcrypto_security_test.cpp index 5f08000..a6cfc49 100644 --- a/oemcrypto/test/oemcrypto_security_test.cpp +++ b/oemcrypto/test/oemcrypto_security_test.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -831,8 +832,8 @@ TEST_F(OEMCryptoLoadsCertificate, vector wrapped_drm_key_buffer = wrapped_drm_key_; wrapped_drm_key_buffer.resize(wrapped_drm_key_length); OEMCryptoResult result = OEMCrypto_LoadDRMPrivateKey( - s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_drm_key_buffer.data(), wrapped_drm_key_buffer.size()); + s.session_id(), OEMCrypto_RSAPrivateKey, wrapped_drm_key_buffer.data(), + wrapped_drm_key_buffer.size()); s.close(); return result; }; @@ -860,8 +861,8 @@ TEST_F( s.open(); vector wrapped_drm_key_buffer(wrapped_drm_key_length); OEMCryptoResult result = OEMCrypto_LoadDRMPrivateKey( - s.session_id(), OEMCrypto_RSA_Private_Key, - wrapped_drm_key_buffer.data(), wrapped_drm_key_buffer.size()); + s.session_id(), OEMCrypto_RSAPrivateKey, wrapped_drm_key_buffer.data(), + wrapped_drm_key_buffer.size()); s.close(); return result; }; @@ -2277,5 +2278,26 @@ TEST_P(OEMCryptoLicenseOverflowTest, INSTANTIATE_TEST_SUITE_P(TestAll, OEMCryptoLicenseOverflowTest, Range(kCurrentAPI - 1, kCurrentAPI + 1)); +/** Generate an entropy file. */ +TEST_F(OEMCryptoClientTest, GenerateEntropyFile) { + constexpr char kFilename[] = "entropy.bin"; + constexpr size_t kChunkSize = 100'000; + constexpr int kNumChunks = 10; + + if (OEMCrypto_GetRandom(nullptr, 0) == OEMCrypto_ERROR_NOT_IMPLEMENTED) { + GTEST_SKIP() << "OEMCrypto_GetRandom() is not implemented."; + } + + std::ofstream file(kFilename, std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + std::vector chunk(kChunkSize); + for (int i = 0; i < kNumChunks; ++i) { + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_GetRandom(chunk.data(), chunk.size())); + file.write(reinterpret_cast(chunk.data()), chunk.size()); + ASSERT_TRUE(file.good()); + } +} + /// @} } // namespace wvoec diff --git a/oemcrypto/test/oemcrypto_security_tests.gypi b/oemcrypto/test/oemcrypto_security_tests.gypi index fda1b75..3b7886a 100644 --- a/oemcrypto/test/oemcrypto_security_tests.gypi +++ b/oemcrypto/test/oemcrypto_security_tests.gypi @@ -63,7 +63,7 @@ ], 'dependencies': [ '<(oemcrypto_dir)/odk/src/odk.gyp:odk', - '<(oemcrypto_dir)/util/oec_ref_util.gyp:oec_ref_util', + '<(oemcrypto_dir)/util/build.gyp:liboec_ref_util', ], 'includes': [ '../../util/libssl_dependency.gypi' ], } diff --git a/oemcrypto/test/oemcrypto_session_tests_helper.cpp b/oemcrypto/test/oemcrypto_session_tests_helper.cpp index 175bb51..633331b 100644 --- a/oemcrypto/test/oemcrypto_session_tests_helper.cpp +++ b/oemcrypto/test/oemcrypto_session_tests_helper.cpp @@ -34,7 +34,7 @@ void SessionUtil::CreateWrappedDRMKey() { ASSERT_NO_FATAL_FAILURE(provisioning_messages.EncryptAndSignResponse()); ASSERT_EQ(OEMCrypto_SUCCESS, provisioning_messages.LoadResponse()); wrapped_drm_key_ = provisioning_messages.wrapped_rsa_key(); - drm_key_type_ = OEMCrypto_RSA_Private_Key; + drm_key_type_ = OEMCrypto_RSAPrivateKey; drm_public_key_.clear(); } } @@ -111,10 +111,11 @@ void SessionUtil::InstallTestDrmKey(Session* s) { // private to session s. void SessionUtil::CreateProv4OEMKey(Session* s) { ASSERT_NE(s, nullptr); - if (global_features.provisioning_method != OEMCrypto_BootCertificateChain) { - FAIL() << "Provisioning 4.0 is required."; + if (global_features.provisioning_method != OEMCrypto_BootCertificateChain || + global_features.api_version >= 20) { + FAIL() << "Two-stage provisioning 4.0 is required."; } - Provisioning40RoundTrip provisioning_messages(s); + Provisioning40TwoStageRoundTrip provisioning_messages(s); // Generate key pair. ASSERT_NO_FATAL_FAILURE(provisioning_messages.PrepareSession(true)); // Need OEM public key to verify the signed request. @@ -131,84 +132,138 @@ void SessionUtil::CreateProv4OEMKey(Session* s) { ASSERT_EQ(OEMCrypto_SUCCESS, provisioning_messages.LoadOEMCertResponse()); } -// Generate DRM key pair, craft a provisioning 4.0 DRM cert request, sign it -// with the OEM private key and verify the signature. Finally, install DRM -// private to session s. An OEM cert needs to be installed first. It is also -// done in this function. +// Generate a DRM key pair, craft a Provisioning 4.0 DRM cert request, sign the +// request with the DRM private key, and verify the signature. Finally, install +// the DRM private key to session s. void SessionUtil::CreateProv4DRMKey() { if (global_features.provisioning_method != OEMCrypto_BootCertificateChain) { FAIL() << "Provisioning 4.0 is required."; } - // Provision OEM key first. - if (wrapped_oem_key_.size() == 0) { - Session oem_session; - ASSERT_NO_FATAL_FAILURE(oem_session.open()); - ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&oem_session)); - } - Session s; - ASSERT_NO_FATAL_FAILURE(s.open()); - ASSERT_EQ(OEMCrypto_SUCCESS, - OEMCrypto_InstallOemPrivateKey( - s.session_id(), oem_key_type_, - reinterpret_cast(wrapped_oem_key_.data()), - wrapped_oem_key_.size())); - ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( - oem_key_type_, oem_public_key_.data(), oem_public_key_.size())); + if (global_features.api_version >= 20) { + Session s; + ASSERT_NO_FATAL_FAILURE(s.open()); + // Provision DRM key. + Provisioning40OneStageRoundTrip provisioning_messages(&s); + ASSERT_NO_FATAL_FAILURE(provisioning_messages.PrepareSession()); + // Need DRM public key to verify DRM request signature. + ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( + provisioning_messages.drm_key_type(), + provisioning_messages.drm_public_key().data(), + provisioning_messages.drm_public_key().size())); + ASSERT_NO_FATAL_FAILURE(provisioning_messages.SignAndVerifyRequest()); + ASSERT_EQ(OEMCrypto_SUCCESS, provisioning_messages.LoadDRMCertResponse()); + wrapped_drm_key_ = provisioning_messages.wrapped_drm_key(); + drm_key_type_ = provisioning_messages.drm_key_type(); + drm_public_key_ = provisioning_messages.drm_public_key(); + } else { + // Two stage provisioning. + // Provision OEM key first. + if (wrapped_oem_key_.size() == 0) { + Session oem_session; + ASSERT_NO_FATAL_FAILURE(oem_session.open()); + ASSERT_NO_FATAL_FAILURE(CreateProv4OEMKey(&oem_session)); + } + Session s; + ASSERT_NO_FATAL_FAILURE(s.open()); + ASSERT_EQ(OEMCrypto_SUCCESS, + OEMCrypto_InstallOemPrivateKey( + s.session_id(), oem_key_type_, + reinterpret_cast(wrapped_oem_key_.data()), + wrapped_oem_key_.size())); + ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( + oem_key_type_, oem_public_key_.data(), oem_public_key_.size())); - // Provision DRM key. - Provisioning40RoundTrip provisioning_messages(&s); - ASSERT_NO_FATAL_FAILURE(provisioning_messages.PrepareSession(false)); - // Need DRM public key to verify DRM request signature. - ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( - provisioning_messages.drm_key_type(), - provisioning_messages.drm_public_key().data(), - provisioning_messages.drm_public_key().size())); - ASSERT_NO_FATAL_FAILURE(provisioning_messages.SignAndVerifyRequest()); - ASSERT_EQ(OEMCrypto_SUCCESS, provisioning_messages.LoadDRMCertResponse()); - wrapped_drm_key_ = provisioning_messages.wrapped_drm_key(); - drm_key_type_ = provisioning_messages.drm_key_type(); - drm_public_key_ = provisioning_messages.drm_public_key(); + // Provision DRM key. + Provisioning40TwoStageRoundTrip provisioning_messages(&s); + ASSERT_NO_FATAL_FAILURE(provisioning_messages.PrepareSession(false)); + // Need DRM public key to verify DRM request signature. + ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( + provisioning_messages.drm_key_type(), + provisioning_messages.drm_public_key().data(), + provisioning_messages.drm_public_key().size())); + ASSERT_NO_FATAL_FAILURE(provisioning_messages.SignAndVerifyRequest()); + ASSERT_EQ(OEMCrypto_SUCCESS, provisioning_messages.LoadDRMCertResponse()); + wrapped_drm_key_ = provisioning_messages.wrapped_drm_key(); + drm_key_type_ = provisioning_messages.drm_key_type(); + drm_public_key_ = provisioning_messages.drm_public_key(); + } } -// Requires stage 1 prov4 to be complete, ie OEM key is available void SessionUtil::CreateProv4CastKey(Session* s, bool load_drm_before_prov_req) { if (global_features.provisioning_method != OEMCrypto_BootCertificateChain) { FAIL() << "Provisioning 4.0 is required."; } - Provisioning40CastRoundTrip prov_cast(s, encoded_rsa_key_); + if (global_features.api_version >= 20) { + // Single stage cast provisioning. + Provisioning40CastOneStageRoundTrip prov_cast(s, encoded_rsa_key_); - // Calls GenerateCertificateKeyPair(). Generated keys stored in - // prov_cast.drm_public_key_ and prov_cast.wrapped_drm_key_ - ASSERT_NO_FATAL_FAILURE(prov_cast.PrepareSession()); + // Calls GenerateCertificateKeyPair(). Generated keys stored in + // prov_cast.drm_public_key_ and prov_cast.wrapped_drm_key_ + ASSERT_NO_FATAL_FAILURE(prov_cast.PrepareSession()); - // Can choose to load DRM key before preparing the provisioning request, or - // after - if (load_drm_before_prov_req) { - ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); + // Can choose to load DRM key before preparing the provisioning request, or + // after + if (load_drm_before_prov_req) { + ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); + } + ASSERT_NO_FATAL_FAILURE(s->SetPublicKeyFromSubjectPublicKey( + prov_cast.drm_key_type(), prov_cast.drm_public_key().data(), + prov_cast.drm_public_key().size())); + ASSERT_NO_FATAL_FAILURE(prov_cast.SignAndVerifyRequest()); + if (!load_drm_before_prov_req) { + ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); + } + + // Generate derived keys in order to verify and decrypt response. + // We are cheating a little bit here since this GenerateDerivedKeys helper + // simulates work on both client side (calls + // OEMCrypto_GenerateDerivedKeysFromSessionKey) and server side (sets + // key_deriver() keys used to create response) + ASSERT_NO_FATAL_FAILURE(s->GenerateDerivedKeysFromSessionKey()); + + // Response is provisioning 2 with CAST key + ASSERT_NO_FATAL_FAILURE(prov_cast.CreateDefaultResponse()); + ASSERT_NO_FATAL_FAILURE(prov_cast.EncryptAndSignResponse()); + + // Should parse and load successfully + ASSERT_EQ(OEMCrypto_SUCCESS, prov_cast.LoadResponse()); + } else { + // Two stage cast provisioning. + // Single stage cast provisioning. + Provisioning40CastTwoStageRoundTrip prov_cast(s, encoded_rsa_key_); + + // Calls GenerateCertificateKeyPair(). Generated keys stored in + // prov_cast.drm_public_key_ and prov_cast.wrapped_drm_key_ + ASSERT_NO_FATAL_FAILURE(prov_cast.PrepareSession()); + + // Can choose to load DRM key before preparing the provisioning request, or + // after + if (load_drm_before_prov_req) { + ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); + } + ASSERT_NO_FATAL_FAILURE(s->SetPublicKeyFromSubjectPublicKey( + prov_cast.drm_key_type(), prov_cast.drm_public_key().data(), + prov_cast.drm_public_key().size())); + ASSERT_NO_FATAL_FAILURE(prov_cast.SignAndVerifyRequest()); + if (!load_drm_before_prov_req) { + ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); + } + + // Generate derived keys in order to verify and decrypt response. + // We are cheating a little bit here since this GenerateDerivedKeys helper + // simulates work on both client side (calls + // OEMCrypto_GenerateDerivedKeysFromSessionKey) and server side (sets + // key_deriver() keys used to create response) + ASSERT_NO_FATAL_FAILURE(s->GenerateDerivedKeysFromSessionKey()); + + // Response is provisioning 2 with CAST key + ASSERT_NO_FATAL_FAILURE(prov_cast.CreateDefaultResponse()); + ASSERT_NO_FATAL_FAILURE(prov_cast.EncryptAndSignResponse()); + + // Should parse and load successfully + ASSERT_EQ(OEMCrypto_SUCCESS, prov_cast.LoadResponse()); } - ASSERT_NO_FATAL_FAILURE(s->SetPublicKeyFromSubjectPublicKey( - prov_cast.drm_key_type(), prov_cast.drm_public_key().data(), - prov_cast.drm_public_key().size())); - ASSERT_NO_FATAL_FAILURE(prov_cast.SignAndVerifyRequest()); - if (!load_drm_before_prov_req) { - ASSERT_NO_FATAL_FAILURE(prov_cast.LoadDRMPrivateKey()); - } - - // Generate derived keys in order to verify and decrypt response. - // We are cheating a little bit here since this GenerateDerivedKeys helper - // simulates work on both client side (calls - // OEMCrypto_GenerateDerivedKeysFromSessionKey) and server side (sets - // key_deriver() keys used to create response) - ASSERT_NO_FATAL_FAILURE(s->GenerateDerivedKeysFromSessionKey()); - - // Response is provisioning 2 with CAST key - ASSERT_NO_FATAL_FAILURE(prov_cast.CreateDefaultResponse()); - ASSERT_NO_FATAL_FAILURE(prov_cast.EncryptAndSignResponse()); - - // Should parse and load successfully - ASSERT_EQ(OEMCrypto_SUCCESS, prov_cast.LoadResponse()); } - } // namespace wvoec diff --git a/oemcrypto/test/oemcrypto_session_tests_helper.h b/oemcrypto/test/oemcrypto_session_tests_helper.h index eba241b..12e03b4 100644 --- a/oemcrypto/test/oemcrypto_session_tests_helper.h +++ b/oemcrypto/test/oemcrypto_session_tests_helper.h @@ -38,7 +38,7 @@ class SessionUtil { void InstallTestDrmKey(Session* s); // Create and install an OEM Cert private key. After creation, the key is - // saved to oem_public_key_. Only for provisioning 4.0 + // saved to oem_public_key_. Only for two-stage provisioning 4.0 void CreateProv4OEMKey(Session* s); // Create a new DRM Cert. Only for provisioning 4.0 @@ -53,7 +53,7 @@ class SessionUtil { std::vector drm_public_key_; wvoec::WidevineKeybox keybox_; - // Used by prov4.0 + // Used by prov4.0 single stage. std::vector wrapped_oem_key_; std::vector oem_public_key_; OEMCrypto_PrivateKeyType oem_key_type_; diff --git a/oemcrypto/test/oemcrypto_unittests.gypi b/oemcrypto/test/oemcrypto_unittests.gypi index 7b2c8ba..2cc8943 100644 --- a/oemcrypto/test/oemcrypto_unittests.gypi +++ b/oemcrypto/test/oemcrypto_unittests.gypi @@ -75,7 +75,8 @@ ], 'dependencies': [ '<(oemcrypto_dir)/odk/src/odk.gyp:odk', - '<(oemcrypto_dir)/util/oec_ref_util.gyp:oec_ref_util', + '<(oemcrypto_dir)/util/build.gyp:liboec_ref_util', + '<(oemcrypto_dir)/util/wvcrc32.gyp:libwvcrc32', ], 'includes': [ '../../util/libssl_dependency.gypi' ], } diff --git a/oemcrypto/test/oemcrypto_usage_table_test.cpp b/oemcrypto/test/oemcrypto_usage_table_test.cpp index bae0c9c..b58bc2c 100644 --- a/oemcrypto/test/oemcrypto_usage_table_test.cpp +++ b/oemcrypto/test/oemcrypto_usage_table_test.cpp @@ -5,6 +5,8 @@ #include "oemcrypto_usage_table_test.h" +#include + using ::testing::Range; using ::testing::Values; @@ -43,12 +45,33 @@ TEST_F(OEMCryptoSessionTests, Provisioning_IncrementCounterAPI18) { }; // prep and sign provisioning4 request, then extract counter values - auto provision4 = [&](counts* c) { + const std::function provision4_single_stage = [&](counts* c) { + // Same as SessionUtil::CreateProv4DRMKey, but we can't extract counter + // values using that function + Session s; + ASSERT_NO_FATAL_FAILURE(s.open()); + Provisioning40OneStageRoundTrip provisioning_messages(&s); + ASSERT_NO_FATAL_FAILURE(provisioning_messages.PrepareSession()); + ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( + provisioning_messages.drm_key_type(), + provisioning_messages.drm_public_key().data(), + provisioning_messages.drm_public_key().size())); + ASSERT_NO_FATAL_FAILURE(provisioning_messages.SignAndVerifyRequest()); + ASSERT_EQ(OEMCrypto_SUCCESS, provisioning_messages.LoadDRMCertResponse()); + c->prov = + provisioning_messages.core_request().counter_info.provisioning_count; + c->lic = provisioning_messages.core_request().counter_info.license_count; + c->decrypt = + provisioning_messages.core_request().counter_info.decrypt_count; + c->mgn = provisioning_messages.core_request() + .counter_info.master_generation_number; + }; + const std::function provision4_two_stage = [&](counts* c) { // Same as SessionUtil::CreateProv4OEMKey, but we can't extract counter // values using that function Session s; ASSERT_NO_FATAL_FAILURE(s.open()); - Provisioning40RoundTrip provisioning_messages(&s); + Provisioning40TwoStageRoundTrip provisioning_messages(&s); ASSERT_NO_FATAL_FAILURE(provisioning_messages.PrepareSession(true)); ASSERT_NO_FATAL_FAILURE(s.SetPublicKeyFromSubjectPublicKey( provisioning_messages.oem_key_type(), @@ -67,6 +90,9 @@ TEST_F(OEMCryptoSessionTests, Provisioning_IncrementCounterAPI18) { c->mgn = provisioning_messages.core_request() .counter_info.master_generation_number; }; + const auto provision4 = global_features.api_version >= 20 + ? provision4_single_stage + : provision4_two_stage; if (global_features.provisioning_method == OEMCrypto_OEMCertificate || global_features.provisioning_method == OEMCrypto_DrmCertificate) { @@ -172,10 +198,10 @@ TEST_F(OEMCryptoSessionTests, MasterGeneration_IncrementCounterAPI18) { Session& s = entry.session(); ASSERT_NO_FATAL_FAILURE(entry.OpenAndReload(this)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); Session s2; s2.open(); @@ -228,18 +254,18 @@ TEST_P(OEMCryptoUsageTableTest, OnlineLicense) { ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // test repeated report generation - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); // Flag the entry as inactive. ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // It should report as inactive. - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); // Decrypt should fail. ASSERT_NO_FATAL_FAILURE( entry.TestDecryptCTR(false, OEMCrypto_ERROR_UNKNOWN_FAILURE)); @@ -248,7 +274,7 @@ TEST_P(OEMCryptoUsageTableTest, OnlineLicense) { ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // It should report as inactive. - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); } // Test the usage report when the license is loaded but the keys are never @@ -260,12 +286,12 @@ TEST_P(OEMCryptoUsageTableTest, OnlineLicenseUnused) { Session& s = entry.session(); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // No decrypt. We do not use this license. - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); // Flag the entry as inactive. ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // It should report as inactive. - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUnused)); // Decrypt should fail. ASSERT_NO_FATAL_FAILURE( entry.TestDecryptCTR(true, OEMCrypto_ERROR_UNKNOWN_FAILURE)); @@ -274,7 +300,7 @@ TEST_P(OEMCryptoUsageTableTest, OnlineLicenseUnused) { ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // It should report as inactive. - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUnused)); } // Test that the usage table has been updated and saved before a report can be @@ -285,14 +311,14 @@ TEST_P(OEMCryptoUsageTableTest, ForbidReportWithNoUpdate) { entry.MakeAndLoadOnline(this); Session& s = entry.session(); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); // Cannot generate a report without first updating the file. ASSERT_NO_FATAL_FAILURE( s.GenerateReport(entry.pst(), OEMCrypto_ERROR_ENTRY_NEEDS_UPDATE)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // Now it's OK. - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); // Flag the entry as inactive. ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); // Cannot generate a report without first updating the file. @@ -315,7 +341,7 @@ TEST_P(OEMCryptoUsageTableTest, OnlineLicenseWithRefreshAPI16) { MakeRenewalRequest(&renewal_messages); LoadRenewal(&renewal_messages, OEMCrypto_SUCCESS); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); } // Verify that a streaming license cannot be reloaded. @@ -491,10 +517,10 @@ TEST_P(OEMCryptoUsageTableTest, GenericCryptoEncrypt) { ASSERT_EQ(OEMCrypto_SUCCESS, sts); EXPECT_EQ(expected_encrypted, encrypted); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); encrypted.assign(clear_buffer_.size(), 0); sts = OEMCrypto_Generic_Encrypt(key_handle.data(), key_handle.size(), clear_buffer_.data(), clear_buffer_.size(), @@ -532,10 +558,10 @@ TEST_P(OEMCryptoUsageTableTest, GenericCryptoDecrypt) { ASSERT_EQ(OEMCrypto_SUCCESS, sts); EXPECT_EQ(clear_buffer_, resultant); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); resultant.assign(encrypted.size(), 0); sts = OEMCrypto_Generic_Decrypt( key_handle.data(), key_handle.size(), encrypted.data(), encrypted.size(), @@ -582,10 +608,10 @@ TEST_P(OEMCryptoUsageTableTest, GenericCryptoSign) { ASSERT_EQ(expected_signature, signature); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); signature.assign(SHA256_DIGEST_LENGTH, 0); gen_signature_length = SHA256_DIGEST_LENGTH; sts = OEMCrypto_Generic_Sign(key_handle.data(), key_handle.size(), @@ -624,10 +650,10 @@ TEST_P(OEMCryptoUsageTableTest, GenericCryptoVerify) { signature.size()); ASSERT_EQ(OEMCrypto_SUCCESS, sts); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); sts = OEMCrypto_Generic_Verify(key_handle.data(), key_handle.size(), clear_buffer_.data(), clear_buffer_.size(), OEMCrypto_HMAC_SHA256, signature.data(), @@ -657,15 +683,15 @@ TEST_P(OEMCryptoUsageTableTest, OfflineLicenseRefresh) { LoadRenewal(&renewal_messages, OEMCrypto_SUCCESS); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); } // Test that an offline license can be loaded and that the license can be // released TEST_P(OEMCryptoUsageTableTest, OfflineLicenseReleaseAPI19) { - // License release is new in OEMCrypto v19. - if (wvoec::global_features.api_version < 19 || license_api_version_ < 19) { - GTEST_SKIP() << "Test for versions 19 and up only."; + // License release is new in OEMCrypto v20. + if (wvoec::global_features.api_version < 20 || license_api_version_ < 20) { + GTEST_SKIP() << "Test for versions 20 and up only."; } LicenseWithUsageEntry entry; entry.license_messages().set_api_version(license_api_version_); @@ -690,10 +716,10 @@ TEST_P(OEMCryptoUsageTableTest, ReloadOfflineLicense) { Session& s = entry.session(); ASSERT_NO_FATAL_FAILURE(entry.OpenAndReload(this)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); } // Test that an offline license can be reloaded in a new session, and then @@ -706,7 +732,7 @@ TEST_P(OEMCryptoUsageTableTest, ReloadOfflineLicenseWithRefresh) { ASSERT_NO_FATAL_FAILURE(entry.OpenAndReload(this)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); RenewalRoundTrip renewal_messages(&entry.license_messages()); @@ -714,7 +740,7 @@ TEST_P(OEMCryptoUsageTableTest, ReloadOfflineLicenseWithRefresh) { LoadRenewal(&renewal_messages, OEMCrypto_SUCCESS); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); } // Verify that we can still reload an offline license after @@ -731,10 +757,10 @@ TEST_P(OEMCryptoUsageTableTest, ReloadOfflineLicenseWithTerminate) { ASSERT_NO_FATAL_FAILURE(entry.OpenAndReload(this)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); } // If we attempt to load a second license with the same usage entry as the @@ -765,7 +791,7 @@ TEST_P(OEMCryptoUsageTableTest, BadReloadOfflineLicense) { // Now we go back to the original license response. It should load OK. ASSERT_NO_FATAL_FAILURE(entry.OpenAndReload(this)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); } // An offline license should not load on the first call if the nonce is bad. @@ -841,7 +867,7 @@ TEST_P(OEMCryptoUsageTableTest, DeactivateOfflineLicense) { ASSERT_NO_FATAL_FAILURE( entry.TestDecryptCTR(false, OEMCrypto_ERROR_UNKNOWN_FAILURE)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); ASSERT_NO_FATAL_FAILURE(s.close()); // Offline license can not be reused if it has been deactivated. @@ -858,12 +884,12 @@ TEST_P(OEMCryptoUsageTableTest, DeactivateOfflineLicense) { // only work if the license server can handle v16 licenses. This is a rare // condition, so it is OK to break it during the transition months. entry.license_messages().set_api_version(global_features.api_version); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); // We could call DeactivateUsageEntry multiple times. The state should not // change. ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); } // The usage report should indicate that the keys were never used for @@ -880,7 +906,7 @@ TEST_P(OEMCryptoUsageTableTest, DeactivateOfflineLicenseUnused) { ASSERT_NO_FATAL_FAILURE( entry.TestDecryptCTR(true, OEMCrypto_ERROR_UNKNOWN_FAILURE)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUnused)); ASSERT_NO_FATAL_FAILURE(s.close()); // Offline license can not be reused if it has been deactivated. @@ -897,12 +923,12 @@ TEST_P(OEMCryptoUsageTableTest, DeactivateOfflineLicenseUnused) { // only work if the license server can handle v16 licenses. This is a rare // condition, so it is OK to break it during the transition months. entry.license_messages().set_api_version(global_features.api_version); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUnused)); // We could call DeactivateUsageEntry multiple times. The state should not // change. ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUnused)); } TEST_P(OEMCryptoUsageTableTest, SecureStop) { @@ -921,7 +947,7 @@ TEST_P(OEMCryptoUsageTableTest, SecureStop) { ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); // It should report as inactive. - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); } // Test update usage table fails when passed a null pointer. @@ -954,7 +980,7 @@ class OEMCryptoUsageTableDefragTest : public OEMCryptoUsageTableTest { ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); ASSERT_NO_FATAL_FAILURE(entry->TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry->GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry->GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(s.close()); } @@ -1206,8 +1232,7 @@ TEST_P(OEMCryptoUsageTableDefragTest, ManyUsageEntries) { while (successful_count < attempt_count && status == OEMCrypto_SUCCESS) { wvutil::TestSleep::SyncFakeClock(); LOGD("Creating license for entry %zu", successful_count); - entries.push_back( - std::unique_ptr(new LicenseWithUsageEntry())); + entries.push_back(std::make_unique()); entries.back()->set_pst("pst " + std::to_string(successful_count)); ASSERT_NO_FATAL_FAILURE(entries.back()->MakeOfflineAndClose(this, &status)) << "Failed creating license for entry " << successful_count; @@ -1253,8 +1278,7 @@ TEST_P(OEMCryptoUsageTableDefragTest, ManyUsageEntries) { // Create a few more license for (size_t i = 0; i < small_number; i++) { wvutil::TestSleep::SyncFakeClock(); - entries.push_back( - std::unique_ptr(new LicenseWithUsageEntry())); + entries.push_back(std::make_unique()); entries.back()->set_pst("new pst " + std::to_string(smaller_size + i)); entries.back()->MakeOfflineAndClose(this); } @@ -1264,10 +1288,10 @@ TEST_P(OEMCryptoUsageTableDefragTest, ManyUsageEntries) { Session& s = entries[i]->session(); ASSERT_NO_FATAL_FAILURE(entries[i]->OpenAndReload(this)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entries[i]->GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entries[i]->GenerateVerifyReport(OEMCrypto_Unused)); ASSERT_NO_FATAL_FAILURE(entries[i]->TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entries[i]->GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entries[i]->GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(s.close()); } } @@ -1533,11 +1557,11 @@ TEST_P(OEMCryptoUsageTableTest, TimingTest) { wvutil::TestSleep::Sleep(kLongSleep); ASSERT_NO_FATAL_FAILURE(s1.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry1.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry1.GenerateVerifyReport(OEMCrypto_InactiveUsed)); ASSERT_NO_FATAL_FAILURE(s2.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry2.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry2.GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(s3.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry3.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry3.GenerateVerifyReport(OEMCrypto_Unused)); } // Verify the times in the usage report. For performance reasons, we allow @@ -1552,7 +1576,7 @@ TEST_P(OEMCryptoUsageTableTest, VerifyUsageTimes) { Session& s = entry.session(); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); const int64_t kDotIntervalInSeconds = 5; const int64_t kIdleInSeconds = 20; @@ -1569,7 +1593,7 @@ TEST_P(OEMCryptoUsageTableTest, VerifyUsageTimes) { PrintDotsWhileSleep(kIdleInSeconds, kDotIntervalInSeconds); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kUnused)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Unused)); cout << "Start simulated playback..." << endl; int64_t dot_time = kDotIntervalInSeconds; @@ -1578,7 +1602,7 @@ TEST_P(OEMCryptoUsageTableTest, VerifyUsageTimes) { do { ASSERT_NO_FATAL_FAILURE(entry.TestDecryptCTR()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); wvutil::TestSleep::Sleep(kShortSleep); playback_time = wvutil::Clock().GetCurrentTime() - start_time; ASSERT_LE(0, playback_time); @@ -1591,7 +1615,7 @@ TEST_P(OEMCryptoUsageTableTest, VerifyUsageTimes) { cout << "\nSimulated playback time = " << playback_time << " seconds.\n"; ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); EXPECT_NEAR(s.pst_report().seconds_since_first_decrypt() - s.pst_report().seconds_since_last_decrypt(), playback_time, kUsageTableTimeTolerance); @@ -1611,10 +1635,10 @@ TEST_P(OEMCryptoUsageTableTest, VerifyUsageTimes) { // |<--->| = seconds_since_last_decrypt // |<----------------------------->| = seconds_since_first_decrypt // |<------------------------------------| = seconds_since_license_received - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kActive)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_Active)); ASSERT_NO_FATAL_FAILURE(entry.DeactivateUsageEntry()); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); ASSERT_NO_FATAL_FAILURE( entry.TestDecryptCTR(false, OEMCrypto_ERROR_UNKNOWN_FAILURE)); } @@ -1736,13 +1760,13 @@ TEST_P(OEMCryptoUsageTableTestWallClock, TimeRollbackPrevention) { ASSERT_NO_FATAL_FAILURE(s1.GenerateReport(entry1.pst())); wvutil::Unpacked_PST_Report report1 = s1.pst_report(); - EXPECT_EQ(report1.status(), kActive); + EXPECT_EQ(report1.status(), OEMCrypto_Active); EXPECT_GE(report1.seconds_since_license_received(), kTotalTime); EXPECT_GE(report1.seconds_since_first_decrypt(), kTotalTime); ASSERT_NO_FATAL_FAILURE(s2.GenerateReport(entry2.pst())); wvutil::Unpacked_PST_Report report2 = s2.pst_report(); - EXPECT_EQ(report2.status(), kUnused); + EXPECT_EQ(report2.status(), OEMCrypto_Unused); EXPECT_GE(report2.seconds_since_license_received(), kTotalTime); } @@ -1761,7 +1785,7 @@ TEST_P(OEMCryptoUsageTableTest, PSTLargeBuffer) { ASSERT_NO_FATAL_FAILURE( entry.TestDecryptCTR(false, OEMCrypto_ERROR_UNKNOWN_FAILURE)); ASSERT_NO_FATAL_FAILURE(s.UpdateUsageEntry(&encrypted_usage_header_)); - ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(kInactiveUsed)); + ASSERT_NO_FATAL_FAILURE(entry.GenerateVerifyReport(OEMCrypto_InactiveUsed)); ASSERT_NO_FATAL_FAILURE(s.close()); } diff --git a/oemcrypto/util/Android.bp b/oemcrypto/util/Android.bp new file mode 100644 index 0000000..e9d585c --- /dev/null +++ b/oemcrypto/util/Android.bp @@ -0,0 +1,127 @@ +// Copyright (C) 2025 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// *** THIS PACKAGE HAS SPECIAL LICENSING CONDITIONS. PLEASE +// CONSULT THE OWNERS AND opensource-licensing@google.com BEFORE +// DEPENDING ON IT IN YOUR PROJECT. *** +package { + default_team: "trendy_team_media_framework_drm", + // See: http://go/android-license-faq + // A large-scale-change added "default_applicable_licenses" to import + // all of the "license_kinds" from "vendor_widevine_license" + // to get the below license kinds: + // legacy_by_exception_only (by exception only) + default_applicable_licenses: ["vendor_widevine_license"], +} + +// ---------------------------------------------------------------- +// Widevine CRC-32 +// Builds: libwvcrc32.a +cc_defaults { + name: "libwvcrc32_defaults", + defaults: ["widevine_code_protected"], + + local_include_dirs: ["include"], + export_include_dirs: ["include"], + srcs: ["src/wvcrc32.cpp"], + + owner: "widevine", + min_sdk_version: "34", +} + +cc_library_static { + name: "libwvcrc32", + defaults: ["libwvcrc32_defaults"], + proprietary: true, +} + +cc_library_static { + name: "libwvcrc32.system", + defaults: ["libwvcrc32_defaults"], + system_ext_specific: true, +} + +// ---------------------------------------------------------------- +// Widevine Cryptography Library (using BoringSSL) +// Builds: libwvcrypto.a + +cc_defaults { + name: "libwvcrypto_defaults", + + // DO NOT add header dependencies outside of the CDM utils, + // public OEMCrypto headers, and this library itself. + + // For OEMCryptoCENC.h and OEMCryptoCENCCommon.h + header_libs: ["oemcrypto_api_headers"], + local_include_dirs: [ + "include", + ], + // TODO(b/404626424): Remove this need for `include_dirs`. + include_dirs: [ + // Class utils, logging and string utilities. + "vendor/widevine/libwvdrmengine/cdm/util/include", + ], + + export_header_lib_headers: ["oemcrypto_api_headers"], + export_include_dirs: [ + "include", + ], + + srcs: [ + // Generic cryptography. + "src/cmac.cpp", + "src/hmac.cpp", + // DICE / COSE cryptography. + "src/cose_utils.cpp", + // OEMCrypto-specific cryptography. + "src/oemcrypto_cose_key.cpp", + "src/oemcrypto_drm_key.cpp", + "src/oemcrypto_ecc_key.cpp", + "src/oemcrypto_ed_key.cpp", + "src/oemcrypto_key_deriver.cpp", + "src/oemcrypto_oem_cert.cpp", + "src/oemcrypto_oem_cert_chain.cpp", + "src/oemcrypto_rsa_key.cpp", + ], + + shared_libs: [ + "libcppbor", + "libcrypto", + ], + + // TODO(b/400697403): Change implementation to not expose + // libcrypto headers. + export_shared_lib_headers: [ + "libcrypto", + ], + + owner: "widevine", + min_sdk_version: "34", +} + +cc_library_static { + name: "libwvcrypto", + defaults: ["libwvcrypto_defaults"], + + static_libs: ["libcdm_utils"], + proprietary: true, +} + +cc_library_static { + name: "libwvcrypto.system", + defaults: ["libwvcrypto_defaults"], + + static_libs: ["libcdm_utils.system"], + system_ext_specific: true, +} diff --git a/oemcrypto/util/WVCRC32.md b/oemcrypto/util/WVCRC32.md new file mode 100644 index 0000000..d28289c --- /dev/null +++ b/oemcrypto/util/WVCRC32.md @@ -0,0 +1,71 @@ +# Widevine's Cyclic Redundancy Check (CRC) + +Widevine uses an implementation of [CRC-32][CRC-Wiki] for performing data +integrity checks. This is used for verifying the +[decryption hash][OEMCrypto-SetDecryptHash] and for keybox integrity checking. + +## Parameters of Widevine's CRC-32 + +The following are the parameters of Widevine's CRC-32 algorithm. + +| Parameter | Value | Comment | +|:----------------------:|:-------------|:------------------------| +| *Generator Polynomial* | `0x04c11db7` | "Normal" form, MSC of 1 | +| *Initial Value* | `0xffffffff` | | +| *Output Mask* | `0x00000000` | No output masking | +| *Input Reflection* | `false` | No input reflection | +| *Output Reflection* | `false` | No output reflection | + +No additional operations are performed when finalizing. + +Note: Widevine's CRC-32 algorithm is based on CRC-32/MPEG-2. + +## Brief Introduction to CRC + +A Cyclic Redundancy Check (CRC) is a form of error detecting coding which +appends a fixed-width value (code) to the message. CRC is a family of error +detecting codes. + +The *mathematics* behind CRC are that of polynomial division under a Galois +Field of two elements (0 or 1). The data being check is treated as a +polynomial whose coefficients are the bits of the data. The result of +a CRC calculation is the remainder. + +### What is the CRC result? + +Fundamentally, the CRC-32 "result" is the 32-bit integer (representing the +remainder polynomial). + +Historically, the Widevine documentation had blurred the lines between +the CRC result and the serialized form of the CRC used in keyboxes. +Widevine *almost* always serializes this result in a 4 byte sequence +containing the result in network-byte-order, where most systems (such as +x86-64) operate as little-endian. + +Due to this confusion, some implementations of OEMCrypto decryption hash +checking might accept a CRC result for both byte orders (ex. `0x11223344` +and `0x44332211` might be treated as equal). + +### General CRC Parameters + +Each CRC algorithm contains a set of parameters which affect the result. + +* Generator Polynomial - Polynomial used as the divisor in a CRC calculation +* Accumulator - A fixed with unsigned integer, which represents the current + state of the polynomial division +* Initial Value - The initial value of the *accumulator* when starting + the calculation (typically all zeros or all ones) +* Output Mask - An XOR bit mask applied to the final value of the accumulator + before converting to a result. +* Input reflection - Reverses the order of the bits within the octet when + updating. This does not reverse the entire sequence, only the bits + within each octet. +* Output reflection - Reverses the order of the coefficients of the output + polynomial remainder. This does not reverse the entire sequence, only + the bits within the output polynomial unit. + +Some CRC algorithms have additional requirements when finalizing the +result (such as appending bytes). + +[CRC-Wiki]: https://en.wikipedia.org/wiki/Cyclic_redundancy_check +[OEMCrypto-SetDecryptHash]: https://developers.google.com/widevine/drm/client/oemcrypto/v19/oemcrypto-api/group/test-verify#oemcrypto_setdecrypthash diff --git a/oemcrypto/util/build.gyp b/oemcrypto/util/build.gyp new file mode 100644 index 0000000..58fbfee --- /dev/null +++ b/oemcrypto/util/build.gyp @@ -0,0 +1,103 @@ +# Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +# source code may only be used and distributed under the Widevine +# License Agreement. +{ + 'variables': { + 'oemcrypto_dir%': '..', + 'util_dir%': '../../util', + 'third_party_path%': '../../third_party', + 'privacy_crypto_impl%': 'boringssl', + # 'boringssl_libcrypto_path' is used by libcrypto_dependency.gypi + # to select proper libcrypto library. + 'boringssl_libcrypto_path%': '../../third_party/boringssl/boringssl.gyp:crypto', + 'libcppbor_path%': '../../third_party/libcppbor.gyp:cppbor', + # TODO(b/404626424): Update this when moved. + 'wvcrypto_dir%': '../util', + }, + 'targets': [ + { + 'target_name': 'libwvcrypto', + 'type': 'static_library', + 'standalone_static_library': 1, + 'hard_dependency': 1, + + 'include_dirs': [ + '<(oemcrypto_dir)/include', + '<(util_dir)/include', + '<(wvcrypto_dir)/include', + ], + + # Includes used in the headers of wvcrypto. + 'direct_dependent_settings': { + 'include_dirs': [ + '<(oemcrypto_dir)/include', + '<(wvcrypto_dir)/include', + # libcrypto includes cannot be exported easily. + ], + }, + + 'sources': [ + '<(wvcrypto_dir)/src/cmac.cpp', + '<(wvcrypto_dir)/src/cose_utils.cpp', + '<(wvcrypto_dir)/src/hmac.cpp', + '<(wvcrypto_dir)/src/oemcrypto_cose_key.cpp', + '<(wvcrypto_dir)/src/oemcrypto_drm_key.cpp', + '<(wvcrypto_dir)/src/oemcrypto_ecc_key.cpp', + '<(wvcrypto_dir)/src/oemcrypto_ed_key.cpp', + '<(wvcrypto_dir)/src/oemcrypto_key_deriver.cpp', + '<(wvcrypto_dir)/src/oemcrypto_oem_cert.cpp', + '<(wvcrypto_dir)/src/oemcrypto_oem_cert_chain.cpp', + '<(wvcrypto_dir)/src/oemcrypto_rsa_key.cpp', + ], + + 'includes': [ + # Sets values needed for using OpenSSL/BoringSSL. + '../../util/libcrypto_dependency.gypi', + ], + + 'dependencies': [ + '<(libcppbor_path)', + ], + }, + { + 'target_name': 'liboec_ref_util', + 'type': 'static_library', + 'standalone_static_library': 1, + 'hard_dependency': 1, + + 'include_dirs': [ + '<(oemcrypto_dir)/include', + '<(util_dir)/include', + '<(wvcrypto_dir)/include', + ], + + # Includes used in the headers of wvcrypto. + 'direct_dependent_settings': { + 'include_dirs': [ + '<(oemcrypto_dir)/include', + '<(oemcrypto_dir)/util/include', + '<(wvcrypto_dir)/include', + # libcrypto includes cannot be exported easily. + ], + }, + + 'sources': [ + '<(oemcrypto_dir)/util/src/bcc_validator.cpp', + '<(oemcrypto_dir)/util/src/cbor_validator.cpp', + '<(oemcrypto_dir)/util/src/device_info_validator.cpp', + '<(oemcrypto_dir)/util/src/prov4_validation_helper.cpp', + '<(oemcrypto_dir)/util/src/signed_csr_payload_validator.cpp', + ], + + 'dependencies': [ + '<(libcppbor_path)', + 'libwvcrypto', + ], + + 'includes': [ + # Sets values needed for using OpenSSL/BoringSSL. + '../../util/libcrypto_dependency.gypi', + ], + } + ], +} diff --git a/oemcrypto/util/include/bcc_validation_config.h b/oemcrypto/util/include/bcc_validation_config.h new file mode 100644 index 0000000..fd5392b --- /dev/null +++ b/oemcrypto/util/include/bcc_validation_config.h @@ -0,0 +1,27 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVOEC_UTIL_BCC_VALIDATION_CONFIG_H_ +#define WVOEC_UTIL_BCC_VALIDATION_CONFIG_H_ + +namespace wvoec { +namespace util { +struct BccValidationConfig { + bool check_widevine_component_name = false; + // ... other flags as needed ... +}; + +class BccValidationConfigBuilder { + public: + BccValidationConfigBuilder& set_check_widevine_component_name(bool value) { + config_.check_widevine_component_name = value; + return *this; + } + BccValidationConfig build() { return config_; } + + private: + BccValidationConfig config_; +}; +} // namespace util +} // namespace wvoec +#endif // WVOEC_UTIL_BCC_VALIDATION_CONFIG_H_ diff --git a/oemcrypto/util/include/bcc_validator.h b/oemcrypto/util/include/bcc_validator.h index adc84a6..ac4e693 100644 --- a/oemcrypto/util/include/bcc_validator.h +++ b/oemcrypto/util/include/bcc_validator.h @@ -9,6 +9,7 @@ #include +#include "bcc_validation_config.h" #include "cbor_validator.h" #include "prov4_validation_helper.h" @@ -121,7 +122,7 @@ struct Bcc { std::string ToString() const; CborMessageStatus Validate( std::vector>& msgs, - bool is_degenerated) const; + bool is_degenerated, const BccValidationConfig& validation_config) const; }; // BccValidator processes a Provisioning 4.0 device root of trust. It extracts @@ -132,6 +133,8 @@ struct Bcc { class BccValidator : public CborValidator { public: BccValidator() = default; + BccValidator(const BccValidationConfig& config) + : validation_config_(config) {} virtual ~BccValidator() override = default; WVCDM_DISALLOW_COPY_AND_MOVE(BccValidator); @@ -140,6 +143,13 @@ class BccValidator : public CborValidator { // Outputs formatted BCC. virtual std::string GetFormattedMessage() const override; + void SetValidationConfig(const BccValidationConfig& config) { + validation_config_ = config; + } + void SetCheckWidevineComponentName(bool value) { + validation_config_.check_widevine_component_name = value; + } + private: // Processes CoseKey PubKeyEd25519 / PubKeyECDSA256 / PubKeyECDSA384, which // contains subject public key, and extracts the PubKey to *|public_key_info|. @@ -168,6 +178,8 @@ class BccValidator : public CborValidator { bool VerifySignature(const BccPublicKeyInfo& signing_key, const std::vector& message, const std::vector& signature); + // Validation configuration. + BccValidationConfig validation_config_; // Used to generate formatted message. std::stringstream msg_ss_; }; // class BccValidator diff --git a/oemcrypto/util/include/cose_utils.h b/oemcrypto/util/include/cose_utils.h new file mode 100644 index 0000000..6c9add3 --- /dev/null +++ b/oemcrypto/util/include/cose_utils.h @@ -0,0 +1,462 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#ifndef WVOEC_UTIL_COSE_UTILS_H_ +#define WVOEC_UTIL_COSE_UTILS_H_ + +#include + +#include +#include +#include + +#include "wv_class_utils.h" + +namespace wvoec { +namespace util { +// ==== COSE Utilities == +// Most of the following constants related to COSE are found +// in RFC 9052 and 9053. +// The term "label" is borrowed from the COSE RFCs and refers to the +// map keys of defined COSE structs. + +// == COSE_Key == + +// CDDL of COSE_Key is from RFC 9052 section 7. +// COSE_Key = { +// 1 => tstr / int, ; kty +// ? 2 => bstr, ; kid (not supported) +// ? 3 => tstr / int, ; alg +// ? 4 => [+ (tstr / int) ], ; key_ops (only int supported) +// ? 5 => bstr, ; Base IV (not supported) +// ; Some labels are well-defined depending on the context +// ; and value of 'kty' +// * label => values +// } + +// COSE_Key labels (map keys of COSE_Key). +// RFC 9052, section 7.1, table 4. +// Key type ("kty") => tstr/int (only int is supported) +// Identifies the key type within the struct. +constexpr int64_t kCoseKeyKeyTypeLabel = 1; +// Key algorithm ("alg") => tstr/int (only int is supported) +// Used to identify/restrict the use of the specified key. +constexpr int64_t kCoseKeyAlgorithmLabel = 3; +// Key operations ("key_ops") => [+ (tstr/int) ] (only int is +// supported). +// Used to identify/restrict the use of the specified key. +constexpr int64_t kCoseKeyKeyOpsLabel = 4; + +// Extended COSE_Key labels for case of EC2 (P-256, P-384) and +// OKP (Ed25519) key types. +// Curve identifier ("crv") => tstr/int (only int is supported) +// Identifier for the curve. Used for both EC2 and OKP +constexpr int64_t kCoseKeyCurveLabel = -1; +// Private key ("d") => bstr +// For EC2 keys, big-endian encoded private scalar. +// For Ed25519 keys, RFC 8032 encoded private key. +constexpr int64_t kCoseKeyPrivateKeyLabel = -4; + +// Double Coordinate Curves (kty=EC2). +// RFC 9053, section 7.1, table 19. +// X coordinate ("x") => bstr +// Big-endian encoding of X coordinate. +constexpr int64_t kCoseKeyXCoordLabel = -2; +// Y coordinate ("y") => bstr/bool +// When bstr, big-endian encoding of Y coordinate. +// When bool, sign-bit of Y coordinate +constexpr int64_t kCoseKeyYCoordLabel = -3; + +// Twisted Edwards Key (kty=OKP, crv=Ed25519) +// RFC 9053, section 7.2, table 20. +// Public key ("x") => bstr +// RFC 8032 encoded Ed25519 public key. +constexpr int64_t kCoseKeyPublicKeyLabel = -2; + +// COSE_Key key type ("kty") values. +// RFC 9053, section 10.1, table 22. +// Octet-Key-Pair (used for Ed25519 keys) +constexpr int64_t kCoseKeyTypeOkp = 1; +// Double Coordinate Elliptic Curve (used for P-256 and P-384 keys) +constexpr int64_t kCoseKeyTypeEc2 = 2; + +// COSE_Key algorithm ("alg") values. +// RFC 9053: +// ECDSA -> section 2.1, table 1 +// EdDSA -> section 2.2, table 2 +// ECDSA w/ SHA-256 +constexpr int64_t kCoseAlgorithmEs256 = -7; +// ECDSA w/ SHA-384 +constexpr int64_t kCoseAlgorithmEs384 = -35; +// EdDSA w/ SHA-512 +constexpr int64_t kCoseAlgorithmEdDsa = -8; + +// COSE_Key key operations ("key_ops") element values. +// RFC 9052, section 7.1, table 5 +// May be used to create signatures; requires private key fields. +constexpr int64_t kCoseKeyOpSign = 1; +// May be used to verify signatures. +constexpr int64_t kCoseKeyOpVerify = 2; +// May be used for key transport encryption. +constexpr int64_t kCoseKeyOpEncrypt = 3; +// May be used for key transport decryption; requires private key fields. +constexpr int64_t kCoseKeyOpDecrypt = 4; +// May be used for key wrap encryption. +constexpr int64_t kCoseKeyOpWrapKey = 5; +// May be used for key wrap decryption; requires private key fields. +constexpr int64_t kCoseKeyOpUnwrapKey = 6; +// May be used for key derivation; requires private key fields. +constexpr int64_t kCoseKeyOpDeriveKey = 7; +// May be used for bit derivation (not keys); requires private key fields. +constexpr int64_t kCoseKeyOpDeriveBits = 8; +// May be used to create MAC keys. +constexpr int64_t kCoseKeyOpMacCreate = 9; +// May be used to verify MAC keys. +constexpr int64_t kCoseKeyOpMacVerify = 10; + +// COSE_Key curve ("crv") values. +// RFC 9053, section 7.1 table 18. +// Elliptic Curve P-256 (kty=EC2). +constexpr int64_t kCoseCurveP256 = 1; +// Elliptic Curve P-384 (kty=EC2). +constexpr int64_t kCoseCurveP384 = 2; +// Twisted Edwards Curve 25519 (kty=OKP). +constexpr int64_t kCoseCurveEd25519 = 6; + +// Wrapper around the key data related to EC2 keys. +struct CoseEc2KeyData { + static constexpr int64_t key_type = kCoseKeyTypeEc2; + std::optional algorithm; + std::vector key_ops; + int64_t curve = 0; + std::vector x; // empty if not set. + std::variant, bool> y; // empty vector if not set. + std::vector private_key; // empty if not set. + + bool has_public_key() const { + // Both components must be set to be considered valid. + if (x.empty()) return false; + if (std::holds_alternative(y)) return true; + return !std::get>(y).empty(); + } + + void clear_public_key() { + x.clear(); + y = std::vector(); + } + + bool has_private_key() const { return !private_key.empty(); } + void clear_private_key() { private_key.clear(); } + + void Clear() { + algorithm.reset(); + key_ops.clear(); + curve = 0; + clear_public_key(); + clear_private_key(); + } +}; + +// Wrapper around the key data related to OKP (Ed25519) keys. +struct CoseOkpKeyData { + static constexpr int64_t key_type = kCoseKeyTypeOkp; + std::optional algorithm; + std::vector key_ops; + int64_t curve = 0; + std::vector public_key; // empty if not set. + std::vector private_key; // empty if not set. + + bool has_public_key() const { return !public_key.empty(); } + void clear_public_key() { public_key.clear(); } + + bool has_private_key() const { return !private_key.empty(); } + void clear_private_key() { private_key.clear(); } + + void Clear() { + algorithm.reset(); + key_ops.clear(); + curve = 0; + clear_public_key(); + clear_private_key(); + } +}; + +// Wrapper around a COSE_Key. +// This implementation only supports EC2 and OKP and limits +// the supported fields to those used by OEMCrypto. +// +// The effective CDDL this class uses is: +// +// COSE_Key = { +// 1 => int, ; kty (must be OKP(1) or EC2(2)) +// ? 2 => [+int] ; key_ops +// ? 3 => int, ; alg +// -1 => int, ; crv +// ? -2 => bstr, ; x-coord (EC2) or public key (OKP) +// ? -3 => bstr / bool, ; y-coord/y-sign-bit (EC2 only) +// ? -4 => bstr ; d (EC2) or private key (OKP) +// ; All other fields are ignored +// * label => values +// } +class CoseKeyData { + public: + CoseKeyData() = default; + WVCDM_DEFAULT_COPY_AND_MOVE(CoseKeyData); + + int64_t key_type() const; + + bool has_algorithm() const; + std::optional algorithm() const; + // Sets the algorithm without validation. + void set_algorithm(int64_t algorithm); + // Clear the algorithm. + void clear_algorithm(); + + const std::vector& key_ops() const; + bool has_key_ops() const; + void set_key_ops(const std::vector& key_ops); + void clear_key_ops(); + + // Operations on individual key operations. + bool ContainsKeyOp(int64_t key_op) const; + void AppendKeyOp(int64_t key_op); + void RemoveKeyOp(int64_t key_op); + + int64_t curve() const; + // Sets the curve without validation. + void set_curve(int64_t curve); + + bool has_public_key() const; + void clear_public_key(); + + const std::vector& private_key() const; + bool has_private_key() const; + void set_private_key(const std::vector& private_key); + void clear_private_key(); + + bool IsEc2Data() const { + return std::holds_alternative(data_); + } + std::optional ec2_data() const; + void set_ec2_data(const CoseEc2KeyData& data) { data_ = data; } + + bool IsOkpData() const { + return std::holds_alternative(data_); + } + std::optional okp_data() const; + void set_okp_data(const CoseOkpKeyData& data) { data_ = data; } + + // Clears the key data, but keeps the same key type. + void Clear(); + + // Attempts to parse the CBOR COSE_Key. + // Input should be a CBOR serialized COSE_Key. + // Requires 'kty' and 'crv' fields. + // Only validates the value of 'kty' as being EC2 or OKP. + // For all other defined fields, only the value type is validated. + bool ParseCbor(const uint8_t* buffer, size_t buffer_size); + bool ParseCbor(const std::vector& buffer) { + return ParseCbor(buffer.data(), buffer.size()); + } + // Serializes the COSE_Key data. + // - "kty" and "curve" are always serialized. + // - "alg" is only serialized if set + // - "key_ops" is only serialized if non-empty + // - bstr fields are only serialized of non-empty + std::vector SerializeCbor() const; + + private: + // Internal getters. Must check type before calling, otherwise + // bad things will happen. + const CoseEc2KeyData& GetEc2Data() const { + return std::get(data_); + } + CoseEc2KeyData& GetEc2Data() { return std::get(data_); } + + const CoseOkpKeyData& GetOkpData() const { + return std::get(data_); + } + CoseOkpKeyData& GetOkpData() { return std::get(data_); } + + // Private mutable key_ops getter. + std::vector& mutable_key_ops(); + + std::variant data_; +}; + +// == COSE_Sign1 == + +// CDDL of COSE_Sign1 is partially RFC 9052 section 4.2; though +// the sub-group definitions are expanded for our use case. +// COSE_Sign1 = [ +// ; If present, only algorithm identifier is used. +// protected: bstr .cbor header_map / bstr .size 0, +// ; Values within "unprotected" are unused +// unprotected: header_map, +// ; Optional payload. +// payload : bstr / nil, +// ; Signature of the CBOR encoded Sig_structure +// signature : bstr +// ] +// +// header_map = { ; Expanded from Generic_Headers +// ? 1 => int / tstr, ; algorithm identifier +// ? 2 => [+label], ; criticality (not supported) +// ? 3 => tstr / int, ; content type (not supported) +// ? 4 => bstr, ; key identifier (not supported) +// ? ( 5 => bstr // ; IV (not supported) +// 6 => bstr ) ; Partial IV (not supported) +// * label => values (not supported) +// } + +// A COSE_Sign1 sequence must contain 4 elements. +constexpr size_t kCoseSign1Length = 4; +// Protected parameters ("protected") => bstr .cbor header_map / bstr .size 0 +// Optional parameters used to generate/verify the signature. +// For OEMCrypto's use case, only the "alg" field is checked. +// If protected parameters are to be omitted, this element MUST be an +// empty bstr. +constexpr size_t kCoseSign1ProtectedIndex = 0; +// Unprotected parameters ("unprotected") => header_map (unused) +constexpr size_t kCoseSign1UnprotectedIndex = 1; +// Signed content ("payload") => bstr / nil +// COSE_Sign1 signatures MAY contain the payload which it signs +// ("attached content"). If the payload is sent separately from +// the content then this field may be nil ("detached content"). +constexpr size_t kCoseSign1PayloadIndex = 2; +// Computed COSE signature ("signature") => bstr +// Raw signature bytes, exact format depends on the signing +// algorithm. +// The signature is computed over the CBOR encoded Sig_structure +// which includes the payload and some additional context; +// not just the payload. +constexpr size_t kCoseSign1SignatureIndex = 3; + +// Generic_Headers labels (map keys of Generic_Headers, used +// in the definition of header_map). +// RFC 9052, section 3.1, table 3. +// Signature algorithm ("alg") => tstr/int (only int is supported). +// Should be the same value as "alg" from COSE_Key. +constexpr int64_t kCoseGenHeaderAlgorithmLabel = 1; + +// Wrapped around COSE_Sign1 array. +// This implementation only supports the fields used by OEMCrypto. +// +// COSE_Sign1 = [ +// protected: bstr .size 0 / bstr .cbor { +// ? 1 => int, ; alg +// ; All other fields are ignored +// * label => values +// }, +// unprotected: { +// ; All fields are ignored +// * label => values +// }, +// payload : bstr / nil, +// signature : bstr +// ] +// +// Note: using "protected_params" as "protected" is a C++ keyword. +class CoseSign1Data { + public: + CoseSign1Data() = default; + WVCDM_DEFAULT_COPY_AND_MOVE(CoseSign1Data); + + // CBOR encoded header_map or empty. + const std::vector& protected_params() const { + return protected_params_; + } + // Set protected without validation. + void set_protected_params(const std::vector& params) { + protected_params_ = params; + } + void clear_protected_params() { protected_params_.clear(); } + + // Parsed from |protected_params_|. + // Is successful, returns an option containing the algorithm; + // otherwise returns nullopt (see logs for details). + std::optional ExtractAlgorithm() const; + + // Sets |protected_params_| to be a CBOR encoded header_map + // containing the "alg" field. + // Note: this will clear any existing data in |protected_params_|. + void SetAlgorithm(int64_t algorithm); + + // Optional payload, nullopt if CBOR value is nil. + // Note: an empty vector is a valid payload. + const std::optional>& payload() const { + return payload_; + } + bool has_payload() const { return payload_.has_value(); } + void set_payload(const std::vector& payload) { payload_ = payload; } + void clear_payload() { payload_.reset(); } + + // Signature bytes. + const std::vector& signature() const { return signature_; } + bool has_signature() const { return !signature_.empty(); } + void set_signature(const std::vector& signature) { + signature_ = signature; + } + void clear_signature() { signature_.clear(); } + + void Clear() { + protected_params_.clear(); + payload_.reset(); + signature_.clear(); + } + + // Attempts to parse the CBOR COSE_Sign1. + // Input should be a CBOR serialized COSE_Sign1 sequence. + // Limited validation: + // - "protected" : checked to be an empty bstr or a + // CBOR encoded map. If a map, then only the "alg" is + // checked to be an integer. + // - "unprotected" : checked to be a map, contents are ignored. + // - "payload" : checked to be bstr or nil + // - "signature" : checked to be a bstr + // Returns true if successfully parsed, false otherwise. + bool ParseCbor(const uint8_t* buffer, size_t buffer_size); + bool ParseCbor(const std::vector& buffer) { + return ParseCbor(buffer.data(), buffer.size()); + } + + // Serializes the COSE_Sign1 sequence. + // - "protected" is serialized as is, always a bstr + // - "unprotected" is always an empty map + // - "payload" if serialized as is, bstr or nil + // - "signature" is serialized as is, always a bstr + std::vector SerializeCbor() const; + + private: + std::vector protected_params_; + std::optional> payload_; + std::vector signature_; +}; + +// Constructs a COSE_Sign1 specific Sig_structure. +// Defined in RFC 9052 section 4.4. +// +// Sig_structure = [ +// ; Specific to COSE_Sign1 +// context : "Signature1", +// ; Must be the exact bytes from COSE_Sign1 "protected" +// ; (not a re-serialization of one). +// body_protected : bstr .cbor header_map / bstr .size 0, +// ; Not used, always empty. +// external_aad : bstr, +// payload : bstr +// ] +// +// Parameters: +// - |protected_params| is used for "body_protected" +// - |payload| is used for "payload" +std::vector PackageCoseSign1SigStructure( + const std::vector& protected_params, + const std::vector& payload); + +} // namespace util +} // namespace wvoec +#endif // WVOEC_UTIL_COSE_UTILS_H_ diff --git a/oemcrypto/util/include/oemcrypto_cose_key.h b/oemcrypto/util/include/oemcrypto_cose_key.h new file mode 100644 index 0000000..02d5881 --- /dev/null +++ b/oemcrypto/util/include/oemcrypto_cose_key.h @@ -0,0 +1,343 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#ifndef WVOEC_UTIL_OEMCRYPTO_COSE_KEY_H_ +#define WVOEC_UTIL_OEMCRYPTO_COSE_KEY_H_ + +#include + +#include +#include +#include + +#include "OEMCryptoCENCCommon.h" +#include "oemcrypto_ecc_key.h" +#include "oemcrypto_ed_key.h" +#include "wv_class_utils.h" + +namespace wvoec { +namespace util { +// The COSE protocol uses several parameters to define a particular +// type of key. For OEMCrypto's use case, the combinations of these +// parameters are limited. This enum "CoseKeyFamily" are catch all +// for OEMCrypto's use case. +// +// See the details of each enum below. +enum CoseKeyFamily { + kCoseKeyFamilyUnknown = 0, + // Elliptic curve of P-256, equivalent to secp256r1 from the + // DRM key. + // kty = EC2 + // curve = P-256 + // alg = ES256 (ECDSA w/ SHA-256) + kCoseKeyP256 = 256, + + // Elliptic curve of P-384, equivalent to secp384r1 from the + // DRM key. + // kty = EC2 + // curve = P-384 + // alg = ES384 (ECDSA w/ SHA-384) + kCoseKeyP384 = 384, + + // Not supported: Android/Widevine DICE protocol does not support + // P-521; although support is relatively easy given its use for + // DRM keys. + // kty = EC2 + // curve = P-521 + // alg = ES512 (ECDSA w/ SHA-512) + // kCoseKeyP521 = 521, + + // Twisted Edwards Curve Ed25519 + // kty = OKP + // curve = Ed25519 + // alg = EdDSA (PureEdDSA) + kCoseKeyEd25519 = 25519, +}; + +std::string CoseKeyFamilyToString(CoseKeyFamily family); + +class CosePrivateKey; + +class CosePublicKey { + public: + CosePublicKey() = delete; + ~CosePublicKey() = default; + WVCDM_DISALLOW_COPY_AND_MOVE(CosePublicKey); + + static std::unique_ptr New(const CosePrivateKey& private_key); + + // Loads a COSE public key from the provided untagged CBOR-encoded + // COSE_Key. The encoded key may be public or private; however, + // only the public components will be loaded. + // + // This load function only supports P-256, P-384 and Ed25519 + // keys. + // + // The CBOR format that this method is expecting is as follows: + // + // COSE_Key = { + // 1 => int ; Required, "kty" key type, integer only, + // ; must be either EC2(2) or OKP(1) + // ? 3 => int ; Optional, algorithm, integer only, must be an + // ; appropriate algorithm base on the curve. + // ; Assumed to be the correct algorithm based + // ; on curve + // -1 => int ; Required, "crv" curve, integer only, + // ; must be either P-256(1), P-384(2) or Ed25519(6); + // ; validated against "kty" + // ; P-256/P-384 -> kty = EC2 + // ; Ed25519 -> kty = OKP + // ? -2 => bstr ; Conditionally optional, x-coordinate or public key, + // ; required if private key is not provided. + // ; Format depends on curve + // ; P-256/P-384 -> big-endian (x-coordinate) + // ; Ed25519 -> little endian (public key) + // ? -3 => bstr / ; Conditionally optional, y-coordinate, + // bool ; required if curve is P-256 or P-384 and private + // ; key is not provided. + // ; When bstr, big-endian encoding of x-coordinate + // ; When bool, considered the "signed bit" + // ; true indicates negative y + // ; false indicates positive y + // ? -4 => bstr ; Conditionally optional, private key, + // ; required if public components are not provided. + // ; Format depends on curve + // ; P-256/P-384 -> big-endian (private key point) + // ; Ed25519 -> little endian (private scalar) + // ; All other fields are ignored + // } + static std::unique_ptr LoadCborCoseKey(const uint8_t* buffer, + size_t buffer_size); + static std::unique_ptr LoadCborCoseKey( + const std::string& buffer); + static std::unique_ptr LoadCborCoseKey( + const std::vector& buffer); + + // Loads a P-256 or P-384 COSE key from their key points. + // Note: |family| must be one of the supported EC2 key types. + static std::unique_ptr LoadEccKeyPoint( + CoseKeyFamily family, const std::vector& x, + const std::vector& y); + static std::unique_ptr LoadEccKeyPoint( + CoseKeyFamily family, const std::vector& x, bool negative_y); + + // Loads a P-256 or P-384 COSE key using the SEC 1 format EC key point. + // Note: |family| must be one of the supported EC2 key types. + static std::unique_ptr LoadEccSec1KeyPoint( + CoseKeyFamily family, const std::vector& key_point); + + // Loads a P-256 or P-384 COSE key using the private scalar + // encoded as big-endian. + // Note: |family| must be one of the supported EC2 key types. + static std::unique_ptr LoadEccFromPrivateScalar( + CoseKeyFamily family, const std::vector& private_scalar); + + // Loads an Ed25519 COSE key using the raw public key. + static std::unique_ptr LoadEd25519FromPublic( + const std::vector& public_key); + + // Loads an Ed25519 COSE key using the raw private key. + static std::unique_ptr LoadEd25519FromPrivate( + const std::vector& private_key); + + CoseKeyFamily family() const; + + const EccPublicKey* ecc_key() const { return ecc_key_.get(); } + const EdPublicKey* ed_key() const { return ed_key_.get(); } + + bool IsMatchingPrivateKey(const CosePrivateKey& private_key) const; + + // Serializes the public key to a CBOR encoded COSE_Key. + // Only the public components are serialized. + std::vector SerializeCbor(bool include_algorithm = true) const; + + // Verifies the COSE_Sign1 signature. + // + // The two different version of this method allow for cases where + // the signature and payload are combined, or when they are separate. + // + // The parameter |cose_sign1| is expected to be a CBOR encoded + // COSE_Sign1 array. + // + // COSE_Sign1 = [ + // protected: bstr .cbor { + // ? 1 => int ; Optional, signing algorithm. + // ; If present, must be the expected + // ; value for this key family. + // }, + // unprotected: {}, ; Any parameters here are ignored + // payload: bstr / nil ; Required, + // ; If calling without |message| parameter, + // ; this value be bstr and the contents will + // ; be used to verify signature + // ; If calling with |message| parameter, + // ; this value must be nil, or if bstr, + // ; it must be identical to |message| + // signature: bstr ; Required, signature to be verified. + // ; The signature is compared against + // ; the constructed Sig_structure following + // ; the COSE_Sign1 protocol. + // ] + // + // Returns: + // - OEMCrypto_SUCCESS if signature is valid + // - OEMCrypto_ERROR_SIGNATURE_FAILURE if the signature is not + // valid + // - OEMCrypto_ERROR_INVALID_KEY if the algorithm specified in the + // "protected" field does not match the key + // - OEMCrypto_ERROR_INVALID_CONTEXT if |cose_sign1| could not be + // parsed, or the there were unexpected values + // - OEMCrypto_ERROR_UNKNOWN_FAILURE if an internal error + // occurs. + OEMCryptoResult VerifyCoseSign1Signature( + const std::vector& cose_sign1) const; + OEMCryptoResult VerifyCoseSign1Signature( + const std::vector& message, + const std::vector& cose_sign1) const; + + private: + CosePublicKey(std::unique_ptr&& ecc_key); + CosePublicKey(std::unique_ptr&& ed_key); + + // Internal helper for verifying COSE_Sign1 signatures. + // The provided parameters should be taken from the COSE_Sign1 + // structure, with the exception of |payload| which may or + // may not be from the COSE_Sign1 structure (attached or + // detached). + OEMCryptoResult VerifyCoseSign1SignatureInternal( + const std::vector& protected_params, + const std::vector& payload, + const std::vector& signature) const; + + // Internal helper for verifying signatures. + // Verifies the provided |signature| against |verifying_payload| + // using the appropriate algorithm for the public key type. + // Generally, it is expected that |verifying_payload| is a CBOR + // encoded Sig_structure. + OEMCryptoResult VerifyRawSignatureInternal( + const std::vector& verifying_payload, + const std::vector& signature) const; + + // Only one of the two will be set. + std::unique_ptr ecc_key_; + std::unique_ptr ed_key_; +}; // class CosePublicKey + +class CosePrivateKey { + public: + CosePrivateKey() = delete; + ~CosePrivateKey() = default; + WVCDM_DISALLOW_COPY_AND_MOVE(CosePrivateKey); + + static std::unique_ptr New(CoseKeyFamily family); + + // Loads a COSE private key from the provided untagged CBOR-encoded + // COSE_Key. The encoded key must be private. + // + // This load function only supports P-256, P-384 and Ed25519 + // keys. + // + // The CBOR format that this method is expecting is as follows: + // + // COSE_Key = { + // 1 => int ; Required, "kty" key type, integer only, + // ; must be either EC2(2) or OKP(1) + // ? 3 => int ; Optional, algorithm, integer only, must be an + // ; appropriate algorithm base on the curve. + // ; Assumed to be the correct algorithm based + // ; on curve + // -1 => int ; Required, "crv" curve, integer only, + // ; must be either P-256(1), P-384(2) or Ed25519(6); + // ; validated against "kty" + // ; P-256/P-384 -> kty = EC2 + // ; Ed25519 -> kty = OKP + // -4 => bstr ; Required, private key / private key scalar + // ; required if public components are not provided. + // ; Format depends on curve + // ; P-256/P-384 -> big-endian (private key point) + // ; Ed25519 -> little endian (private scalar) + // ; All other fields are ignored + // } + static std::unique_ptr LoadCborCoseKey(const uint8_t* buffer, + size_t buffer_size); + static std::unique_ptr LoadCborCoseKey( + const std::string& buffer); + static std::unique_ptr LoadCborCoseKey( + const std::vector& buffer); + + // Loads a P-256 or P-384 COSE key using the private scalar + // encoded as big-endian. + // Note: |family| must be one of the supported EC2 key types. + static std::unique_ptr LoadEccPrivateScalar( + CoseKeyFamily family, const std::vector& private_key); + + // Loads an Ed25519 COSE key using the raw private key. + static std::unique_ptr LoadEd25519RawPrivate( + const std::vector& private_key); + + // Creates a new COSE public key of this private key. + // Equivalent to calling CosePublicKey::New with this private + // key. + std::unique_ptr MakePublicKey() const; + + CoseKeyFamily family() const; + + const EccPrivateKey* ecc_key() const { return ecc_key_.get(); } + const EdPrivateKey* ed_key() const { return ed_key_.get(); } + + bool IsMatchingPublicKey(const CosePublicKey& public_key) const; + + // Serializes the public key to a CBOR encoded COSE_Key. + // Both the public and private components are serialized. + std::vector SerializeCbor(bool include_algorithm = true) const; + + // Generates a COSE_Sign1 signature of the provided + // |payload|. + // + // The CBOR format of this method is: + // + // COSE_Sign1 = [ + // protected: bstr .cbor { + // 1 => int ; Signing algorithm, see CoseKeyFamily + // ; enums for which algorithm will be set. + // }, + // unprotected: {}, ; Always empty. + // payload: bstr / nil ; Optional copy of |payload|. + // ; Included if |include_payload| is true; + // ; otherwise, set to nil. + // signature: bstr ; Signature generated on the constructed + // ; Sig_structure following the COSE_Sign1 + // ; protocol. + // ] + std::vector GenerateCoseSign1Signature( + const std::vector& payload, bool include_payload = false) const; + + // This exposes GenerateRawSignatureInternal() but is for + // testing only (testbed != testing). + std::vector GenerateRawSignatureForTest( + const std::vector& signing_payload) const { + return GenerateRawSignatureInternal(signing_payload); + } + + private: + CosePrivateKey(std::unique_ptr&& ecc_key); + CosePrivateKey(std::unique_ptr&& ed_key); + + // Internal helper for generating signatures. + // Signs the provided |signing_payload| as is using the appropriate + // algorithm for the private key type. + // Generally, it is expected that |signing_payload| is a CBOR + // encoded Sig_structure. + std::vector GenerateRawSignatureInternal( + const std::vector& signing_payload) const; + + // Only one of the two will be set. + std::unique_ptr ecc_key_; + std::unique_ptr ed_key_; +}; // class CosePrivateKey +} // namespace util +} // namespace wvoec +#endif // WVOEC_UTIL_OEMCRYPTO_COSE_KEY_H_ diff --git a/oemcrypto/util/include/oemcrypto_drm_key.h b/oemcrypto/util/include/oemcrypto_drm_key.h index a02b94e..faaf3a5 100644 --- a/oemcrypto/util/include/oemcrypto_drm_key.h +++ b/oemcrypto/util/include/oemcrypto_drm_key.h @@ -7,10 +7,14 @@ #ifndef WVOEC_UTIL_DRM_KEY_H_ #define WVOEC_UTIL_DRM_KEY_H_ +#include + #include +#include #include #include +#include "OEMCryptoCENC.h" #include "OEMCryptoCENCCommon.h" #include "oemcrypto_ecc_key.h" #include "oemcrypto_rsa_key.h" @@ -18,18 +22,89 @@ namespace wvoec { namespace util { +// OEMCrypto may use a different hash algorithm for +// RSASSA-PSS signatures than the default of SHA-1. +// +// This option does not apply to ECC keys, as the hash algorithm +// is determined by the curve. +using DrmSignatureHashAlgorithmOption = + std::optional; + +// DRM public key performs some of the operations required by an +// OEMCrypto session's RSA/ECC private key. +class DrmPublicKey { + public: + DrmPublicKey() = delete; + WVCDM_DISALLOW_COPY_AND_MOVE(DrmPublicKey); + ~DrmPublicKey() = default; + + // Create an RSA-based DRM key. + static std::unique_ptr Create( + std::shared_ptr&& rsa_key, + DrmSignatureHashAlgorithmOption hash_algorithm_opt = std::nullopt); + static std::unique_ptr Create( + std::unique_ptr&& rsa_key, + DrmSignatureHashAlgorithmOption hash_algorithm_opt = std::nullopt); + // Create an ECC-based DRM key. + static std::unique_ptr Create( + std::shared_ptr&& ecc_key); + static std::unique_ptr Create( + std::unique_ptr&& ecc_key); + + bool IsRsaKey() const { return static_cast(rsa_key_); } + bool IsEccKey() const { return static_cast(ecc_key_); } + + // Get the OEMCrypto signature hash algorithm used by this + // key when calling VerifySignature(). + OEMCrypto_SignatureHashAlgorithm GetSignatureHashAlgorithm() const; + + // Verifies the |signature| matches the provided |message| + // by the private equivalent of this public key. + // + // For RSA keys, the signature is RSASSA-PSS. + // For ECC keys, the signature is ECDSA. + // + // See EccPublicKey::VerifySignature and + // RsaPublicKey::VerifySignature for details. + OEMCryptoResult VerifySignature(const uint8_t* message, size_t message_length, + const uint8_t* signature, + size_t signature_length) const; + OEMCryptoResult VerifySignature(const std::vector& message, + const std::vector& signature) const; + OEMCryptoResult VerifySignature(const std::string& message, + const std::string& signature) const; + + private: + DrmPublicKey(std::shared_ptr&& rsa_key, + OEMCrypto_SignatureHashAlgorithm rsa_hash_algorithm); + DrmPublicKey(std::shared_ptr&& ecc_key); + + // Only one will be set. + std::shared_ptr ecc_key_; + std::shared_ptr rsa_key_; + // |rsa_hash_algorithm_| is only used with |rsa_key_| when + // verifying non-CAST signatures. + OEMCrypto_SignatureHashAlgorithm rsa_hash_algorithm_ = OEMCrypto_SHA1; +}; // class DrmPublicKey + // DRM private key performs all of the operations required by an // OEMCrypto session's RSA/ECC private key. class DrmPrivateKey { public: + DrmPrivateKey() = delete; WVCDM_DISALLOW_COPY_AND_MOVE(DrmPrivateKey); ~DrmPrivateKey() = default; // Create an RSA-based DRM key. + // + // |hash_algorithm_opt| may be specified to use a RSASSA-PSS + // hash algorithm other than the default of SHA-1. static std::unique_ptr Create( - std::shared_ptr&& rsa_key); + std::shared_ptr&& rsa_key, + DrmSignatureHashAlgorithmOption hash_algorithm_opt = std::nullopt); static std::unique_ptr Create( - std::unique_ptr&& rsa_key); + std::unique_ptr&& rsa_key, + DrmSignatureHashAlgorithmOption hash_algorithm_opt = std::nullopt); // Create an ECC-based DRM key. static std::unique_ptr Create( std::shared_ptr&& ecc_key); @@ -39,6 +114,14 @@ class DrmPrivateKey { bool IsRsaKey() const { return static_cast(rsa_key_); } bool IsEccKey() const { return static_cast(ecc_key_); } + // Get the OEMCrypto signature hash algorithm used by this + // key when calling GenerateSignature(). + OEMCrypto_SignatureHashAlgorithm GetSignatureHashAlgorithm() const; + // Similar to the above, but behaves exactly as the OEMCrypto + // API OEMCrypto_GetSignatureHashAlgorithm(). + OEMCryptoResult GetSignatureHashAlgorithm( + OEMCrypto_SignatureHashAlgorithm* algorithm) const; + // Generates a session key from the key source. // For RSA keys, |key_source| is an encrypted session key. // For ECC keys, |key_source| is a ephemeral public key to be @@ -63,10 +146,11 @@ class DrmPrivateKey { size_t* signature_length) const; std::vector GenerateSignature( const std::vector& message) const; + std::vector GenerateSignature(const std::string& message) const; size_t SignatureSize() const; // Generates a signature for the provided message. - // For RSA keys, the signature is RSASSA-PKCS1. + // For RSA keys, the signature is CAST's special RSASSA-PKCS1. // For ECC keys, this is not supported. OEMCryptoResult GenerateRsaSignature(const uint8_t* message, size_t message_length, @@ -76,12 +160,17 @@ class DrmPrivateKey { const std::vector& message) const; private: - DrmPrivateKey() {} + DrmPrivateKey(std::shared_ptr&& ecc_key); + DrmPrivateKey(std::shared_ptr&& rsa_key, + OEMCrypto_SignatureHashAlgorithm rsa_hash_algorithm); // Only one will be set. std::shared_ptr ecc_key_; std::shared_ptr rsa_key_; -}; + // |rsa_hash_algorithm_| is only used with |rsa_key_| when + // generating non-CAST signatures. + OEMCrypto_SignatureHashAlgorithm rsa_hash_algorithm_ = OEMCrypto_SHA1; +}; // class DrmPrivateKey } // namespace util } // namespace wvoec #endif // WVOEC_UTIL_DRM_KEY_H_ diff --git a/oemcrypto/util/include/oemcrypto_ecc_key.h b/oemcrypto/util/include/oemcrypto_ecc_key.h index ac69505..f2a29dc 100644 --- a/oemcrypto/util/include/oemcrypto_ecc_key.h +++ b/oemcrypto/util/include/oemcrypto_ecc_key.h @@ -16,6 +16,7 @@ #include +#include "OEMCryptoCENC.h" #include "OEMCryptoCENCCommon.h" #include "wv_class_utils.h" @@ -107,13 +108,42 @@ class EccPublicKey { static std::unique_ptr LoadPrivateKeyInfo( const std::vector& buffer); + // Loads the public key from the raw key point. + // + // The provided |x| and |y| are decoded as big-endian integers + // to form the key point. + static std::unique_ptr LoadKeyPoint( + EccCurve curve, const std::vector& x, + const std::vector& y); + // Same as the above, except the Y coordinate is derived from + // the |curve| and |x| coordinate, using the |y_sign_bit|. + // If |y_sign_bit| is true, then Y is negative, if false than + // Y is positive. + static std::unique_ptr LoadKeyPoint( + EccCurve curve, const std::vector& x, bool y_sign_bit); + + // Loads the public key from the private scalar. + // + // The provided private scalar |d| should be the big-endian + // encoded integer. + static std::unique_ptr LoadFromPrivateScalar( + EccCurve curve, const std::vector& d); + EccCurve curve() const { return curve_; } const EC_KEY* GetEcKey() const { return key_; } + // Returns the signing hash algorithm used by this key. + // Result is determined by the curve. + OEMCrypto_SignatureHashAlgorithm GetSignatureHashAlgorithm() const; + // Checks if the provided |private_key| is the EC private key of this // public key. bool IsMatchingPrivateKey(const EccPrivateKey& private_key) const; + // Checks if the provided |other_key| is the same key as this public + // key. + bool IsMatchingPublicKey(const EccPublicKey& other_key) const; + // Serializes the public key into an ASN.1 DER encoded SubjectPublicKey // representation. // On success, |buffer_size| is populated with the number of bytes @@ -136,6 +166,14 @@ class EccPublicKey { // Returns an empty vector on error. std::vector SerializeAsSec1KeyPoint(bool compressed = false) const; + // Serializes the ECC key point by their components. The X and Y + // coordinates are encoded as fixed-width big-endian integer. + std::vector SerializeXCoord() const; + std::vector SerializeYCoord() const; + // Returns the Y sign-bit. Returns true if Y is negative, false if + // Y is positive. + bool GetYSignBit() const; + // Verifies the |signature| matches the provided |message| by the // private equivalent of this public key. // The |signature| should be a valid ASN.1 DER encoded @@ -191,6 +229,16 @@ class EccPublicKey { // SEC 1 encoded EC key point |buffer|. bool InitFromSec1KeyPoint(EccCurve curve, const uint8_t* buffer, size_t length); + // Initializes the public key object from the provided |curve| and + // components |x| and |y| (or |y_sign_bit|). + // Internally uses SEC 1. + bool InitFromKeyPoint(EccCurve curve, const std::vector& x, + const std::vector& y); + bool InitFromKeyPoint(EccCurve curve, const std::vector& x, + bool y_sign_bit); + // Initializes the public key object from the provided |curve| and + // private scalar. + bool InitFromPrivateScalar(EccCurve curve, const std::vector& d); // Digests the |message| and verifies signature against the provided // ECDSA signature point |sig_point|. OEMCryptoResult DigestAndVerify(const uint8_t* message, size_t message_length, @@ -250,6 +298,13 @@ class EccPrivateKey { static std::unique_ptr Load( const std::vector& buffer); + // Loads the private key from the private scalar. + // + // The provided private scalar |d| should be the big-endian + // encoded integer. + static std::unique_ptr LoadPrivateScalar( + EccCurve curve, const std::vector& d); + // Creates a new ECC public key of this private key. // Equivalent to calling EccPublicKey::New with this private // key. @@ -262,6 +317,10 @@ class EccPrivateKey { // private key. bool IsMatchingPublicKey(const EccPublicKey& public_key) const; + // Returns the signing hash algorithm used by this key. + // Result is determined by the curve. + OEMCrypto_SignatureHashAlgorithm GetSignatureHashAlgorithm() const; + // Serializes the private key into an ASN.1 DER encoded PrivateKeyInfo // representation. // On success, |buffer_size| is populated with the number of bytes @@ -293,6 +352,16 @@ class EccPrivateKey { std::vector SerializeAsPublicSec1KeyPoint( bool compressed = false) const; + // Serializes the ECC key point by their components. The X and Y + // coordinates are encoded as fixed-width big-endian integer. + std::vector SerializeXCoord() const; + std::vector SerializeYCoord() const; + // Returns the Y sign-bit. Returns true if Y is negative, false if + // Y is positive. + bool GetYSignBit() const; + // Returns the big-endian encoded private scalar. + std::vector SerializePrivateScalar() const; + // Signs the provided |message| and serializes the signature // point to |signature| as a ASN.1 DER encoded ECDSA-Sig-Value. // This implementation uses ECDSA with the following digest @@ -348,6 +417,9 @@ class EccPrivateKey { bool InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length); // Generates a new key based on the provided curve. bool InitFromCurve(EccCurve curve); + // Initializes the private key object from the provided |curve| and + // private scalar. + bool InitFromPrivateScalar(EccCurve curve, const std::vector& d); // OpenSSL/BoringSSL implementation of an ECC key. // The public point of the key will always be present. diff --git a/oemcrypto/util/include/oemcrypto_ed_key.h b/oemcrypto/util/include/oemcrypto_ed_key.h new file mode 100644 index 0000000..124f2d6 --- /dev/null +++ b/oemcrypto/util/include/oemcrypto_ed_key.h @@ -0,0 +1,282 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#ifndef WVOEC_UTIL_ED_KEY_H_ +#define WVOEC_UTIL_ED_KEY_H_ + +#include +#include + +#include +#include +#include + +#include + +#include "OEMCryptoCENCCommon.h" +#include "wv_class_utils.h" + +namespace wvoec { +namespace util { +// Ed25519 keys are special asymmetric keys which can only be used for +// signing and signature verifications. + +// Forward declarations. +class EdPrivateKey; + +// High-level wrapper around an Ed25519 signature verification key. +// Compatible with both OpenSSL and BoringSSL. +class EdPublicKey { + public: + ~EdPublicKey(); + WVCDM_DISALLOW_COPY_AND_MOVE(EdPublicKey); + + // Creates a new public key equivalent of the provided private key. + static std::unique_ptr New(const EdPrivateKey& private_key); + + // Loads a ASN.1 DER serialized ED public key. + // + // The provided |buffer| must contain a valid ASN.1 DER encoded + // SubjectPublicKey. Only supported algorithm is id-Ed25519(1.3.101.112) + // + // buffer: SubjectPublicKeyInfo = { + // algorithm: AlgorithmIdentifier = { + // algorithm: OID = id-Ed25519 + // -- No parameters for Ed25519 + // }, + // subjectPublicKey: BIT STRING = ... -- RFC 8032 encoded public key. + // } + // + // Failure will occur if the provided |buffer| does not contain a + // valid SubjectPublicKey, or if the specified curve is not + // supported. + static std::unique_ptr Load(const uint8_t* buffer, + size_t length); + static std::unique_ptr Load(const std::string& buffer); + static std::unique_ptr Load(const std::vector& buffer); + + // Load raw Ed25519 key (RFC 8032, section 3.1 format). + static std::unique_ptr LoadRaw(const uint8_t* buffer, + size_t length); + static std::unique_ptr LoadRaw(const std::string& buffer); + static std::unique_ptr LoadRaw( + const std::vector& buffer); + + // Load raw Ed25519 private key as public key (RFC 8032, + // section 3.1 format). + static std::unique_ptr LoadFromRawPrivate(const uint8_t* buffer, + size_t length); + static std::unique_ptr LoadFromRawPrivate( + const std::string& buffer); + static std::unique_ptr LoadFromRawPrivate( + const std::vector& buffer); + + // Checks if the provided |private_key| is the Ed25519 private + // key of this public key. + bool IsMatchingPrivateKey(const EdPrivateKey& private_key) const; + + // Get handle of OpenSSL/BoringSSL EVP key. + const EVP_PKEY* GetEvpPkey() const { return key_; } + + // Serializes the public key into an ASN.1 DER encoded SubjectPublicKey + // representation. + // On success, |buffer_size| is populated with the number of bytes + // written to |buffer|, and OEMCrypto_SUCCESS is returned. + // If the provided |buffer_size| is too small, ERROR_SHORT_BUFFER + // is returned and |buffer_size| is set to the required buffer size. + OEMCryptoResult SerializeSubjectPublicKeyInfo(uint8_t* buffer, + size_t* buffer_size) const; + // Same as above, except directly returns the serialized key. + // Returns an empty vector on error. + std::vector SerializeSubjectPublicKeyInfo() const; + // Serializes raw public key value. + // The encoding is defined in RFC 8032, section 3.1, and should be + // 32-bytes for all key values. + std::vector SerializeRaw() const; + + // Verifies the |signature| matches the provided |message| by the + // private equivalent of this public key. + // + // Uses PureEdDSA. The signature should be the raw encoded + // signature output. + // + // Returns: + // OEMCrypto_SUCCESS if signature is valid + // OEMCrypto_ERROR_SIGNATURE_FAILURE if the signature is invalid + // Any other result indicates an unexpected error + OEMCryptoResult VerifySignature(const uint8_t* message, size_t message_length, + const uint8_t* signature, + size_t signature_length) const; + OEMCryptoResult VerifySignature(const std::string& message, + const std::string& signature) const; + OEMCryptoResult VerifySignature(const std::vector& message, + const std::vector& signature) const; + + private: + EdPublicKey() = default; + + // All of the following Init*() functions are call exactly + // once by the protected constructors. + // + // The |buffer| and |length| parameters should be checked by the + // caller. + + // Initializes |key_| from an ASN.1 DER encoded SubjectPublicKeyInfo. + bool InitFromSubjectPublicKeyInfo(const uint8_t* buffer, size_t length); + // Initializes |key_| from an ASN.1 DER encoded PrivateKeyInfo. + bool InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length); + // Initializes |key_| from an RFC 8032 raw public key. + bool InitFromRawPublicKey(const uint8_t* buffer, size_t length); + // Initializes |key_| from an RFC 8032 raw private key. + bool InitFromRawPrivateKey(const uint8_t* buffer, size_t length); + // Initializes |key_| from the private key. + bool InitFromPrivateKey(const EdPrivateKey& private_key); + + EVP_PKEY* key_ = nullptr; +}; // class EdPublicKey + +// High-level wrapper around an Ed25519 signing key. +// Compatible with both OpenSSL and BoringSSL. +class EdPrivateKey { + public: + ~EdPrivateKey(); + WVCDM_DISALLOW_COPY_AND_MOVE(EdPrivateKey); + + // Creates a new, pseudorandom Ed25519 private key. + static std::unique_ptr New(); + + // Loads a serialized Ed25519 private key. + // The provided |buffer| must contain a valid ASN.1 DER encoded + // PrivateKeyInfo/OneAsymmetricKey containing a valid Ed25519 key + // description. + // + // Note: ASN.1 OneAsymmetricKey and PrivateKeyInfo sequences + // are compatible with each other. + // + // PrivateKeyInfo ::= { + // version: INTEGER = v1(0) | v2(1), + // privateKeyAlgorithm: AlgorithmIdentifier := { + // algorithm: OID = id-Ed25519 + // -- no parameters + // }, + // privateKey: OCTET STRING = ..., -- BER encoding of CurvePrivateKey + // attributes: [0] Attributes OPTIONAL, -- Not used + // publicKey: [1] BIT STRING OPTIONAL -- RFC 8032 encoded private key. + // } + // + // CurvePrivateKey ::= OCTET STRING -- RFC 8032 encoded Ed25519 private key. + // + // Note: If the public key is not included, then it is computed from + // the private key. + // + // References: + // RFC 5958 - Description of OneAsymmetricKey/PrivateKeyInfo + // RFC 8410 - OneAsymmetricKey for Ed25519 + // RFC 8032 - Ed25519 public/private key encoding. + // + // Failure will occur if the provided |buffer| does not contain a + // valid PrivateKeyInfo, key is not an ECC key, the specified + // curve is not supported, or the key is not valid. + static std::unique_ptr Load(const uint8_t* buffer, + size_t length); + static std::unique_ptr Load(const std::string& buffer); + static std::unique_ptr Load(const std::vector& buffer); + + // Load raw Ed25519 public key (RFC 8032, section 3.1 format). + // Note: This is called "seed" in Open Profile for DICE. + static std::unique_ptr LoadRaw(const uint8_t* buffer, + size_t length); + static std::unique_ptr LoadRaw(const std::string& buffer); + static std::unique_ptr LoadRaw( + const std::vector& buffer); + + // Creates a new Ed25519 public key of this private key. + // Equivalent to calling EdPublicKey::New with this private + // key. + std::unique_ptr MakePublicKey() const; + + // Get handle of OpenSSL/BoringSSL EVP key. + const EVP_PKEY* GetEvpPkey() const { return key_; } + + // Checks if the provided |public_key| is the EC public key of this + // private key. + bool IsMatchingPublicKey(const EdPublicKey& public_key) const; + + // Serializes the private key into an ASN.1 DER encoded PrivateKeyInfo + // representation. + // On success, |buffer_size| is populated with the number of bytes + // written to |buffer|, and SUCCESS is returned. + // If the provided |buffer_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |buffer_size| is + // set to the required buffer size. + OEMCryptoResult SerializePrivateKeyInfo(uint8_t* buffer, + size_t* buffer_size) const; + // Same as above, except directly returns the serialized key. + // Returns an empty vector on error. + std::vector SerializePrivateKeyInfo() const; + + // Serializes the public component of the private key into an ASN.1 + // DER encoded SubjectPublicKey representation. + // On success, |buffer_size| is populated with the number of bytes + // written to |buffer|, and SUCCESS is returned. + // If the provided |buffer_size| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |buffer_size| is + // set to the required buffer size. + OEMCryptoResult SerializeAsSubjectPublicKeyInfo(uint8_t* buffer, + size_t* buffer_size) const; + // Same as above, except directly returns the serialized key. + // Returns an empty vector on error. + std::vector SerializeAsSubjectPublicKeyInfo() const; + + // Serializes as a raw private key. + std::vector SerializeRaw() const; + // Serializes as a raw public key. + std::vector SerializeAsRawPublicKey() const; + + // Signs the provided |message| and generates a signature + // to |signature| as an RFC 8032 signature value. + // Signs using PureEdDSA. + // + // On success, |signature_length| is populated with the number of + // bytes written to |signature|, and SUCCESS is returned. + // If the provided |signature_length| is too small, + // OEMCrypto_ERROR_SHORT_BUFFER is returned and |signature_length| + // is set to the required signature size. + OEMCryptoResult GenerateSignature(const uint8_t* message, + size_t message_length, uint8_t* signature, + size_t* signature_length) const; + // Same as above, except directly returns the serialized signature. + // Returns an empty vector on error. + std::vector GenerateSignature( + const std::vector& message) const; + std::vector GenerateSignature(const std::string& message) const; + // Returns the signature size. + // Note: Ed25519 signatures are always exactly 64-bytes (unique compared + // to other asymmetric key types). + size_t SignatureSize() const; + + private: + EdPrivateKey() = default; + + // All of the following Init*() functions are call exactly + // once by the protected constructors. + // + // The |buffer| and |length| parameters should be checked by the + // caller. + + // Initializes |key_| from an ASN.1 DER encoded PrivateKeyInfo. + bool InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length); + + // Initializes |key_| from an RFC 8032 raw private key. + bool InitFromRaw(const uint8_t* buffer, size_t length); + // Initializes |key_| with a new pseudorandom Ed25519 private key + bool InitNew(); + + EVP_PKEY* key_ = nullptr; +}; // class EdPrivateKey +} // namespace util +} // namespace wvoec +#endif // WVOEC_UTIL_ED_KEY_H_ diff --git a/oemcrypto/util/include/oemcrypto_oem_cert.h b/oemcrypto/util/include/oemcrypto_oem_cert.h index cd17791..9e52084 100644 --- a/oemcrypto/util/include/oemcrypto_oem_cert.h +++ b/oemcrypto/util/include/oemcrypto_oem_cert.h @@ -13,12 +13,11 @@ #include #include "OEMCryptoCENCCommon.h" +#include "oemcrypto_oem_cert_chain.h" #include "wv_class_utils.h" namespace wvoec { namespace util { -class OemPublicCertificate; - // An OEM Certificate is a factory provisioned root of trust // certificate which consists of a public certificate and its // matching private key. @@ -35,14 +34,7 @@ class OemPublicCertificate; // the reference implementation uses PKCS8 PrivateKeyInfo. class OemCertificate { public: - enum KeyType { - kNone = 0, - // Private key is an ASN.1 DER encoded PrivateKeyInfo specifying - // an RSA encryption key. - kRsa = 1 - }; - - ~OemCertificate(); + ~OemCertificate() = default; WVCDM_DISALLOW_COPY_AND_MOVE(OemCertificate); // Creates a new OEM Certificate and performs basic validation @@ -50,8 +42,8 @@ class OemCertificate { // The |public_cert| provided is parsed as an X.509 Certificate // and the public key is verified against the private key. // The |private_key| is parsed depending on the key type. - // If any error occurs or if the provided data is malformed, an - // empty pointer is returned. + // If any error occurs or if the provided data is malformed, a + // null pointer is returned. static std::unique_ptr Create(const uint8_t* private_key, size_t private_key_size, const uint8_t* public_cert, @@ -61,13 +53,18 @@ class OemCertificate { const std::vector& public_cert); // Returns the key type of the OEM Public key and private key. - // As of OEMCrypto v16, the only supported key type is RSA. - KeyType key_type() const; + // As of OEMCrypto v19, the only supported key types are RSA + // and ECC. + OemCertKeyType GetKeyType() const; // Returns the private key data. Intended to be used for calls // to OEMCrypto_LoadOEMPrivateKey(). const std::vector& GetPrivateKey() const { return private_key_; } + // Returns the ASN.1 DER encoded SubjectPublicKeyInfo of + // the device public key. + std::vector SerializePublicKey() const; + // Returns a copy of the ASN.1 DER encoded PKCS #7 certificate chain. // If |*public_cert_length| is large enough, the complete // certificate is copied to the buffer specified by |public_cert|, @@ -91,12 +88,12 @@ class OemCertificate { OEMCryptoResult IsCertificateValid() const; private: - OemCertificate(); + OemCertificate() = default; // Serialized private key matching the OEM certificate. std::vector private_key_; // Serialized OEM Certificate. - std::unique_ptr public_cert_; + std::unique_ptr cert_chain_; }; // class OemCertificate } // namespace util } // namespace wvoec diff --git a/oemcrypto/util/include/oemcrypto_oem_cert_chain.h b/oemcrypto/util/include/oemcrypto_oem_cert_chain.h new file mode 100644 index 0000000..ea66aad --- /dev/null +++ b/oemcrypto/util/include/oemcrypto_oem_cert_chain.h @@ -0,0 +1,130 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#ifndef WVOEC_UTIL_OEM_CERT_CHAIN_H_ +#define WVOEC_UTIL_OEM_CERT_CHAIN_H_ + +#include + +#include +#include + +#include +#include + +#include "scoped_object.h" +#include "wv_class_utils.h" + +namespace wvoec { +namespace util { +class OemPublicCertificate; + +enum class OemCertKeyType { + kUnknown = 0, + // RSA: Either RSA-2048 or RSA-3072 + kRsa = 1, + // ECC: Either secp256r1, secp384r1 or secp521r1. + kEcc = 2, +}; +const char* OemCertKeyTypeToString(OemCertKeyType type); + +// Wrapper around an OEMCrypto OEM Certificate Chain as used by +// the device (sometimes referred simply as OEM Certificate). +// +// The OEM certificate chain is used by Provisioning 3.0 +// (chain built-in at the factory) and Provisioning 4.0 (chain +// is provided by the server). +// +// The certificate chain sequence of X.509 Certificates, with +// the first certificate being the "leaf" or "device" certificate, +// followed by one "intermediate" certificate. +// +// The "leaf" certificate is signed by the "intermediate" certificate. +// The "intermediate" certificate is signed by the "Widevine ROT" +// certificate. +// The "Widevine ROT" certificate is not available on the device. +// +// This implementation has limited utility, and is primarily used +// to verify small parts of information from within the certificate +// chain. +// +// The certificate chain must be an ASN.1 DER encoded PKCS#7 +// ContentInfo of type signedData (RFC2315). The "certificates" +// field of the SignedData must contain two X.509 Certificates. +// +// For Provisioning 3.0, only RSA leaf keys are expected; for +// Provisioning 4.0, only RSA or ECC keys are expected. +// This requirement is not enforced by this class. +class OemCertificateChain { + public: + using ScopedPkcs7 = ScopedObject; + + static constexpr size_t kExpectedCertCount = 2; // Leaf and intermediate. + static constexpr size_t kDeviceCertIndex = 0; + static constexpr size_t kIntermediateCertIndex = 1; + + ~OemCertificateChain() = default; + WVCDM_DISALLOW_COPY_AND_MOVE(OemCertificateChain); + + static std::unique_ptr LoadPkcs7(const uint8_t* message, + size_t message_length); + static std::unique_ptr LoadPkcs7( + const std::vector& message); + + const PKCS7* pkcs7() const { return pkcs7_.get(); } + + const std::vector& cert_data() const { return cert_data_; } + + // Returns a pointer to the device/intermediate certificate. + // Note: The point is owned by the OemCertificateChain instance. + const OemPublicCertificate* device_cert() const { return device_cert_.get(); } + const OemPublicCertificate* intermediate_cert() const { + return intermediate_cert_.get(); + } + + private: + OemCertificateChain() = default; + + bool InitFromPkcs7(const uint8_t* message, size_t message_length); + + // |pkcs7_| will only ever contain the PKCS7 instance containing + // SignedData (OpenSSL's PKCS7_SIGNED). + ScopedPkcs7 pkcs7_; + // Copy of the certificate data used to initialize + // the object. + std::vector cert_data_; + + std::unique_ptr device_cert_; + std::unique_ptr intermediate_cert_; +}; // class OemCertificateChain + +class OemPublicCertificate { + public: + OemPublicCertificate() = delete; + ~OemPublicCertificate() = default; + WVCDM_DISALLOW_COPY_AND_MOVE(OemPublicCertificate); + + // == Subject Public Key Information == + // Get the enumerated types of public keys. + // If not one of the support OemCertKeyType, then kUnknown + // is returned. + OemCertKeyType GetSubjectPublicKeyType() const; + // Returns the ASN.1 DER encoded SubjectPublicKeyInfo. + std::vector SerializedSubjectPublicKeyInfo() const; + + private: + explicit OemPublicCertificate(X509* cert) : cert_(cert) {} + + // |cert_| is not owned by the class. + X509* cert_ = nullptr; + + // Only allow OemCertificateChain to create an instance of + // OemPublicCertificate. + friend class OemCertificateChain; +}; +} // namespace util +} // namespace wvoec +#endif // WVOEC_UTIL_OEM_CERT_CHAIN_H_ diff --git a/oemcrypto/util/include/oemcrypto_rsa_key.h b/oemcrypto/util/include/oemcrypto_rsa_key.h index 0eb7e77..012a49c 100644 --- a/oemcrypto/util/include/oemcrypto_rsa_key.h +++ b/oemcrypto/util/include/oemcrypto_rsa_key.h @@ -24,30 +24,60 @@ namespace util { enum RsaFieldSize { kRsaFieldUnknown = 0, kRsa2048Bit = 2048, - kRsa3072Bit = 3084 + kRsa3072Bit = 3072 }; // Identifies the RSA signature algorithm to be used when signing // messages or verifying message signatures. // The two standard signing algorithms specified by PKCS1 RSA V2.1 // are RSASSA-PKCS1 and RSASSA-PSS. Each require agreement on a -// set of options. For OEMCrypto, only one set of options are agreed -// upon for each RSA signature scheme. CAST receivers specify a -// special implementation of PKCS1 where the message is already -// digested and encoded when provided. +// set of options. +// For RSASSA-PSS in OEMCrypto v18+, a limited variety of options +// are available. +// For CAST receivers specify a special implementation of PKCS1 +// where the message is already digested and encoded when provided. enum RsaSignatureAlgorithm { - // RSASSA-PSS with default options: + // RSASSA-PSS with options: // Hash algorithm: SHA-1 // MGF: MGF1 with SHA-1 // Salt length: 20 bytes // Trailer field: 0xbc - kRsaPssDefault = 0, + kRsaPssSha1 = 0, + // RSASSA-PSS with options: + // Hash algorithm: SHA-256 + // MGF: MGF1 with SHA-1 + // Salt length: 20 bytes + // Trailer field: 0xbc + kRsaPssSha256 = 256, + // RSASSA-PSS with options: + // Hash algorithm: SHA-384 + // MGF: MGF1 with SHA-1 + // Salt length: 20 bytes + // Trailer field: 0xbc + kRsaPssSha384 = 384, + // RSASSA-PSS with options: + // Hash algorithm: SHA-512 + // MGF: MGF1 with SHA-1 + // Salt length: 20 bytes + // Trailer field: 0xbc + kRsaPssSha512 = 512, // RSASSA-PKCS1 for CAST receivers. + // This is a very special version of RSASSA-PKCS1 and should + // only be used for CAST RSA signature; do not use this in + // place of a standard PKCS1 signing algorithm. // Assumes message is already digested & encoded. Max message length // is 83 bytes. - kRsaPkcs1Cast = 1 + kRsaPkcs1Cast = 83, + // The default RSASSA-PSS specification uses SHA-1. + kRsaPssDefault = kRsaPssSha1, }; +// Returns one of the kRsaPss* algorithm types based on the +// OEMCrypto signature hash algorithm types. +// If |hash_algorithm| is undefined, then kRsaPssDefault is returned. +RsaSignatureAlgorithm RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm( + OEMCrypto_SignatureHashAlgorithm hash_algorithm); + // Returns the string representation of the provided RSA field size. // Intended for logging purposes. std::string RsaFieldSizeToString(RsaFieldSize field_size); @@ -87,7 +117,7 @@ class RsaPublicKey { // } // // Failure will occur if the provided |buffer| does not contain a - // valid SubjectPublicKey, or if the specified curve is not + // valid SubjectPublicKeyInfo, or if the specified key size is not // supported. static std::unique_ptr Load(const uint8_t* buffer, size_t length); @@ -102,6 +132,24 @@ class RsaPublicKey { static std::unique_ptr LoadPrivateKeyInfo( const std::vector& buffer); + // Loads a serialized RSA public key. + // The provided |buffer| must contain a valid ASN.1 DER encoded + // RSAPublicKey (RFC 3447 / RFC 8017). + // + // buffer: RSAPublicKey ::= SEQUENCE { + // modulus INTEGER, -- n + // publicExponent INTEGER, -- e + // } + // + // Failure will occur if the provided |buffer| does not contain a + // valid RSAPublicKey, or if the specified key size is not supported. + static std::unique_ptr LoadRsaPublicKey(const uint8_t* buffer, + size_t length); + static std::unique_ptr LoadRsaPublicKey( + const std::string& buffer); + static std::unique_ptr LoadRsaPublicKey( + const std::vector& buffer); + RsaFieldSize field_size() const { return field_size_; } uint32_t allowed_schemes() const { return allowed_schemes_; } const RSA* GetRsaKey() const { return key_; } @@ -109,6 +157,9 @@ class RsaPublicKey { // Checks if the provided |private_key| is the RSA private key of this // public key. bool IsMatchingPrivateKey(const RsaPrivateKey& private_key) const; + // Checks if the provided |other_key| is the same key as this public + // key. + bool IsMatchingPublicKey(const RsaPublicKey& other_key) const; // Serializes the public key into an ASN.1 DER encoded SubjectPublicKey // representation. @@ -122,12 +173,14 @@ class RsaPublicKey { // Returns an empty vector on error. std::vector Serialize() const; + // Serializes the public key into an ASN.1 DER encoded RSAPublicKey + // representation. + std::vector SerializeAsRsaPublicKey() const; + // Verifies the |signature| matches the provided |message| by the // private equivalent of this public key. // The signature algorithm can be specified via the |algorithm| field. // See RsaSignatureAlgorithm for details on each algorithm. - // For RSASSA-PSS, the hash algorithm can be specified via |hash_algorithm|. - // This parameter is ignored for other signature algorithms. // // Returns: // OEMCrypto_SUCCESS if signature is valid @@ -135,17 +188,15 @@ class RsaPublicKey { // OEMCrypto_ERROR_UNKNOWN_FAILURE if any error occurs OEMCryptoResult VerifySignature( const uint8_t* message, size_t message_length, const uint8_t* signature, - size_t signature_length, RsaSignatureAlgorithm algorithm = kRsaPssDefault, - OEMCrypto_SignatureHashAlgorithm hash_algorithm = OEMCrypto_SHA1) const; + size_t signature_length, + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; OEMCryptoResult VerifySignature( const std::string& message, const std::string& signature, - RsaSignatureAlgorithm algorithm = kRsaPssDefault, - OEMCrypto_SignatureHashAlgorithm hash_algorithm = OEMCrypto_SHA1) const; + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; OEMCryptoResult VerifySignature( const std::vector& message, const std::vector& signature, - RsaSignatureAlgorithm algorithm = kRsaPssDefault, - OEMCrypto_SignatureHashAlgorithm hash_algorithm = OEMCrypto_SHA1) const; + RsaSignatureAlgorithm algorithm = kRsaPssDefault) const; // Encrypts the OEMCrypto session key used for deriving other keys. // On success, |enc_session_key_size| is populated with the number @@ -187,6 +238,9 @@ class RsaPublicKey { // In case of any failure, false is return and the key should be // discarded. bool InitFromSubjectPublicKeyInfo(const uint8_t* buffer, size_t length); + // Initializes from RSAPublicKey. + bool InitFromRsaPublicKey(const uint8_t* buffer, size_t length); + // Initializes from PrivateKeyInfo. bool InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length); // Initializes the public key object from a private. bool InitFromPrivateKey(const RsaPrivateKey& private_key); @@ -196,10 +250,12 @@ class RsaPublicKey { bool InitFromSslHandle(const RSA* rsa_handle, uint32_t allowed_schemes); // Signature specialization functions. - OEMCryptoResult VerifySignaturePss( - const uint8_t* message, size_t message_length, const uint8_t* signature, - size_t signature_length, - OEMCrypto_SignatureHashAlgorithm hash_algorithm) const; + // Note: |algorithm| must be a PSS type. + OEMCryptoResult VerifySignaturePss(const uint8_t* message, + size_t message_length, + const uint8_t* signature, + size_t signature_length, + RsaSignatureAlgorithm algorithm) const; OEMCryptoResult VerifySignaturePkcs1Cast(const uint8_t* message, size_t message_length, const uint8_t* signature, @@ -229,18 +285,16 @@ class RsaPrivateKey { // The provided |buffer| must contain a valid ASN.1 DER encoded // PrivateKeyInfo (RFC 5208). // - // buffer: PrivateKeyInfo = { + // buffer: PrivateKeyInfo ::= SEQUENCE { // version: INTEGER = v1(0), // privateKeyAlgorithm: OID = rsaEncryption, // privateKey: OCTET STRING = ..., // -- BER encoding of RSAPrivateKey (RFC 3447) // attributes: Attributes = ... -- Optional, not used by OEMCrypto // } - // Note: If the public key is not included, then it is computed from - // the private. // // Failure will occur if the provided |buffer| does not contain a - // valid RSAPrivateKey, or if the specified curve is not supported. + // valid PrivateKeyInfo, or if the specified key size is not supported. static std::unique_ptr Load(const uint8_t* buffer, size_t length); static std::unique_ptr Load(const std::string& buffer); @@ -348,13 +402,17 @@ class RsaPrivateKey { // Initializes the public key object using the provided |buffer|. // In case of any failure, false is return and the key should be // discarded. + // From PrivateKeyInfo. bool InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length); + // Generates a new key based on the provided field size. bool InitFromFieldSize(RsaFieldSize field_size); // Signature specialization functions. + // Note: |algorithm| must be a PSS type. OEMCryptoResult GenerateSignaturePss(const uint8_t* message, size_t message_length, + RsaSignatureAlgorithm algorithm, uint8_t* signature, size_t* signature_length) const; OEMCryptoResult GenerateSignaturePkcs1Cast(const uint8_t* message, diff --git a/oemcrypto/util/include/wvcrc32.h b/oemcrypto/util/include/wvcrc32.h index 67f5ed0..6b72830 100644 --- a/oemcrypto/util/include/wvcrc32.h +++ b/oemcrypto/util/include/wvcrc32.h @@ -1,22 +1,123 @@ -// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary // source code may only be used and distributed under the Widevine // License Agreement. // -// Compute CRC32/MPEG2 Checksum. Needed for verification of WV Keybox. +// Widevine's CRC-32 Algorithm. // #ifndef WVOEC_UTIL_WVCRC32_H_ #define WVOEC_UTIL_WVCRC32_H_ +// This is intended to be a stand-alone file. +// Do not include any non-standard library headers. #include +#include +#include + namespace wvoec { namespace util { -uint32_t wvcrc32(const uint8_t* p_begin, size_t i_count); -uint32_t wvcrc32Init(); -uint32_t wvcrc32Cont(const uint8_t* p_begin, size_t i_count, uint32_t prev_crc); +// ==== Context-Based Widevine CRC-32 ==== -// Convert to network byte order -uint32_t wvcrc32n(const uint8_t* p_begin, size_t i_count); +// Context-based implementation of Widevine's CRC-32 algorithm. +// Allows for stream-based CRC-32 calculations. +// Note: This class uses a "Fluent Interface" design pattern. +class Crc32Ctx final { + public: + // == CRC32/MPEG2 Parameters == + // Generator polynomial in "normal" form. + static constexpr uint32_t kPolynomial = 0x04c11db7; + // Initial accumulator value when starting a new CRC check. + static constexpr uint32_t kInitialValue = 0xffffffff; + // Output XOR bit mask. + static constexpr uint32_t kOutputMask = 0x00000000; + // Note: Input/output reflection is not supported. + + // == Constructors == + constexpr Crc32Ctx() = default; + // Default copy/move constructors/assignment operators. + constexpr Crc32Ctx(const Crc32Ctx&) = default; + constexpr Crc32Ctx(Crc32Ctx&&) = default; + constexpr Crc32Ctx& operator=(const Crc32Ctx&) = default; + constexpr Crc32Ctx& operator=(Crc32Ctx&&) = default; + // Resuming constructor. + explicit constexpr Crc32Ctx(uint32_t previous_crc) + : accumulator_(previous_crc ^ kOutputMask) {} + + // == State Accessors == + // Obtain the current accumulator value. + constexpr uint32_t accumulator() const { return accumulator_; } + + // == Data Processors == + + // Updates the CRC context with more data. + // + // Returns a self reference. + Crc32Ctx& Update(uint8_t datum) { return UpdateInternal(&datum, 1); } + Crc32Ctx& Update(const uint8_t* data, size_t data_length) { + return UpdateInternal(data, data_length); + } + Crc32Ctx& Update(const std::vector& data); + Crc32Ctx& Update(const std::string& data); + + // Finalizes and produces the resulting CRC value. + // + // The internal accumulator does not update. Future calls + // to Update() are allowed. + constexpr uint32_t Finalize() const { return accumulator_ ^ kOutputMask; } + // Finalize in network-byte-order. + uint32_t FinalizeNbo() const; + // Finalize in network-byte-order to the provided buffer. + // |crc_buffer| must be 4 at least bytes large. + // Always write 4 bytes of data. + bool FinalizeNbo(uint8_t* crc_buffer, size_t crc_buffer_length) const; + + // Restarts the CRC calculator to the initial state. + // Returns a self reference. + constexpr Crc32Ctx& Reset() { + accumulator_ = kInitialValue; + return *this; + } + + // Resumes the CRC calculator from a previously acquired + // CRC value. + // Returns a self reference. + constexpr Crc32Ctx& Resume(uint32_t previous_crc) { + accumulator_ = previous_crc ^ kOutputMask; + return *this; + } + + private: + Crc32Ctx& UpdateInternal(const uint8_t* data, size_t data_length); + + uint32_t accumulator_ = kInitialValue; +}; // class Crc32Ctx + +// ==== C-Like Widevine CRC-32 API ==== + +// Compute Widevine CRC-32 value over the entire message +// all at once. +// +// Parameters: +// |data| - Pointer to message buffer. +// |data_length| - Length of the data pointed to by |data|. +uint32_t wvcrc32(const uint8_t* data, size_t data_length); + +// Get the initial value of the CRC accumulator. +uint32_t wvcrc32Init(); +// Continue computing the Widevine CRC over a new chunk +// of the message. +// +// Parameters: +// |data| - Pointer to message buffer. +// |data_length| - Length of the data pointed to by |data|. +// |prev_crc| - Initial CRC value or previously computed +// CRC value from earlier data. +uint32_t wvcrc32Cont(const uint8_t* data, size_t data_length, + uint32_t prev_crc); + +// Compute Widevine CRC-32 value over the entire message +// all at once, returning the CRC value in network byte order. +uint32_t wvcrc32n(const uint8_t* data, size_t data_length); } // namespace util } // namespace wvoec #endif // WVOEC_UTIL_WVCRC32_H_ diff --git a/oemcrypto/util/oec_ref_util.gyp b/oemcrypto/util/oec_ref_util.gyp deleted file mode 100644 index c2c688c..0000000 --- a/oemcrypto/util/oec_ref_util.gyp +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2022 Google LLC. All Rights Reserved. This file and proprietary -# source code may only be used and distributed under the Widevine -# License Agreement. -{ - 'variables': { - 'oemcrypto_dir': '..', - 'util_dir': '../../util', - }, - 'targets': [ - { - 'target_name': 'oec_ref_util', - 'type': 'static_library', - 'standalone_static_library': 1, - 'hard_dependency': 1, - 'includes': [ - 'oec_ref_util.gypi', - ], - }, - ], -} diff --git a/oemcrypto/util/oec_ref_util.gypi b/oemcrypto/util/oec_ref_util.gypi deleted file mode 100644 index 1a098f9..0000000 --- a/oemcrypto/util/oec_ref_util.gypi +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2022 Google LLC. All Rights Reserved. This file and proprietary -# source code may only be used and distributed under the Widevine -# License Agreement. -{ - 'variables': { - 'privacy_crypto_impl%': 'boringssl', - 'boringssl_libcrypto_path%': '../../third_party/boringssl/boringssl.gyp:crypto', - 'libcppbor_path%': '../../third_party/libcppbor.gyp:cppbor', - }, - 'include_dirs': [ - '<(oemcrypto_dir)/include', - '<(oemcrypto_dir)/util/include', - '<(util_dir)/include', - ], - 'direct_dependent_settings': { - 'include_dirs': [ - '<(oemcrypto_dir)/include', - '<(oemcrypto_dir)/util/include', - ], - }, - 'sources': [ - '<(oemcrypto_dir)/util/src/bcc_validator.cpp', - '<(oemcrypto_dir)/util/src/cbor_validator.cpp', - '<(oemcrypto_dir)/util/src/cmac.cpp', - '<(oemcrypto_dir)/util/src/device_info_validator.cpp', - '<(oemcrypto_dir)/util/src/hmac.cpp', - '<(oemcrypto_dir)/util/src/oemcrypto_drm_key.cpp', - '<(oemcrypto_dir)/util/src/oemcrypto_ecc_key.cpp', - '<(oemcrypto_dir)/util/src/oemcrypto_key_deriver.cpp', - '<(oemcrypto_dir)/util/src/oemcrypto_oem_cert.cpp', - '<(oemcrypto_dir)/util/src/oemcrypto_rsa_key.cpp', - '<(oemcrypto_dir)/util/src/prov4_validation_helper.cpp', - '<(oemcrypto_dir)/util/src/signed_csr_payload_validator.cpp', - '<(oemcrypto_dir)/util/src/wvcrc.cpp', - ], - 'dependencies': [ - '<(libcppbor_path)', - ], - 'includes': [ - '../../util/libcrypto_dependency.gypi', - ], -} diff --git a/oemcrypto/util/oec_ref_util_unittests.gypi b/oemcrypto/util/oec_ref_util_unittests.gypi index eea92f7..67a5673 100644 --- a/oemcrypto/util/oec_ref_util_unittests.gypi +++ b/oemcrypto/util/oec_ref_util_unittests.gypi @@ -17,14 +17,17 @@ 'sources': [ '<(oemcrypto_dir)/util/test/bcc_validator_unittest.cpp', '<(oemcrypto_dir)/util/test/cmac_unittest.cpp', + '<(oemcrypto_dir)/util/test/cose_utils_unittest.cpp', '<(oemcrypto_dir)/util/test/device_info_validator_unittest.cpp', '<(oemcrypto_dir)/util/test/hmac_unittest.cpp', '<(oemcrypto_dir)/util/test/oem_cert_test.cpp', + '<(oemcrypto_dir)/util/test/oemcrypto_cose_key_unittest.cpp', '<(oemcrypto_dir)/util/test/oemcrypto_ecc_key_unittest.cpp', + '<(oemcrypto_dir)/util/test/oemcrypto_ed_key_unittest.cpp', + '<(oemcrypto_dir)/util/test/oemcrypto_oem_cert_chain_unittest.cpp', '<(oemcrypto_dir)/util/test/oemcrypto_oem_cert_unittest.cpp', '<(oemcrypto_dir)/util/test/oemcrypto_ref_test_utils.cpp', '<(oemcrypto_dir)/util/test/oemcrypto_rsa_key_unittest.cpp', - '<(oemcrypto_dir)/util/test/oemcrypto_wvcrc32_unittest.cpp', '<(oemcrypto_dir)/util/test/signed_csr_payload_validator_unittest.cpp', ], 'dependencies': [ diff --git a/oemcrypto/util/src/bcc_validator.cpp b/oemcrypto/util/src/bcc_validator.cpp index 3b6b5f1..03610e9 100644 --- a/oemcrypto/util/src/bcc_validator.cpp +++ b/oemcrypto/util/src/bcc_validator.cpp @@ -380,7 +380,7 @@ std::string Bcc::ToString() const { CborMessageStatus Bcc::Validate( std::vector>& msgs, - bool is_degenerated) const { + bool is_degenerated, const BccValidationConfig& validation_config) const { CborMessageStatus status = kCborParseOk; const std::string component = "Bcc"; CborMessageStatus cur_status = dk_pub.Validate(msgs); @@ -392,12 +392,13 @@ CborMessageStatus Bcc::Validate( ApplyStatus(status, cur_status); if (is_widevine_entry) found_widevine_entry = true; } - if (!is_degenerated && !found_widevine_entry) { - msgs.push_back( - std::make_pair(kCborValidateWarning, - component + ": Widevine cert not found. Expect a BCC " - "entry with component_name: widevine")); - ApplyStatus(status, kCborValidateWarning); + if (validation_config.check_widevine_component_name && !is_degenerated && + !found_widevine_entry) { + msgs.push_back(std::make_pair(kCborValidateError, + component + + ": Widevine cert not found. Expect a BCC " + "entry with component_name: widevine")); + ApplyStatus(status, kCborValidateError); } return status; } @@ -700,7 +701,8 @@ CborMessageStatus BccValidator::Validate() { // More validations on the BCC std::vector> msgs; - CborMessageStatus validate_status = bcc.Validate(msgs, is_degenerated); + CborMessageStatus validate_status = + bcc.Validate(msgs, is_degenerated, validation_config_); ApplyStatus(message_status_, validate_status); for (const auto& msg : msgs) { AddValidationMessage(msg.first, msg.second); diff --git a/oemcrypto/util/src/cose_utils.cpp b/oemcrypto/util/src/cose_utils.cpp new file mode 100644 index 0000000..2ba78c0 --- /dev/null +++ b/oemcrypto/util/src/cose_utils.cpp @@ -0,0 +1,757 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include "cose_utils.h" + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include "log.h" + +namespace wvoec { +namespace util { +namespace { +// == cppbor Utils == + +bool IsItemMapType(const cppbor::Item& item) { + return item.type() == cppbor::MAP; +} + +bool IsItemArrayType(const cppbor::Item& item) { + return item.type() == cppbor::ARRAY; +} + +bool IsItemIntType(const cppbor::Item& item) { + return item.type() == cppbor::UINT || item.type() == cppbor::NINT; +} + +bool IsItemBstrType(const cppbor::Item& item) { + return item.type() == cppbor::BSTR; +} + +bool IsItemTstrType(const cppbor::Item& item) { + return item.type() == cppbor::TSTR; +} + +bool IsItemBoolType(const cppbor::Item& item) { + if (item.type() != cppbor::SIMPLE) return false; + return item.asSimple()->simpleType() == cppbor::BOOLEAN; +} + +bool IsItemNilType(const cppbor::Item& item) { + if (item.type() != cppbor::SIMPLE) return false; + return item.asSimple()->simpleType() == cppbor::NULL_T; +} + +const char* SimpleItemTypeToString(const cppbor::Simple& item) { + switch (item.simpleType()) { + case cppbor::BOOLEAN: + return "bool"; + case cppbor::NULL_T: + return "nil"; + // case cppbor::FLOAT: + // case cppbor::DOUBLE: + // return "float"; + // CDM's version of cppbor is out of date and does not support + // the simple types FLOAT and DOUBLE; however, Android's version does. + // Adding "default" to prevent compilation errors if/when CDM's cppbor + // is updated. + default: + break; + } + return "unknownSimple"; +} + +const char* ItemTypeToString(const cppbor::Item& item) { + switch (item.type()) { + case cppbor::UINT: + case cppbor::NINT: + return "int"; + case cppbor::BSTR: + return "bstr"; + case cppbor::TSTR: + return "tstr"; + case cppbor::ARRAY: + return "array"; + case cppbor::MAP: + return "map"; + case cppbor::SEMANTIC: + // Could be other things, but most common is tag. + return "tag"; + case cppbor::SIMPLE: + return SimpleItemTypeToString(*item.asSimple()); + } + return "unknown"; +} + +// == COSE EC2 & OKP Key Data == + +// Parses the common fields of all COSE_Key values; except kty, which +// should be checked by caller. +// +// Effective CDDL: +// +// COSE_Key = { +// ? 2 => [+ int] ; key_ops* +// ? 3 => int ; alg* +// ? -1 => int ; crv* +// ? -4 => bstr ; d (private key) +// } +// +// Note (*): tstr for 'key_ops' elements, 'alg' and 'crv' is not supported. +template +bool ParseCommonCoseKeyData(const cppbor::Map& cose_key, CoseAnyKeyData* data) { + if (data == nullptr) return false; + + // 'alg' - algorithm; optional + const auto& algorithm_item = cose_key.get(kCoseKeyAlgorithmLabel); + if (!algorithm_item) { + data->algorithm.reset(); + } else if (IsItemIntType(*algorithm_item)) { + data->algorithm = algorithm_item->asInt()->value(); + } else { + LOGE("Expected 'alg' field to be an integer: type = %s", + ItemTypeToString(*algorithm_item)); + return false; + } + + // 'key_ops' - key operations; optional + const auto& key_ops_item = cose_key.get(kCoseKeyKeyOpsLabel); + if (!key_ops_item) { + data->key_ops.clear(); + } else if (IsItemArrayType(*key_ops_item)) { + const cppbor::Array& key_ops = *(key_ops_item->asArray()); + const size_t key_op_count = key_ops.size(); + // Valid CBOR key_ops must have at least 1 element. + if (key_op_count == 0) { + LOGE("Expected 'key_ops' to have at least 1 item"); + return false; + } + data->key_ops.clear(); + data->key_ops.reserve(key_op_count); + for (size_t i = 0; i < key_op_count; i++) { + const auto& key_op_item = key_ops[i]; + if (!key_op_item) { + LOGE("Unexpected error getting key op: i = %zu, count = %zu", i, + key_op_count); + return false; + } + if (IsItemTstrType(*key_op_item)) { + LOGE("Tstr 'key_ops[%zu]' not supported", i); + LOGD("key_ops[%zu] = %s", i, key_op_item->asTstr()->value().c_str()); + return false; + } + if (!IsItemIntType(*key_op_item)) { + LOGE("Expected 'key_ops[%zu]' to be int: type = %s", i, + ItemTypeToString(*key_op_item)); + return false; + } + data->key_ops.push_back(key_op_item->asInt()->value()); + } + } else { + LOGE("Expected 'key_ops' to be an array: type = %s", + ItemTypeToString(*key_ops_item)); + return false; + } + + // 'crv' - curve; required. + const auto& curve_item = cose_key.get(kCoseKeyCurveLabel); + if (!curve_item) { + LOGE("Missing 'crv' field"); + return false; + } + if (!IsItemIntType(*curve_item)) { + LOGE("Expected 'crv' field to be an integer: type = %s", + ItemTypeToString(*curve_item)); + return false; + } + data->curve = curve_item->asInt()->value(); + + // Both EC2 and OKP have the same label for public key. + // 'private_key'; conditionally optional. + const auto& private_key_item = cose_key.get(kCoseKeyPrivateKeyLabel); + if (!private_key_item) { + data->private_key.clear(); + } else if (IsItemBstrType(*private_key_item)) { + data->private_key = private_key_item->asBstr()->value(); + } else { + LOGE("Expected 'private_key' to be bstr: type = %s", + ItemTypeToString(*private_key_item)); + return false; + } + + return true; +} + +// Parses fields of EC2 keys (kty=EC2(2)). +// +// Enforces the requirement that there must be a private key +// and/or a public key present. +// +// Effective CDDL: +// +// COSE_Key = { +// ; Common fields, see ParseCommonCoseKeyData() +// ? -2 => bstr ; x (X coord) +// ? -3 => bstr / bool ; y (Y coord / Y sign bit) +// } +// +bool ParseCoseEc2KeyData(const cppbor::Map& cose_key, CoseEc2KeyData* data) { + if (data == nullptr) return false; + // Parse common fields. + if (!ParseCommonCoseKeyData(cose_key, data)) return false; + + // 'x' - X-coord; conditionally optional + const auto& x_item = cose_key.get(kCoseKeyXCoordLabel); + if (!x_item) { + data->x.clear(); + } else if (IsItemBstrType(*x_item)) { + data->x = x_item->asBstr()->value(); + } else { + LOGE("Expected 'x' to be bstr: type = %s", ItemTypeToString(*x_item)); + return false; + } + + // 'y' - Y-coord; conditionally optional + const auto& y_item = cose_key.get(kCoseKeyYCoordLabel); + if (!y_item) { + // Using empty vector indicate no Y component. + data->y = std::vector(); + } else if (IsItemBstrType(*y_item)) { + // Case of Y as bstr (encoded Y value) + data->y = y_item->asBstr()->value(); + } else if (IsItemBoolType(*y_item)) { + // Case of Y as bool (Y sign bit) + data->y = y_item->asSimple()->asBool()->value(); + } else { + LOGE("Expected 'y' to be bstr or bool: type = %s", + ItemTypeToString(*y_item)); + return false; + } + + // From RFC 9053 section 7.1.1 + // > For public keys, it is REQUIRED that "crv", "x", and "y" be + // > present in the structure. For private keys, it is REQUIRED + // > that "crv" and "d" be present in the structure. + + // Check that if x is present than y must be present and vice versa. + if ((x_item || y_item) && !(x_item && y_item)) { + LOGE("Expected both 'x' and 'y' or neither: x = %s, y = %s", + x_item ? "yes" : "no", y_item ? "yes" : "no"); + return false; + } + + // Check that private key and/or public key are present. + if (data->private_key.empty() && !x_item) { + LOGE("Missing EC2 public and private key: public = %s, private = %s", + data->private_key.empty() ? "no" : "yes", x_item ? "yes" : "no"); + return false; + } + + return true; +} + +// Parses fields of OKP keys (kty=OKP(1)). +// +// Enforces the requirement that there must be a private key +// and/or a public key present. +// +// Effective CDDL: +// +// COSE_Key = { +// ; Common fields, see ParseCommonCoseKeyData() +// ? -2 => bstr ; public_key / x +// } +// +bool ParseCoseOkpKeyData(const cppbor::Map& cose_key, CoseOkpKeyData* data) { + if (data == nullptr) return false; + // Parse common fields. + if (!ParseCommonCoseKeyData(cose_key, data)) return false; + + // 'public_key' / 'x'; conditionally optional + const auto& public_key_item = cose_key.get(kCoseKeyPublicKeyLabel); + if (!public_key_item) { + data->public_key.clear(); + } else if (IsItemBstrType(*public_key_item)) { + data->public_key = public_key_item->asBstr()->value(); + } else { + LOGE("Expected 'public_key' to be bstr: type = %s", + ItemTypeToString(*public_key_item)); + return false; + } + + // From RFC 9053 section 7.2 + // > For public keys, it is REQUIRED that "crv" and "x" be present + // > in the structure. For private keys, it is REQUIRED that "crv" + // > and "d" be present in the structure. + + // Check that private key and/or public key are present. + if (data->private_key.empty() && !public_key_item) { + LOGE("Missing OKP public and/or private key: public = %s, private = %s", + data->private_key.empty() ? "no" : "yes", + public_key_item ? "yes" : "no"); + return false; + } + + return true; +} + +template +void AddCommonCoseKeyFields(const CoseAnyKeyData& data, cppbor::Map* cose_key) { + if (cose_key == nullptr) return; + // 'kty' - key type; required + cose_key->add(kCoseKeyKeyTypeLabel, data.key_type /* kty : int */); + + // 'key_ops' - key operations; optional + if (!data.key_ops.empty()) { + // Only add array to key if non-empty. + cppbor::Array key_ops_array; + for (const int64_t key_op : data.key_ops) { + key_ops_array.add(key_op); + } + cose_key->add(kCoseKeyKeyOpsLabel, + std::move(key_ops_array) /* key_ops : [+ int] */); + } + + // 'alg' - algorithm; optional + if (data.algorithm.has_value()) { + cose_key->add(kCoseKeyAlgorithmLabel, + data.algorithm.value() /* alg : int */); + } + // 'crv' - curve; required + cose_key->add(kCoseKeyCurveLabel, data.curve /* crv : int */); + // Both EC2 and OKP have the same label for public key. + // 'private_key'; conditionally optional. + if (!data.private_key.empty()) { + cose_key->add(kCoseKeyPrivateKeyLabel, data.private_key /* d : bstr */); + } +} + +void AddCoseEc2KeyFields(const CoseEc2KeyData& data, cppbor::Map* cose_key) { + if (cose_key == nullptr) return; + AddCommonCoseKeyFields(data, cose_key); + + // 'x' - X-coord; conditionally optional + if (!data.x.empty()) { + cose_key->add(kCoseKeyXCoordLabel, data.x /* x : bstr */); + } + + // 'y' - Y-coord; conditionally optional + if (std::holds_alternative(data.y)) { + cose_key->add(kCoseKeyYCoordLabel, + cppbor::Bool(std::get(data.y)) /* y : bool */); + } else { + const auto& y = std::get>(data.y); + if (!y.empty()) { + cose_key->add(kCoseKeyYCoordLabel, y /* y : bstr */); + } + } +} + +void AddCoseOkpKeyFields(const CoseOkpKeyData& data, cppbor::Map* cose_key) { + if (cose_key == nullptr) return; + AddCommonCoseKeyFields(data, cose_key); + + // 'public_key'; conditionally optional + if (!data.public_key.empty()) { + cose_key->add(kCoseKeyPublicKeyLabel, data.public_key /* x : bstr */); + } +} +} // namespace + +// == COSE Key Data == + +int64_t CoseKeyData::key_type() const { + if (IsEc2Data()) return GetEc2Data().key_type; + return GetOkpData().key_type; +} + +bool CoseKeyData::has_algorithm() const { + if (IsEc2Data()) return GetEc2Data().algorithm.has_value(); + return GetOkpData().algorithm.has_value(); +} + +std::optional CoseKeyData::algorithm() const { + if (IsEc2Data()) return GetEc2Data().algorithm; + return GetOkpData().algorithm; +} + +void CoseKeyData::set_algorithm(int64_t algorithm) { + if (IsEc2Data()) { + GetEc2Data().algorithm = algorithm; + } else { + GetOkpData().algorithm = algorithm; + } +} + +void CoseKeyData::clear_algorithm() { + if (IsEc2Data()) { + GetEc2Data().algorithm.reset(); + } else { + GetOkpData().algorithm.reset(); + } +} + +const std::vector& CoseKeyData::key_ops() const { + if (IsEc2Data()) return GetEc2Data().key_ops; + return GetOkpData().key_ops; +} + +// private +std::vector& CoseKeyData::mutable_key_ops() { + if (IsEc2Data()) return GetEc2Data().key_ops; + return GetOkpData().key_ops; +} + +bool CoseKeyData::has_key_ops() const { + if (IsEc2Data()) return !GetEc2Data().key_ops.empty(); + return !GetOkpData().key_ops.empty(); +} + +void CoseKeyData::set_key_ops(const std::vector& key_ops) { + if (IsEc2Data()) { + GetEc2Data().key_ops = key_ops; + } else { + GetOkpData().key_ops = key_ops; + } +} + +void CoseKeyData::clear_key_ops() { + if (IsEc2Data()) { + GetEc2Data().key_ops.clear(); + } else { + GetOkpData().key_ops.clear(); + } +} + +bool CoseKeyData::ContainsKeyOp(int64_t key_op) const { + const auto& key_ops = this->key_ops(); + return std::find(key_ops.begin(), key_ops.end(), key_op) != key_ops.end(); +} + +void CoseKeyData::AppendKeyOp(int64_t key_op) { + mutable_key_ops().push_back(key_op); +} + +void CoseKeyData::RemoveKeyOp(int64_t key_op) { + auto& key_ops = mutable_key_ops(); + // This will remove all instances of |key_op|, not just the first. + key_ops.erase(std::remove(key_ops.begin(), key_ops.end(), key_op), + key_ops.end()); +} + +int64_t CoseKeyData::curve() const { + if (IsEc2Data()) return GetEc2Data().curve; + return GetOkpData().curve; +} + +void CoseKeyData::set_curve(int64_t curve) { + if (IsEc2Data()) { + GetEc2Data().curve = curve; + } else { + GetOkpData().curve = curve; + } +} + +bool CoseKeyData::has_public_key() const { + if (IsEc2Data()) return GetEc2Data().has_public_key(); + return GetOkpData().has_public_key(); +} + +void CoseKeyData::clear_public_key() { + if (IsEc2Data()) { + GetEc2Data().clear_public_key(); + } else { + GetOkpData().clear_public_key(); + } +} + +const std::vector& CoseKeyData::private_key() const { + if (IsEc2Data()) return GetEc2Data().private_key; + return GetOkpData().private_key; +} + +bool CoseKeyData::has_private_key() const { + if (IsEc2Data()) return GetEc2Data().has_private_key(); + return GetOkpData().has_private_key(); +} + +void CoseKeyData::set_private_key(const std::vector& private_key) { + if (IsEc2Data()) { + GetEc2Data().private_key = private_key; + } else { + GetOkpData().private_key = private_key; + } +} + +void CoseKeyData::clear_private_key() { + if (IsEc2Data()) { + GetEc2Data().clear_private_key(); + } else { + GetOkpData().clear_private_key(); + } +} + +std::optional CoseKeyData::ec2_data() const { + if (IsEc2Data()) return GetEc2Data(); + return std::nullopt; +} + +std::optional CoseKeyData::okp_data() const { + if (IsOkpData()) return GetOkpData(); + return std::nullopt; +} + +void CoseKeyData::Clear() { + if (IsEc2Data()) { + GetEc2Data().Clear(); + } else { + GetOkpData().Clear(); + } +} + +bool CoseKeyData::ParseCbor(const uint8_t* buffer, size_t buffer_size) { + if (buffer == nullptr) { + LOGE("Input |buffer| is null"); + return false; + } + if (buffer_size == 0) { + LOGE("Input |buffer| is empty"); + return false; + } + Clear(); // Clear existing data. + + const auto [cose_key_item, end_pos, error_message] = + cppbor::parse(buffer, buffer_size); + if (!cose_key_item) { + LOGE("Failed to parse CBOR COSE_Key: %s", error_message.c_str()); + return false; + } + if (!IsItemMapType(*cose_key_item)) { + LOGE("CBOR item is not a map: type = %s", ItemTypeToString(*cose_key_item)); + return false; + } + const cppbor::Map* cose_key = cose_key_item->asMap(); + + // 'kty' - key type; required. + const auto& key_type_item = cose_key->get(kCoseKeyKeyTypeLabel); + if (!key_type_item) { + LOGE("Missing 'kty' field"); + return false; + } + if (!IsItemIntType(*key_type_item)) { + LOGE("Expected 'kty' field to be an integer: type = %s", + ItemTypeToString(*key_type_item)); + return false; + } + + // Using helper methods to parse the remaining fields. + const int64_t key_type = key_type_item->asInt()->value(); + if (key_type == kCoseKeyTypeOkp) { + // Case OKP + CoseOkpKeyData ed_data; + if (!ParseCoseOkpKeyData(*cose_key, &ed_data)) return false; + data_ = std::move(ed_data); + } else if (key_type == kCoseKeyTypeEc2) { + // Case EC2 + CoseEc2KeyData ec_data; + if (!ParseCoseEc2KeyData(*cose_key, &ec_data)) return false; + data_ = std::move(ec_data); + } else { + LOGE("Unsupported key type: kty = %zd", static_cast(key_type)); + return false; + } + return true; +} + +std::vector CoseKeyData::SerializeCbor() const { + auto cose_key = cppbor::Map(); + if (IsEc2Data()) { + AddCoseEc2KeyFields(GetEc2Data(), &cose_key); + } else { + AddCoseOkpKeyFields(GetOkpData(), &cose_key); + } + return cose_key.canonicalize().encode(); +} + +// == COSE_Sign1 Data == + +std::optional CoseSign1Data::ExtractAlgorithm() const { + if (protected_params_.empty()) { + LOGD("No protected parameters"); + return std::nullopt; + } + // First parse the "protected" field as CBOR. + const auto [header_map_item, end_pos, error_message] = + cppbor::parse(protected_params_); + if (!header_map_item) { + LOGE("Failed to parse 'protected' field: %s", error_message.c_str()); + return std::nullopt; + } + if (!IsItemMapType(*header_map_item)) { + LOGE("Parsed 'protected' is not a map: type = %s", + ItemTypeToString(*header_map_item)); + return std::nullopt; + } + // Extract out the 'alg' field. + const cppbor::Map* header_map = header_map_item->asMap(); + const auto& algorithm_item = header_map->get(kCoseGenHeaderAlgorithmLabel); + if (!algorithm_item) { + LOGD("No algorithm specified"); + return std::nullopt; + } + if (IsItemTstrType(*algorithm_item)) { + // Implementation does not support tstr algorithm types. + LOGW("Tstr 'alg' not supported"); + LOGD("alg : tstr = %s", algorithm_item->asTstr()->value().c_str()); + return std::nullopt; + } + if (!IsItemIntType(*algorithm_item)) { + LOGE("Expected 'alg' field to be an integer: type = %s", + ItemTypeToString(*algorithm_item)); + return std::nullopt; + } + return algorithm_item->asInt()->value(); +} + +void CoseSign1Data::SetAlgorithm(int64_t algorithm) { + protected_params_ = + cppbor::Map() + .add(kCoseGenHeaderAlgorithmLabel, algorithm /* alg : int */) + .canonicalize() + .encode(); +} + +bool CoseSign1Data::ParseCbor(const uint8_t* buffer, size_t buffer_size) { + if (buffer == nullptr) { + LOGE("Input |buffer| is null"); + return false; + } + if (buffer_size == 0) { + LOGE("Input |buffer| is empty"); + return false; + } + Clear(); // Clear existing data. + + const auto [cose_sign1_item, end_pos, error_message] = + cppbor::parse(buffer, buffer_size); + if (!cose_sign1_item) { + LOGE("Failed to parse CBOR COSE_Sign1: %s", error_message.c_str()); + return false; + } + if (!IsItemArrayType(*cose_sign1_item)) { + LOGE("COSE_Sign1 item is not a array: type = %s", + ItemTypeToString(*cose_sign1_item)); + return false; + } + // By checking the size, null checks are not needed later. + const cppbor::Array* cose_sign1 = cose_sign1_item->asArray(); + if (cose_sign1->size() != kCoseSign1Length) { + LOGE("Unexpected COSE_Sign1 length: length = %zu, expected = %zu", + cose_sign1->size(), kCoseSign1Length); + return false; + } + + // 'protected' - bstr .cbor {} / bstr .size 0; required. + const auto& protected_item = cose_sign1->get(kCoseSign1ProtectedIndex); + if (!IsItemBstrType(*protected_item)) { + LOGE("Expected 'protected' field to be a bstr: type = %s", + ItemTypeToString(*protected_item)); + return false; + } + protected_params_ = protected_item->asBstr()->value(); + if (!protected_params_.empty()) { + // Must be a CBOR serialized header_map. + const auto [header_map_item, p_end_pos, p_error_message] = + cppbor::parse(protected_params_); + if (!header_map_item) { + LOGE("Failed to parse CBOR COSE_Sign1 protected header: %s", + p_error_message.c_str()); + return false; + } + if (!IsItemMapType(*header_map_item)) { + LOGE("Expected 'protected' CBOR type to be a map: type = %s", + ItemTypeToString(*header_map_item)); + return false; + } + const cppbor::Map* header_map = header_map_item->asMap(); + // 'alg' - int; optional + const auto& algorithm_item = header_map->get(kCoseGenHeaderAlgorithmLabel); + // Verify if present algorithm is an integer type. + if (algorithm_item && !IsItemIntType(*algorithm_item)) { + LOGE("Expected 'alg' to be an integer: type = %s", + ItemTypeToString(*algorithm_item)); + return false; + } + // All other header_map fields are ignored. + } + + // 'unprotected' - map; required + // We only care that it is present, and is of type map. + const auto& unprotected_item = cose_sign1->get(kCoseSign1UnprotectedIndex); + if (!IsItemMapType(*unprotected_item)) { + LOGE("Expected 'unprotected' field to be a map: type = %s", + ItemTypeToString(*unprotected_item)); + return false; + } + + // 'payload' - bstr / nil; required + const auto& payload_item = cose_sign1->get(kCoseSign1PayloadIndex); + if (IsItemBstrType(*payload_item)) { + payload_ = payload_item->asBstr()->value(); + } else if (IsItemNilType(*payload_item)) { + payload_ = std::nullopt; + } else { + LOGE("Expected 'payload' field to be a bstr or nil: type = %s", + ItemTypeToString(*payload_item)); + return false; + } + + // 'signature' - bstr; required + const auto& signature_item = cose_sign1->get(kCoseSign1SignatureIndex); + if (IsItemBstrType(*signature_item)) { + signature_ = signature_item->asBstr()->value(); + } else { + LOGE("Expected 'signature' to be a bstr: type = %s", + ItemTypeToString(*signature_item)); + return false; + } + return true; +} + +std::vector CoseSign1Data::SerializeCbor() const { + cppbor::Array cose_sign1; + cose_sign1.add(protected_params_ /* protected : bstr */); + cose_sign1.add(cppbor::Map() /* unprotected */); + if (payload_.has_value()) { + cose_sign1.add(payload_.value() /* payload : bstr */); + } else { + cose_sign1.add(cppbor::Null() /* payload : nil */); + } + cose_sign1.add(signature_ /* signature */); + return cose_sign1.encode(); +} + +std::vector PackageCoseSign1SigStructure( + const std::vector& protected_params, + const std::vector& payload) { + return ::cppbor::Array() + .add(::cppbor::Tstr("Signature1") /* context : tstr */) + .add(protected_params /* body_protected : bstr */) + .add(std::vector() /* external_aad : bstr .size 0 */) + .add(payload /* payload : bstr */) + .encode(); +} +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/src/oemcrypto_cose_key.cpp b/oemcrypto/util/src/oemcrypto_cose_key.cpp new file mode 100644 index 0000000..23affb4 --- /dev/null +++ b/oemcrypto/util/src/oemcrypto_cose_key.cpp @@ -0,0 +1,874 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include "oemcrypto_cose_key.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "OEMCryptoCENCCommon.h" +#include "cose_utils.h" +#include "log.h" +#include "oemcrypto_ecc_key.h" +#include "oemcrypto_ed_key.h" + +namespace wvoec { +namespace util { +using Ec2YCoordVariant = decltype(CoseEc2KeyData().y); +namespace { +int64_t KeyFamilyToAlgorithm(CoseKeyFamily family) { + switch (family) { + case kCoseKeyP256: + return kCoseAlgorithmEs256; + case kCoseKeyP384: + return kCoseAlgorithmEs384; + case kCoseKeyEd25519: + return kCoseAlgorithmEdDsa; + case kCoseKeyFamilyUnknown: + break; + } + return 0; +} + +int64_t KeyFamilyToCurve(CoseKeyFamily family) { + switch (family) { + case kCoseKeyP256: + return kCoseCurveP256; + case kCoseKeyP384: + return kCoseCurveP384; + case kCoseKeyEd25519: + return kCoseCurveEd25519; + case kCoseKeyFamilyUnknown: + break; + } + return 0; +} + +CoseKeyFamily EccCurveToCoseKeyFamily(EccCurve curve) { + switch (curve) { + case kEccSecp256r1: + return kCoseKeyP256; + case kEccSecp384r1: + return kCoseKeyP384; + case kEccSecp521r1: + case kEccCurveUnknown: + break; + } + LOGE("Unsupported curve: %d", static_cast(curve)); + return kCoseKeyFamilyUnknown; +} + +// == Key Builders == + +std::unique_ptr MakeEdPublicKey(const CoseOkpKeyData& key_data) { + // Only OKP supported curve is Ed25519 curves. + if (key_data.curve != kCoseCurveEd25519) { + LOGE("Unexpected curve: curve = %zd, expected = Ed25519(%zd)", + static_cast(key_data.curve), + static_cast(kCoseCurveEd25519)); + return nullptr; + } + + // If the COSE_Key specified an algorithm, ensure it is the correct + // one. + const int64_t expected_algorithm = kCoseAlgorithmEdDsa; + const int64_t specified_algorithm = + key_data.algorithm.value_or(expected_algorithm); + if (specified_algorithm != expected_algorithm) { + LOGE("Unexpected COSE_Key algorithm: algorithm = %zd, expected = %zd", + static_cast(specified_algorithm), + static_cast(expected_algorithm)); + return nullptr; + } + + // Ideally, this is a public key; but if not, the public key can be + // recovered from the private key. + if (key_data.has_public_key()) { + auto ed_key = EdPublicKey::LoadRaw(key_data.public_key); + if (!ed_key) { + LOGE("Failed to load Ed25519 public key: public_key_length = %zu", + key_data.public_key.size()); + return nullptr; + } + return ed_key; + } + if (!key_data.has_private_key()) { + // Unexpected, should have been caught by CoseKeyData parser. + LOGE("COSE_Key is missing key value fields for Ed25519"); + return nullptr; + } + + auto ed_key = EdPublicKey::LoadFromRawPrivate(key_data.private_key); + if (!ed_key) { + LOGE( + "Failed to load Ed25519 public key from private: " + "private_key_length = %zu", + key_data.private_key.size()); + return nullptr; + } + return ed_key; +} + +std::unique_ptr MakeEdPrivateKey(const CoseOkpKeyData& key_data) { + // Only OKP supported curve is Ed25519 curves. + if (key_data.curve != kCoseCurveEd25519) { + LOGE("Unexpected curve: curve = %zd, expected = Ed25519(%zd)", + static_cast(key_data.curve), + static_cast(kCoseCurveEd25519)); + return nullptr; + } + + // If the COSE_Key specified an algorithm, ensure it is the correct + // one. + const int64_t expected_algorithm = kCoseAlgorithmEdDsa; + const int64_t specified_algorithm = + key_data.algorithm.value_or(expected_algorithm); + if (specified_algorithm != expected_algorithm) { + LOGE("Unexpected COSE_Key algorithm: algorithm = %zd, expected = %zd", + static_cast(specified_algorithm), + static_cast(expected_algorithm)); + return nullptr; + } + + if (!key_data.has_private_key()) { + LOGE("COSE_Key is missing private key value for Ed25519"); + return nullptr; + } + + auto ed_key = EdPrivateKey::LoadRaw(key_data.private_key); + if (!ed_key) { + LOGE("Failed to load Ed25519 private key: private_key_length = %zu", + key_data.private_key.size()); + return nullptr; + } + return ed_key; +} + +std::unique_ptr MakeEccPublicKey(EccCurve ecc_curve, + const std::vector& x, + const Ec2YCoordVariant& y) { + if (std::holds_alternative(y)) { + // Load using point compression. + return EccPublicKey::LoadKeyPoint(ecc_curve, x, std::get(y)); + } + if (!std::holds_alternative>(y)) { + // Should not happen. + LOGE("Unexpected Y variant"); + return nullptr; + } + + const std::vector& y_coord = std::get>(y); + auto ecc_key = EccPublicKey::LoadKeyPoint(ecc_curve, x, y_coord); + if (!ecc_key) { + LOGE( + "Failed to load ECC public key from key components: " + "x_length = %zu, y_length = %zu", + x.size(), y_coord.size()); + return nullptr; + } + return ecc_key; +} + +std::unique_ptr MakeEccPublicKey(const CoseEc2KeyData& key_data) { + // Only EC2 supported curves are P-256 or P-384. + if (key_data.curve != kCoseCurveP256 && key_data.curve != kCoseCurveP384) { + LOGE("Unexpected curve: curve = %zu, expected = P-256(%zd)/P-384(%zd)", + static_cast(key_data.curve), + static_cast(kCoseCurveP256), + static_cast(kCoseCurveP384)); + return nullptr; + } + const CoseKeyFamily family = + key_data.curve == kCoseCurveP384 ? kCoseKeyP384 : kCoseKeyP256; + const EccCurve ecc_curve = + key_data.curve == kCoseCurveP384 ? kEccSecp384r1 : kEccSecp256r1; + + // If the COSE_Key specified an algorithm, ensure it is the correct + // one. + const int64_t expected_algorithm = KeyFamilyToAlgorithm(family); + const int64_t specified_algorithm = + key_data.algorithm.value_or(expected_algorithm); + if (specified_algorithm != expected_algorithm) { + LOGE("Unexpected COSE_Key algorithm: algorithm = %zd, expected = %zd", + static_cast(specified_algorithm), + static_cast(expected_algorithm)); + return nullptr; + } + + if (key_data.has_public_key()) { + // Load from the public key part. + return MakeEccPublicKey(ecc_curve, key_data.x, key_data.y); + } + if (!key_data.has_private_key()) { + // Unexpected, should have been caught by CoseKeyData parser. + LOGE("COSE_Key is missing key value fields for ECC"); + return nullptr; + } + + // Load from the private key part. + auto ecc_key = + EccPublicKey::LoadFromPrivateScalar(ecc_curve, key_data.private_key); + if (!ecc_key) { + LOGE( + "Failed to load ECC public key from private key: " + "private_key_length = %zu", + key_data.private_key.size()); + return nullptr; + } + return ecc_key; +} + +std::unique_ptr MakeEccPrivateKey( + const CoseEc2KeyData& key_data) { + // Only EC2 supported curves are P-256 or P-384. + if (key_data.curve != kCoseCurveP256 && key_data.curve != kCoseCurveP384) { + LOGE("Unexpected curve: curve = %zu, expected = P-256(%zd)/P-384(%zd)", + static_cast(key_data.curve), + static_cast(kCoseCurveP256), + static_cast(kCoseCurveP384)); + return nullptr; + } + const CoseKeyFamily family = + key_data.curve == kCoseCurveP384 ? kCoseKeyP384 : kCoseKeyP256; + const EccCurve ecc_curve = + key_data.curve == kCoseCurveP384 ? kEccSecp384r1 : kEccSecp256r1; + + // If the COSE_Key specified an algorithm, ensure it is the correct + // one. + const int64_t expected_algorithm = KeyFamilyToAlgorithm(family); + const int64_t specified_algorithm = + key_data.algorithm.value_or(expected_algorithm); + if (specified_algorithm != expected_algorithm) { + LOGE("Unexpected COSE_Key algorithm: algorithm = %zd, expected = %zd", + static_cast(specified_algorithm), + static_cast(expected_algorithm)); + return nullptr; + } + + if (!key_data.has_private_key()) { + LOGE("COSE_Key is missing private key value for ECC"); + return nullptr; + } + + auto ecc_key = + EccPrivateKey::LoadPrivateScalar(ecc_curve, key_data.private_key); + if (!ecc_key) { + LOGE("Failed to load ECC private key: private_key_length = %zu", + key_data.private_key.size()); + return nullptr; + } + return ecc_key; +} +} // namespace + +// static +std::string CoseKeyFamilyToString(CoseKeyFamily family) { + switch (family) { + case kCoseKeyP256: + return "P-256"; + case kCoseKeyP384: + return "P-384"; + case kCoseKeyEd25519: + return "Ed25519"; + case kCoseKeyFamilyUnknown: + return "Unknown"; + } + return "(family)) + ")>"; +} + +CosePublicKey::CosePublicKey(std::unique_ptr&& ecc_key) + : ecc_key_(std::move(ecc_key)) {} + +CosePublicKey::CosePublicKey(std::unique_ptr&& ed_key) + : ed_key_(std::move(ed_key)) {} + +// static +std::unique_ptr CosePublicKey::New( + const CosePrivateKey& private_key) { + if (private_key.ecc_key() != nullptr) { + auto ecc_key = private_key.ecc_key()->MakePublicKey(); + if (!ecc_key) { + LOGE("Failed to make public ECC public key"); + return nullptr; + } + return std::unique_ptr( + new CosePublicKey(std::move(ecc_key))); + } + auto ed_key = private_key.ed_key()->MakePublicKey(); + if (!ed_key) { + LOGE("Failed to make public Ed25519 public key"); + return nullptr; + } + return std::unique_ptr(new CosePublicKey(std::move(ed_key))); +} + +// static +std::unique_ptr CosePublicKey::LoadCborCoseKey( + const uint8_t* buffer, size_t buffer_size) { + if (buffer == nullptr) { + LOGE("Input |buffer| is null"); + return nullptr; + } + if (buffer_size == 0) { + LOGE("CBOR data is zero"); + return nullptr; + } + CoseKeyData cose_key_data; + if (!cose_key_data.ParseCbor(buffer, buffer_size)) { + LOGE("Failed to parse COSE_Key"); + return nullptr; + } + + if (cose_key_data.key_type() == kCoseKeyTypeOkp) { + const auto okp_key_data_opt = cose_key_data.okp_data(); + if (!okp_key_data_opt.has_value()) { + // Unexpected. + LOGE("Failed to get OKP key data"); + return nullptr; + } + auto ed_public_key = MakeEdPublicKey(okp_key_data_opt.value()); + if (!ed_public_key) { + LOGE("Failed to construct Ed25519 public key from COSE_Key"); + return nullptr; + } + return std::unique_ptr( + new CosePublicKey(std::move(ed_public_key))); + } + + if (cose_key_data.key_type() == kCoseKeyTypeEc2) { + const auto ec2_key_data_opt = cose_key_data.ec2_data(); + if (!ec2_key_data_opt.has_value()) { + // Unexpected. + LOGE("Failed to get EC2 key data"); + return nullptr; + } + auto ecc_public_key = MakeEccPublicKey(ec2_key_data_opt.value()); + if (!ecc_public_key) { + LOGE("Failed to construct ECC public key from COSE_Key"); + return nullptr; + } + return std::unique_ptr( + new CosePublicKey(std::move(ecc_public_key))); + } + + LOGE("Unexpected key type: kty = %zd, expected = EC2(%zd)/OKP(%zd)", + static_cast(cose_key_data.key_type()), + static_cast(kCoseKeyTypeEc2), + static_cast(kCoseKeyTypeOkp)); + return nullptr; +} + +// static +std::unique_ptr CosePublicKey::LoadCborCoseKey( + const std::string& buffer) { + if (buffer.empty()) { + LOGE("Input |buffer| is empty"); + return nullptr; + } + return LoadCborCoseKey(reinterpret_cast(buffer.data()), + buffer.size()); +} + +// static +std::unique_ptr CosePublicKey::LoadCborCoseKey( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Input |buffer| is empty"); + return nullptr; + } + return LoadCborCoseKey(buffer.data(), buffer.size()); +} + +// static +std::unique_ptr CosePublicKey::LoadEccKeyPoint( + CoseKeyFamily family, const std::vector& x, + const std::vector& y) { + if (family != kCoseKeyP256 && family != kCoseKeyP384) { + LOGE("Non-ECC family: %s", CoseKeyFamilyToString(family).c_str()); + return nullptr; + } + const EccCurve ecc_curve = + family == kCoseKeyP384 ? kEccSecp384r1 : kEccSecp256r1; + auto ecc_key = EccPublicKey::LoadKeyPoint(ecc_curve, x, y); + if (!ecc_key) { + // Could fail for a number of reasons, most likely is that the + // X and/or Y coordinates are the wrong length. + LOGE("Failed to load ECC key: family = %s, x_size = %zu, y_size = %zu", + CoseKeyFamilyToString(family).c_str(), x.size(), y.size()); + return nullptr; + } + return std::unique_ptr(new CosePublicKey(std::move(ecc_key))); +} + +// static +std::unique_ptr CosePublicKey::LoadEccKeyPoint( + CoseKeyFamily family, const std::vector& x, bool y_sign_bit) { + if (family != kCoseKeyP256 && family != kCoseKeyP384) { + LOGE("Non-ECC family: %s", CoseKeyFamilyToString(family).c_str()); + return nullptr; + } + const EccCurve ecc_curve = + family == kCoseKeyP384 ? kEccSecp384r1 : kEccSecp256r1; + auto ecc_key = EccPublicKey::LoadKeyPoint(ecc_curve, x, y_sign_bit); + if (!ecc_key) { + // Could fail for a number of reasons, most likely is that the + // X coordinate is the wrong length. + LOGE("Failed to load ECC key: family = %s, x_size = %zu, y_sign_bit = %s", + CoseKeyFamilyToString(family).c_str(), x.size(), + y_sign_bit ? "true" : "false"); + return nullptr; + } + return std::unique_ptr(new CosePublicKey(std::move(ecc_key))); +} + +// static +std::unique_ptr CosePublicKey::LoadEccSec1KeyPoint( + CoseKeyFamily family, const std::vector& key_point) { + if (family != kCoseKeyP256 && family != kCoseKeyP384) { + LOGE("Non-ECC family: %s", CoseKeyFamilyToString(family).c_str()); + return nullptr; + } + const EccCurve ecc_curve = + family == kCoseKeyP384 ? kEccSecp384r1 : kEccSecp256r1; + auto ecc_key = EccPublicKey::LoadKeyPoint(ecc_curve, key_point); + if (!ecc_key) { + LOGE("Failed to load ECC key: family = %s, key_point_length = %zu", + CoseKeyFamilyToString(family).c_str(), key_point.size()); + return nullptr; + } + return std::unique_ptr(new CosePublicKey(std::move(ecc_key))); +} + +// static +std::unique_ptr CosePublicKey::LoadEccFromPrivateScalar( + CoseKeyFamily family, const std::vector& private_scalar) { + if (family != kCoseKeyP256 && family != kCoseKeyP384) { + LOGE("Non-ECC family: %s", CoseKeyFamilyToString(family).c_str()); + return nullptr; + } + const EccCurve ecc_curve = + family == kCoseKeyP384 ? kEccSecp384r1 : kEccSecp256r1; + auto ecc_key = EccPublicKey::LoadFromPrivateScalar(ecc_curve, private_scalar); + if (!ecc_key) { + LOGE("Failed to load ECC key: family = %s, private_scalar_length = %zu", + CoseKeyFamilyToString(family).c_str(), private_scalar.size()); + return nullptr; + } + return std::unique_ptr(new CosePublicKey(std::move(ecc_key))); +} + +// static +std::unique_ptr CosePublicKey::LoadEd25519FromPublic( + const std::vector& public_key) { + auto ed_key = EdPublicKey::LoadRaw(public_key); + if (!ed_key) { + LOGE("Failed to load Ed25519 public key: public_key_length = %zu", + public_key.size()); + return nullptr; + } + return std::unique_ptr(new CosePublicKey(std::move(ed_key))); +} + +// static +std::unique_ptr CosePublicKey::LoadEd25519FromPrivate( + const std::vector& private_key) { + auto ed_key = EdPublicKey::LoadFromRawPrivate(private_key); + if (!ed_key) { + LOGE( + "Failed to load Ed25519 public key from private: private_key_length = " + "%zu", + private_key.size()); + return nullptr; + } + return std::unique_ptr(new CosePublicKey(std::move(ed_key))); +} + +CoseKeyFamily CosePublicKey::family() const { + if (ecc_key_) return EccCurveToCoseKeyFamily(ecc_key_->curve()); + return kCoseKeyEd25519; +} + +bool CosePublicKey::IsMatchingPrivateKey( + const CosePrivateKey& private_key) const { + if (family() != private_key.family()) return false; + if (ecc_key_) return ecc_key_->IsMatchingPrivateKey(*private_key.ecc_key()); + return ed_key_->IsMatchingPrivateKey(*private_key.ed_key()); +} + +std::vector CosePublicKey::SerializeCbor( + bool include_algorithm) const { + CoseKeyData key_data; + // Only serialize public key components. + if (ecc_key_) { + CoseEc2KeyData ec2_key_data; + ec2_key_data.x = ecc_key_->SerializeXCoord(); + // Point compression is discouraged, and not provided + // as an option here. + ec2_key_data.y = ecc_key_->SerializeYCoord(); + key_data.set_ec2_data(ec2_key_data); + } else { + CoseOkpKeyData okp_key_data; + okp_key_data.public_key = ed_key_->SerializeRaw(); + key_data.set_okp_data(okp_key_data); + } + + const CoseKeyFamily family = this->family(); + if (include_algorithm) { + key_data.set_algorithm(KeyFamilyToAlgorithm(family)); + } + key_data.set_curve(KeyFamilyToCurve(family)); + + return key_data.SerializeCbor(); +} + +OEMCryptoResult CosePublicKey::VerifyCoseSign1Signature( + const std::vector& cose_sign1) const { + CoseSign1Data sign1_data; + if (!sign1_data.ParseCbor(cose_sign1)) { + LOGE("Failed to parse COSE_Sign1 message"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + + const std::optional algorithm_opt = sign1_data.ExtractAlgorithm(); + if (algorithm_opt.has_value()) { + // If the algorithm is specified, then ensure it is + // the correct algorithm for this key. + const int64_t prescribed_algorithm = algorithm_opt.value(); + const CoseKeyFamily family = this->family(); + const int64_t expected_algorithm = KeyFamilyToAlgorithm(family); + + if (prescribed_algorithm != expected_algorithm) { + LOGE( + "COSE_Sign1 specifies different algorithm: " + "prescribed_algorithm = %zd, expected_algorithm = %zd, family = %s", + static_cast(prescribed_algorithm), + static_cast(expected_algorithm), + CoseKeyFamilyToString(family).c_str()); + // Consider returning a different error. + return OEMCrypto_ERROR_INVALID_KEY; + } + } + if (!sign1_data.has_payload()) { + LOGE("COSE_Sign1 message does not contain payload"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + const std::vector& payload = sign1_data.payload().value(); + return VerifyCoseSign1SignatureInternal(sign1_data.protected_params(), + payload, sign1_data.signature()); +} + +OEMCryptoResult CosePublicKey::VerifyCoseSign1Signature( + const std::vector& message, + const std::vector& cose_sign1) const { + CoseSign1Data sign1_data; + if (!sign1_data.ParseCbor(cose_sign1)) { + LOGE("Failed to parse COSE_Sign1 message"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + + const std::optional algorithm_opt = sign1_data.ExtractAlgorithm(); + if (algorithm_opt.has_value()) { + // If the algorithm is specified, then ensure it is + // the correct algorithm for this key. + const int64_t prescribed_algorithm = algorithm_opt.value(); + const CoseKeyFamily family = this->family(); + const int64_t expected_algorithm = KeyFamilyToAlgorithm(family); + + if (prescribed_algorithm != expected_algorithm) { + LOGE( + "COSE_Sign1 specifies different algorithm: " + "prescribed_algorithm = %zd, expected_algorithm = %zd, family = %s", + static_cast(prescribed_algorithm), + static_cast(expected_algorithm), + CoseKeyFamilyToString(family).c_str()); + // Consider returning a different error. + return OEMCrypto_ERROR_INVALID_KEY; + } + } + if (sign1_data.has_payload()) { + // If the COSE_Sign1 object contains a payload, but the caller is + // also providing the payload, then ensure that both the COSE_Sign1 + // payload and the provided payload are equal. + const std::vector& payload = sign1_data.payload().value(); + if (!std::equal(message.begin(), message.end(), payload.begin(), + payload.end())) { + LOGE("COSE_Sign1 payload and provided payload are not equal"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + } + return VerifyCoseSign1SignatureInternal(sign1_data.protected_params(), + message, sign1_data.signature()); +} + +// private +OEMCryptoResult CosePublicKey::VerifyCoseSign1SignatureInternal( + const std::vector& protected_params, + const std::vector& payload, + const std::vector& signature) const { + const std::vector sig_struct = + PackageCoseSign1SigStructure(protected_params, payload); + if (sig_struct.empty()) { + LOGE("Failed to pack Sig_structure"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + return VerifyRawSignatureInternal(sig_struct, signature); +} + +// private +OEMCryptoResult CosePublicKey::VerifyRawSignatureInternal( + const std::vector& verifying_payload, + const std::vector& signature) const { + if (ecc_key_) + return ecc_key_->VerifyRawSignature(verifying_payload, signature); + return ed_key_->VerifySignature(verifying_payload, signature); +} + +// ==== COSE Private Key ==== + +CosePrivateKey::CosePrivateKey(std::unique_ptr&& ecc_key) + : ecc_key_(std::move(ecc_key)) {} + +CosePrivateKey::CosePrivateKey(std::unique_ptr&& ed_key) + : ed_key_(std::move(ed_key)) {} + +// static +std::unique_ptr CosePrivateKey::New(CoseKeyFamily family) { + switch (family) { + case kCoseKeyP256: { + auto ecc_key = EccPrivateKey::New(kEccSecp256r1); + if (!ecc_key) { + LOGE("Failed to create ECC secp256r1/P-256 key"); + return nullptr; + } + return std::unique_ptr( + new CosePrivateKey(std::move(ecc_key))); + } + case kCoseKeyP384: { + auto ecc_key = EccPrivateKey::New(kEccSecp384r1); + if (!ecc_key) { + LOGE("Failed to create ECC secp384r1/P-384 key"); + return nullptr; + } + return std::unique_ptr( + new CosePrivateKey(std::move(ecc_key))); + } + case kCoseKeyEd25519: { + auto ed_key = EdPrivateKey::New(); + if (!ed_key) { + LOGE("Failed to create Ed25519 key"); + return nullptr; + } + return std::unique_ptr( + new CosePrivateKey(std::move(ed_key))); + } + case kCoseKeyFamilyUnknown: + break; + } + LOGE("Unexpected COSE key family: %s", CoseKeyFamilyToString(family).c_str()); + return nullptr; +} + +// static +std::unique_ptr CosePrivateKey::LoadCborCoseKey( + const uint8_t* buffer, size_t buffer_size) { + if (buffer == nullptr) { + LOGE("Input |buffer| is null"); + return nullptr; + } + if (buffer_size == 0) { + LOGE("CBOR data is zero"); + return nullptr; + } + CoseKeyData cose_key_data; + if (!cose_key_data.ParseCbor(buffer, buffer_size)) { + LOGE("Failed to parse COSE_Key"); + return nullptr; + } + + if (cose_key_data.key_type() == kCoseKeyTypeOkp) { + const auto okp_key_data_opt = cose_key_data.okp_data(); + if (!okp_key_data_opt.has_value()) { + // Unexpected. + LOGE("Failed to get OKP key data"); + return nullptr; + } + auto ed_private_key = MakeEdPrivateKey(okp_key_data_opt.value()); + if (!ed_private_key) { + LOGE("Failed to construct Ed25519 private key from COSE_Key"); + return nullptr; + } + return std::unique_ptr( + new CosePrivateKey(std::move(ed_private_key))); + } + + if (cose_key_data.key_type() == kCoseKeyTypeEc2) { + const auto ec2_key_data_opt = cose_key_data.ec2_data(); + if (!ec2_key_data_opt.has_value()) { + // Unexpected. + LOGE("Failed to get EC2 key data"); + return nullptr; + } + auto ecc_private_key = MakeEccPrivateKey(ec2_key_data_opt.value()); + if (!ecc_private_key) { + LOGE("Failed to construct ECC private key from COSE_Key"); + return nullptr; + } + return std::unique_ptr( + new CosePrivateKey(std::move(ecc_private_key))); + } + + LOGE("Unexpected key type: kty = %zd, expected = EC2(%zd)/OKP(%zd)", + static_cast(cose_key_data.key_type()), + static_cast(kCoseKeyTypeEc2), + static_cast(kCoseKeyTypeOkp)); + return nullptr; +} + +// static +std::unique_ptr CosePrivateKey::LoadCborCoseKey( + const std::string& buffer) { + if (buffer.empty()) { + LOGE("Input |buffer| is empty"); + return nullptr; + } + return LoadCborCoseKey(reinterpret_cast(buffer.data()), + buffer.size()); +} + +// static +std::unique_ptr CosePrivateKey::LoadCborCoseKey( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Input |buffer| is empty"); + return nullptr; + } + return LoadCborCoseKey(buffer.data(), buffer.size()); +} + +// static +std::unique_ptr CosePrivateKey::LoadEccPrivateScalar( + CoseKeyFamily family, const std::vector& private_scalar) { + if (family != kCoseKeyP256 && family != kCoseKeyP384) { + LOGE("Non-ECC family: %s", CoseKeyFamilyToString(family).c_str()); + return nullptr; + } + const EccCurve ecc_curve = + family == kCoseKeyP384 ? kEccSecp384r1 : kEccSecp256r1; + auto ecc_key = EccPrivateKey::LoadPrivateScalar(ecc_curve, private_scalar); + if (!ecc_key) { + LOGE("Failed to load ECC key: family = %s, private_scalar_length = %zu", + CoseKeyFamilyToString(family).c_str(), private_scalar.size()); + return nullptr; + } + return std::unique_ptr( + new CosePrivateKey(std::move(ecc_key))); +} + +// static +std::unique_ptr CosePrivateKey::LoadEd25519RawPrivate( + const std::vector& private_key) { + auto ed_key = EdPrivateKey::LoadRaw(private_key); + if (!ed_key) { + LOGE( + "Failed to load Ed25519 public key from private: private_key_length = " + "%zu", + private_key.size()); + return nullptr; + } + return std::unique_ptr(new CosePrivateKey(std::move(ed_key))); +} + +std::unique_ptr CosePrivateKey::MakePublicKey() const { + return CosePublicKey::New(*this); +} + +CoseKeyFamily CosePrivateKey::family() const { + if (ecc_key_) return EccCurveToCoseKeyFamily(ecc_key_->curve()); + return kCoseKeyEd25519; +} + +bool CosePrivateKey::IsMatchingPublicKey( + const CosePublicKey& public_key) const { + if (family() != public_key.family()) return false; + if (ecc_key_) return ecc_key_->IsMatchingPublicKey(*public_key.ecc_key()); + return ed_key_->IsMatchingPublicKey(*public_key.ed_key()); +} + +std::vector CosePrivateKey::SerializeCbor( + bool include_algorithm) const { + CoseKeyData key_data; + // Serialize both public and private key components. + if (ecc_key_) { + CoseEc2KeyData ec2_key_data; + ec2_key_data.x = ecc_key_->SerializeXCoord(); + // Point compression is discouraged, and not provided + // as an option here. + ec2_key_data.y = ecc_key_->SerializeYCoord(); + ec2_key_data.private_key = ecc_key_->SerializePrivateScalar(); + key_data.set_ec2_data(ec2_key_data); + } else { + CoseOkpKeyData okp_key_data; + okp_key_data.public_key = ed_key_->SerializeAsRawPublicKey(); + okp_key_data.private_key = ed_key_->SerializeRaw(); + key_data.set_okp_data(okp_key_data); + } + + const CoseKeyFamily family = this->family(); + if (include_algorithm) { + key_data.set_algorithm(KeyFamilyToAlgorithm(family)); + } + key_data.set_curve(KeyFamilyToCurve(family)); + + return key_data.SerializeCbor(); +} + +std::vector CosePrivateKey::GenerateCoseSign1Signature( + const std::vector& payload, bool include_payload) const { + CoseSign1Data sign1_data; + const CoseKeyFamily family = this->family(); + // Always include "algorithm" in COSE_Sign1. + sign1_data.SetAlgorithm(KeyFamilyToAlgorithm(family)); + if (sign1_data.protected_params().empty()) { + LOGE("Failed to set COSE_Sign1 algorithm"); + return std::vector(); + } + + const std::vector signing_payload = + PackageCoseSign1SigStructure(sign1_data.protected_params(), payload); + if (signing_payload.empty()) { + LOGE("Failed to package signing payload"); + return std::vector(); + } + + const std::vector signature = + GenerateRawSignatureInternal(signing_payload); + if (signature.empty()) { + LOGE("Failed to generate signature"); + return std::vector(); + } + + if (include_payload) { + sign1_data.set_payload(payload); + } + sign1_data.set_signature(signature); + + return sign1_data.SerializeCbor(); +} + +std::vector CosePrivateKey::GenerateRawSignatureInternal( + const std::vector& signing_payload) const { + if (ecc_key_) return ecc_key_->GenerateRawSignature(signing_payload); + return ed_key_->GenerateSignature(signing_payload); +} +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/src/oemcrypto_drm_key.cpp b/oemcrypto/util/src/oemcrypto_drm_key.cpp index 14a35fe..18b07c7 100644 --- a/oemcrypto/util/src/oemcrypto_drm_key.cpp +++ b/oemcrypto/util/src/oemcrypto_drm_key.cpp @@ -13,28 +13,177 @@ namespace wvoec { namespace util { +namespace { +bool IsSupportedHashAlgorithm(OEMCrypto_SignatureHashAlgorithm algorithm) { + switch (algorithm) { + case OEMCrypto_SHA1: + case OEMCrypto_SHA2_256: + case OEMCrypto_SHA2_384: + case OEMCrypto_SHA2_512: + return true; + } + return false; +} + +inline OEMCrypto_SignatureHashAlgorithm UnwrapHashAlgorithm( + DrmSignatureHashAlgorithmOption hash_algorithm_opt) { + // Default is SHA-1. + return hash_algorithm_opt.value_or(OEMCrypto_SHA1); +} + +inline RsaSignatureAlgorithm GetPssAlgorithm( + OEMCrypto_SignatureHashAlgorithm hash_algorithm) { + return RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm(hash_algorithm); +} +} // namespace + +// == DRM Public Key == + +// private +DrmPublicKey::DrmPublicKey(std::shared_ptr&& rsa_key, + OEMCrypto_SignatureHashAlgorithm rsa_hash_algorithm) + : rsa_key_(std::move(rsa_key)), rsa_hash_algorithm_(rsa_hash_algorithm) {} + +// private +DrmPublicKey::DrmPublicKey(std::shared_ptr&& ecc_key) + : ecc_key_(std::move(ecc_key)) {} + // static -std::unique_ptr DrmPrivateKey::Create( - std::shared_ptr&& rsa_key) { +std::unique_ptr DrmPublicKey::Create( + std::shared_ptr&& rsa_key, + std::optional hash_algorithm_opt) { if (!rsa_key) { LOGE("No RSA key provided"); - return std::unique_ptr(); + return nullptr; } - std::unique_ptr drm_key(new DrmPrivateKey()); - drm_key->rsa_key_ = std::move(rsa_key); - return drm_key; + const auto hash_algorithm = UnwrapHashAlgorithm(hash_algorithm_opt); + if (!IsSupportedHashAlgorithm(hash_algorithm)) { + LOGE("Unsupported hash algorithm: %d", static_cast(hash_algorithm)); + return nullptr; + } + return std::unique_ptr( + new DrmPublicKey(std::move(rsa_key), hash_algorithm)); +} + +// static +std::unique_ptr DrmPublicKey::Create( + std::unique_ptr&& rsa_key, + std::optional hash_algorithm_opt) { + if (!rsa_key) { + LOGE("No RSA key provided"); + return nullptr; + } + const auto hash_algorithm = UnwrapHashAlgorithm(hash_algorithm_opt); + if (!IsSupportedHashAlgorithm(hash_algorithm)) { + LOGE("Unsupported hash algorithm: %d", static_cast(hash_algorithm)); + return nullptr; + } + return std::unique_ptr( + new DrmPublicKey(std::move(rsa_key), hash_algorithm)); +} + +// static +std::unique_ptr DrmPublicKey::Create( + std::shared_ptr&& ecc_key) { + if (!ecc_key) { + LOGE("No ECC key provided"); + return nullptr; + } + return std::unique_ptr(new DrmPublicKey(std::move(ecc_key))); +} + +// static +std::unique_ptr DrmPublicKey::Create( + std::unique_ptr&& ecc_key) { + if (!ecc_key) { + LOGE("No ECC key provided"); + return nullptr; + } + return std::unique_ptr(new DrmPublicKey(std::move(ecc_key))); +} + +OEMCrypto_SignatureHashAlgorithm DrmPublicKey::GetSignatureHashAlgorithm() + const { + if (rsa_key_) return rsa_hash_algorithm_; + return ecc_key_->GetSignatureHashAlgorithm(); +} + +OEMCryptoResult DrmPublicKey::VerifySignature(const uint8_t* message, + size_t message_length, + const uint8_t* signature, + size_t signature_length) const { + if (rsa_key_) { + return rsa_key_->VerifySignature(message, message_length, signature, + signature_length, + GetPssAlgorithm(rsa_hash_algorithm_)); + } + return ecc_key_->VerifySignature(message, message_length, signature, + signature_length); +} + +OEMCryptoResult DrmPublicKey::VerifySignature( + const std::vector& message, + const std::vector& signature) const { + if (rsa_key_) { + return rsa_key_->VerifySignature(message, signature, + GetPssAlgorithm(rsa_hash_algorithm_)); + } + return ecc_key_->VerifySignature(message, signature); +} + +OEMCryptoResult DrmPublicKey::VerifySignature( + const std::string& message, const std::string& signature) const { + if (rsa_key_) { + return rsa_key_->VerifySignature(message, signature, + GetPssAlgorithm(rsa_hash_algorithm_)); + } + return ecc_key_->VerifySignature(message, signature); +} + +// == DRM Private Key == + +// private +DrmPrivateKey::DrmPrivateKey(std::shared_ptr&& ecc_key) + : ecc_key_(std::move(ecc_key)) {} + +// private +DrmPrivateKey::DrmPrivateKey( + std::shared_ptr&& rsa_key, + OEMCrypto_SignatureHashAlgorithm rsa_hash_algorithm) + : rsa_key_(std::move(rsa_key)), rsa_hash_algorithm_(rsa_hash_algorithm) {} + +// static +std::unique_ptr DrmPrivateKey::Create( + std::shared_ptr&& rsa_key, + DrmSignatureHashAlgorithmOption hash_algorithm_opt) { + if (!rsa_key) { + LOGE("No RSA key provided"); + return nullptr; + } + const auto hash_algorithm = UnwrapHashAlgorithm(hash_algorithm_opt); + if (!IsSupportedHashAlgorithm(hash_algorithm)) { + LOGE("Unsupported hash algorithm: %d", static_cast(hash_algorithm)); + return nullptr; + } + return std::unique_ptr( + new DrmPrivateKey(std::move(rsa_key), hash_algorithm)); } // static std::unique_ptr DrmPrivateKey::Create( - std::unique_ptr&& rsa_key) { + std::unique_ptr&& rsa_key, + DrmSignatureHashAlgorithmOption hash_algorithm_opt) { if (!rsa_key) { LOGE("No RSA key provided"); return std::unique_ptr(); } - std::unique_ptr drm_key(new DrmPrivateKey()); - drm_key->rsa_key_ = std::move(rsa_key); - return drm_key; + const auto hash_algorithm = UnwrapHashAlgorithm(hash_algorithm_opt); + if (!IsSupportedHashAlgorithm(hash_algorithm)) { + LOGE("Unsupported hash algorithm: %d", static_cast(hash_algorithm)); + return nullptr; + } + return std::unique_ptr( + new DrmPrivateKey(std::move(rsa_key), hash_algorithm)); } // static @@ -44,9 +193,7 @@ std::unique_ptr DrmPrivateKey::Create( LOGE("No ECC key provided"); return std::unique_ptr(); } - std::unique_ptr drm_key(new DrmPrivateKey()); - drm_key->ecc_key_ = std::move(ecc_key); - return drm_key; + return std::unique_ptr(new DrmPrivateKey(std::move(ecc_key))); } // static @@ -56,9 +203,25 @@ std::unique_ptr DrmPrivateKey::Create( LOGE("No ECC key provided"); return std::unique_ptr(); } - std::unique_ptr drm_key(new DrmPrivateKey()); - drm_key->ecc_key_ = std::move(ecc_key); - return drm_key; + return std::unique_ptr(new DrmPrivateKey(std::move(ecc_key))); +} + +OEMCrypto_SignatureHashAlgorithm DrmPrivateKey::GetSignatureHashAlgorithm() + const { + if (rsa_key_) { + return rsa_hash_algorithm_; + } + return ecc_key_->GetSignatureHashAlgorithm(); +} + +OEMCryptoResult DrmPrivateKey::GetSignatureHashAlgorithm( + OEMCrypto_SignatureHashAlgorithm* algorithm) const { + if (algorithm == nullptr) { + LOGE("Output |algorithm| is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + *algorithm = GetSignatureHashAlgorithm(); + return OEMCrypto_SUCCESS; } OEMCryptoResult DrmPrivateKey::GetSessionKey( @@ -141,7 +304,8 @@ OEMCryptoResult DrmPrivateKey::GenerateSignature( const uint8_t* message, size_t message_length, uint8_t* signature, size_t* signature_length) const { if (rsa_key_) { - return rsa_key_->GenerateSignature(message, message_length, kRsaPssDefault, + return rsa_key_->GenerateSignature(message, message_length, + GetPssAlgorithm(rsa_hash_algorithm_), signature, signature_length); } return ecc_key_->GenerateSignature(message, message_length, signature, @@ -151,7 +315,17 @@ OEMCryptoResult DrmPrivateKey::GenerateSignature( std::vector DrmPrivateKey::GenerateSignature( const std::vector& message) const { if (rsa_key_) { - return rsa_key_->GenerateSignature(message, kRsaPssDefault); + return rsa_key_->GenerateSignature(message, + GetPssAlgorithm(rsa_hash_algorithm_)); + } + return ecc_key_->GenerateSignature(message); +} + +std::vector DrmPrivateKey::GenerateSignature( + const std::string& message) const { + if (rsa_key_) { + return rsa_key_->GenerateSignature(message, + GetPssAlgorithm(rsa_hash_algorithm_)); } return ecc_key_->GenerateSignature(message); } @@ -167,7 +341,7 @@ OEMCryptoResult DrmPrivateKey::GenerateRsaSignature( const uint8_t* message, size_t message_length, uint8_t* signature, size_t* signature_length) const { if (!rsa_key_) { - LOGE("Only RSA DRM keys can generate PKCS1 signatures"); + LOGE("Only RSA DRM keys can generate CAST PKCS1 signatures"); return OEMCrypto_ERROR_INVALID_KEY; } return rsa_key_->GenerateSignature(message, message_length, kRsaPkcs1Cast, @@ -177,7 +351,7 @@ OEMCryptoResult DrmPrivateKey::GenerateRsaSignature( std::vector DrmPrivateKey::GenerateRsaSignature( const std::vector& message) const { if (!rsa_key_) { - LOGE("Only RSA DRM keys can generate PKCS1 signatures"); + LOGE("Only RSA DRM keys can generate CAST PKCS1 signatures"); return std::vector(); } return rsa_key_->GenerateSignature(message, kRsaPkcs1Cast); diff --git a/oemcrypto/util/src/oemcrypto_ecc_key.cpp b/oemcrypto/util/src/oemcrypto_ecc_key.cpp index 7f98687..39a84bf 100644 --- a/oemcrypto/util/src/oemcrypto_ecc_key.cpp +++ b/oemcrypto/util/src/oemcrypto_ecc_key.cpp @@ -28,7 +28,7 @@ namespace { // Estimated max size (in bytes) of a serialized ECC key (public or // private). These values are based on rough calculations for // secp521r1 (largest of the supported curves) and should be slightly -// larger needed. +// larger than needed. constexpr size_t kPrivateKeySize = 250; constexpr size_t kPublicKeySize = 164; // Estimated max size (in bytes) of a SEC 1 serialized ECC public key @@ -45,6 +45,19 @@ constexpr uint8_t kSec1LeadCompressedPositive = 0x02; constexpr uint8_t kSec1LeadCompressedNegative = 0x03; constexpr uint8_t kSec1LeadUncompressed = 0x04; +// Expected max byte sequence length of big-endian encoded X and Y +// components when using SEC 1 encoding. +// +// Values are computed from SEC 1 section 2.3.3 equation for +// "mlen" +// +// mlen = ceil(log2(q) / 8) +// +// Where log2(q) is one of 256, 384 or 521. +constexpr size_t kSec1Secp256r1ComponentSize = 32; +constexpr size_t kSec1Secp384r1ComponentSize = 48; +constexpr size_t kSec1Secp521r1ComponentSize = 66; + // Checks that the first byte of a SEC 1 encoded EC public key point. // Only supported values are compressed w/ sign and uncompressed. // Other X9.62 values are not supported. @@ -70,6 +83,25 @@ using ScopedEcPoint = ScopedObject; void OpensslFreeU8(uint8_t* ptr) { OPENSSL_free(ptr); } using ScopedBuffer = ScopedObject; +// Internal utility for converting EccCurve to OEMCrypto's signature +// hash algorithm enum. +// +// The provided |curve| must be one of the non-unknown curve values. +OEMCrypto_SignatureHashAlgorithm CurveToSigningHashAlgorithm(EccCurve curve) { + switch (curve) { + case kEccSecp256r1: + return OEMCrypto_SHA2_256; + case kEccSecp384r1: + return OEMCrypto_SHA2_384; + case kEccSecp521r1: + return OEMCrypto_SHA2_512; + case kEccCurveUnknown: // Suppress compiler warnings + break; + } + LOGE("Invalid curve: %d", static_cast(curve)); + return OEMCrypto_SHA1; // SHA-1 is never used with ECC. +} + // == EC Group Utilts == const EC_GROUP* GetEcGroup(EccCurve curve) { @@ -167,6 +199,37 @@ EccCurve GetCurveFromKeyGroup(const EC_KEY* key) { return kEccCurveUnknown; } +size_t Sec1ComponentSize(EccCurve curve) { + switch (curve) { + case kEccSecp256r1: + return kSec1Secp256r1ComponentSize; + case kEccSecp384r1: + return kSec1Secp384r1ComponentSize; + case kEccSecp521r1: + return kSec1Secp521r1ComponentSize; + case kEccCurveUnknown: + return 0; + } + LOGE("Unexpected curve: %d", static_cast(curve)); + return 0; +} + +// Takes the provided big-endian encoded component and attempts +// to make it a fixed length byte sequence. +// +// Assumes |comp_in| is equal to or less than |component_size|. +std::vector MakeComponentFixedWidth( + const std::vector& comp_in, size_t component_size) { + // Ignore case of |comp_in| being larger. + if (comp_in.size() >= component_size) return comp_in; + std::vector comp_out(component_size, 0); + const size_t offset = component_size - comp_in.size(); + for (size_t i = 0; i < comp_in.size(); i++) { + comp_out[i + offset] = comp_in[i]; + } + return comp_out; +} + // == EC Key Operation Utilities == // Performs a SHA2 digest on the provided |message| and outputs the @@ -296,12 +359,12 @@ bool IsMatchingKeyPair(const EC_KEY* public_key, const EC_KEY* private_key) { LOGE("Failed to allocate BN ctx"); return false; } - // Returns: 1 if not equal, 0 if equal, -1 if error. + // Returns: 1 if not equal, 0 if equal, -1 (or other negative) if error. const int res = EC_POINT_cmp(EC_KEY_get0_group(public_key), EC_KEY_get0_public_key(public_key), EC_KEY_get0_public_key(private_key), ctx.get()); - if (res == -1) { - LOGE("Error occurred comparing keys"); + if (res < 0) { + LOGE("Error occurred comparing keys: res = %d", res); } return res == 0; } @@ -411,6 +474,67 @@ std::vector SerializeEccPublicKeyAsSec1KeyPoint(const EC_KEY* key, return sec1_key_point; } +std::vector SerializeEccPublicXCoord(const EC_KEY* key) { + // Use SEC 1 point. + std::vector coord = + SerializeEccPublicKeyAsSec1KeyPoint(key, /* compressed = */ true); + if (coord.empty()) return {}; + // Trim leading bytes. + coord.erase(coord.begin()); + return coord; +} + +std::vector SerializeEccPublicYCoord(const EC_KEY* key) { + // Use SEC 1 point. + const std::vector key_point = + SerializeEccPublicKeyAsSec1KeyPoint(key, /* compressed = */ false); + if (key_point.empty()) return {}; + if ((key_point.size() % 2) != 1) { + LOGE("Unexpected key point length: %zu", key_point.size()); + return {}; + } + const size_t offset = (key_point.size() / 2) + 1; + return std::vector(key_point.begin() + offset, key_point.end()); +} + +bool GetEccPublicYSignBit(const EC_KEY* key) { + // Use SEC 1 point. + const std::vector key_point = + SerializeEccPublicKeyAsSec1KeyPoint(key, /* compressed = */ true); + if (key_point.empty()) return {}; + return key_point.front() == kSec1LeadCompressedNegative; +} + +std::vector SerializeEccPrivateScalar(const EC_KEY* key) { + if (key == nullptr) { + // Programmer error, internal to this module. + LOGE("Input |key| is null"); + return {}; + } + + const size_t scalar_size = Sec1ComponentSize(GetCurveFromKeyGroup(key)); + if (scalar_size == 0) { + LOGE("Failed to get scalar size"); + return {}; + } + + const BIGNUM* d = EC_KEY_get0_private_key(key); + if (d == nullptr) { + LOGE("Failed to get private key scalar"); + return {}; + } + + std::vector private_scalar(scalar_size, 0); + const int to_length = static_cast(scalar_size); + const int res = BN_bn2binpad(d, private_scalar.data(), to_length); + if (res < 0 || res > to_length) { + LOGE("Failed to serialize private scalar: res = %d", res); + return {}; + } + private_scalar.resize(static_cast(res)); + return private_scalar; +} + void SetIetfComplianceFlags(EC_KEY* key) { // Required flags for IETF compliance. EC_KEY_set_asn1_flag(key, OPENSSL_EC_NAMED_CURVE); @@ -484,7 +608,7 @@ bool ParseEccSec1PublicKey(EccCurve curve, const uint8_t* buffer, size_t length, if (check == 0) { LOGE("ECC key parameters are invalid"); return false; - } else if (check == -1) { + } else if (check < 0) { LOGE("Failed to check ECC key"); return false; } @@ -492,10 +616,73 @@ bool ParseEccSec1PublicKey(EccCurve curve, const uint8_t* buffer, size_t length, return true; } +bool ParseEccKeyPoint(EccCurve curve, const std::vector& x, + const std::vector& y, ScopedEcKey* public_key) { + if (x.empty() || y.empty()) { + LOGE("Empty components"); + return false; + } + const size_t component_size = Sec1ComponentSize(curve); + if (component_size == 0) { + LOGE("Failed to get component size: curve = %d", static_cast(curve)); + return false; + } + if (x.size() > component_size || y.size() > component_size) { + LOGE( + "Invalid component length(s): x_length = %zu, " + "y_length = %zu, expected = %zu", + x.size(), y.size(), component_size); + return false; + } + // Convert to SEC 1 uncompressed form. + // 0x04 || X || Y + // Note: SEC 1 requires X and Y to be the same length and + // the length expected for the curve. + std::vector sec1_key_point = {kSec1LeadUncompressed}; + std::vector component = MakeComponentFixedWidth(x, component_size); + sec1_key_point.insert(sec1_key_point.end(), component.begin(), + component.end()); + component = MakeComponentFixedWidth(y, component_size); + sec1_key_point.insert(sec1_key_point.end(), component.begin(), + component.end()); + + return ParseEccSec1PublicKey(curve, sec1_key_point.data(), + sec1_key_point.size(), public_key); +} + +bool ParseEccKeyPoint(EccCurve curve, const std::vector& x, + bool y_sign_bit, ScopedEcKey* public_key) { + if (x.empty()) { + LOGE("Empty component"); + return false; + } + const size_t component_size = Sec1ComponentSize(curve); + if (component_size == 0) { + LOGE("Failed to get component size: curve = %d", static_cast(curve)); + return false; + } + if (x.size() > component_size) { + LOGE("Invalid component length: x_length = %zu, expected = %zu", x.size(), + component_size); + return false; + } + // Convert to SEC 1 compressed form. + // (0x02 or 0x03) || X + // Note: SEC 1 requires X to be the length expected for the curve. + std::vector sec1_key_point = { + y_sign_bit ? kSec1LeadCompressedNegative : kSec1LeadCompressedPositive}; + std::vector component = MakeComponentFixedWidth(x, component_size); + sec1_key_point.insert(sec1_key_point.end(), component.begin(), + component.end()); + + return ParseEccSec1PublicKey(curve, sec1_key_point.data(), + sec1_key_point.size(), public_key); +} + bool ParseEccPrivateKeyInfo(const uint8_t* buffer, size_t length, ScopedEcKey* key, EccCurve* curve) { if (length == 0) { - LOGE("Public key is too small: length = %zu", length); + LOGE("Private key is too small: length = %zu", length); return false; } ScopedBio bio(BIO_new_mem_buf(buffer, static_cast(length))); @@ -531,7 +718,7 @@ bool ParseEccPrivateKeyInfo(const uint8_t* buffer, size_t length, if (check == 0) { LOGE("ECC key parameters are invalid"); return false; - } else if (check == -1) { + } else if (check < 0) { LOGE("Failed to check ECC key"); return false; } @@ -543,6 +730,85 @@ bool ParseEccPrivateKeyInfo(const uint8_t* buffer, size_t length, SetIetfComplianceFlags(key->get()); return true; } + +bool ParseEccPrivateScalar(EccCurve curve, const std::vector& d_bytes, + ScopedEcKey* key) { + // Step 1: Validate and deserializes private scalar. + const size_t component_size = Sec1ComponentSize(curve); + if (component_size == 0) { + LOGE("Failed to get component size: curve = %d", static_cast(curve)); + return false; + } + if (d_bytes.empty() || d_bytes.size() > component_size) { + LOGE("Invalid public secalar length: d_length = %zu, expected = %zu", + d_bytes.size(), component_size); + return false; + } + const EC_GROUP* group = GetEcGroup(curve); + if (group == nullptr) { + LOGE("Failed to get ECC group"); + return false; + } + ScopedBigNum d( + BN_bin2bn(d_bytes.data(), static_cast(d_bytes.size()), nullptr)); + if (!d) { + LOGE("Failed to decode private scalar: d_length = %zu", d_bytes.size()); + return false; + } + + // Step 2: Compute public component. + ScopedEcPoint point(EC_POINT_new(group)); + if (!point) { + LOGE("Failed to allocate EC_POINT"); + return false; + } + ScopedBigNumCtx ctx(BN_CTX_new()); + if (!ctx) { + LOGE("Failed to allocate BN CTX"); + return false; + } + // EC_POINT_mul(group, R, n, NULL, NULL, ctx): + // let G := basePoint(group) + // R = n * G + if (!EC_POINT_mul(group, point.get(), d.get(), nullptr, nullptr, ctx.get())) { + LOGE("Failed to compute public point"); + return false; + } + + // Step 3: Construct EC_KEY from parts. + key->reset(EC_KEY_new()); + if (!*key) { + LOGE("Failed to allocate key"); + return false; + } + // Note: ownership of |group| is not transferred. + if (!EC_KEY_set_group(key->get(), group)) { + LOGE("Failed to set group"); + return false; + } + // Note: ownership of |d| is not transferred. + if (!EC_KEY_set_private_key(key->get(), d.get())) { + LOGE("Failed to set private scalar"); + return false; + } + // Note: ownership of |point| is not transferred. + if (!EC_KEY_set_public_key(key->get(), point.get())) { + LOGE("Failed to set public point"); + return false; + } + + // Step 4: Validate and set compliance flags. + const int check = EC_KEY_check_key(key->get()); + if (check == 0) { + LOGE("ECC key parameters are invalid"); + return false; + } else if (check < 0) { + LOGE("Failed to check ECC key"); + return false; + } + SetIetfComplianceFlags(key->get()); + return true; +} } // namespace std::string EccCurveToString(EccCurve curve) { @@ -692,6 +958,50 @@ std::unique_ptr EccPublicKey::LoadPrivateKeyInfo( return LoadPrivateKeyInfo(buffer.data(), buffer.size()); } +// static +std::unique_ptr EccPublicKey::LoadKeyPoint( + EccCurve curve, const std::vector& x, + const std::vector& y) { + std::unique_ptr key(new EccPublicKey()); + if (!key->InitFromKeyPoint(curve, x, y)) { + LOGE( + "Failed to initialize public key from raw key point: " + "curve = %s, x_length = %zu, y_length = %zu", + EccCurveToString(curve).c_str(), x.size(), y.size()); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EccPublicKey::LoadKeyPoint( + EccCurve curve, const std::vector& x, bool y_sign_bit) { + std::unique_ptr key(new EccPublicKey()); + if (!key->InitFromKeyPoint(curve, x, y_sign_bit)) { + LOGE( + "Failed to initialize public key from raw key point: " + "curve = %s, x_length = %zu, y_sign_bit = %s", + EccCurveToString(curve).c_str(), x.size(), + y_sign_bit ? "true" : "false"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EccPublicKey::LoadFromPrivateScalar( + EccCurve curve, const std::vector& d) { + std::unique_ptr key(new EccPublicKey()); + if (!key->InitFromPrivateScalar(curve, d)) { + LOGE( + "Failed to initialize public key from private scalar: " + "curve = %s, d_length = %zu", + EccCurveToString(curve).c_str(), d.size()); + key.reset(); + } + return key; +} + bool EccPublicKey::IsMatchingPrivateKey( const EccPrivateKey& private_key) const { if (private_key.curve() != curve_) { @@ -700,6 +1010,17 @@ bool EccPublicKey::IsMatchingPrivateKey( return IsMatchingKeyPair(GetEcKey(), private_key.GetEcKey()); } +bool EccPublicKey::IsMatchingPublicKey(const EccPublicKey& other_key) const { + if (this == &other_key) return true; + if (other_key.curve() != curve_) return false; + return IsMatchingKeyPair(GetEcKey(), other_key.GetEcKey()); +} + +OEMCrypto_SignatureHashAlgorithm EccPublicKey::GetSignatureHashAlgorithm() + const { + return CurveToSigningHashAlgorithm(curve_); +} + OEMCryptoResult EccPublicKey::Serialize(uint8_t* buffer, size_t* buffer_size) const { return SerializeEccPublicKey(key_, buffer, buffer_size); @@ -714,6 +1035,16 @@ std::vector EccPublicKey::SerializeAsSec1KeyPoint( return SerializeEccPublicKeyAsSec1KeyPoint(key_, compressed); } +std::vector EccPublicKey::SerializeXCoord() const { + return SerializeEccPublicXCoord(key_); +} + +std::vector EccPublicKey::SerializeYCoord() const { + return SerializeEccPublicYCoord(key_); +} + +bool EccPublicKey::GetYSignBit() const { return GetEccPublicYSignBit(key_); } + OEMCryptoResult EccPublicKey::VerifySignature(const uint8_t* message, size_t message_length, const uint8_t* signature, @@ -834,7 +1165,7 @@ bool EccPublicKey::InitFromSubjectPublicKeyInfo(const uint8_t* buffer, if (check == 0) { LOGE("ECC key parameters are invalid"); return false; - } else if (check == -1) { + } else if (check < 0) { LOGE("Failed to check ECC key"); return false; } @@ -894,6 +1225,44 @@ bool EccPublicKey::InitFromSec1KeyPoint(EccCurve curve, const uint8_t* buffer, return true; } +bool EccPublicKey::InitFromKeyPoint(EccCurve curve, + const std::vector& x, + const std::vector& y) { + ScopedEcKey key; + if (!ParseEccKeyPoint(curve, x, y, &key)) { + LOGE("Failed to parse key point"); + return false; + } + curve_ = curve; + key_ = key.release(); + return true; +} + +bool EccPublicKey::InitFromKeyPoint(EccCurve curve, + const std::vector& x, + bool y_sign_bit) { + ScopedEcKey key; + if (!ParseEccKeyPoint(curve, x, y_sign_bit, &key)) { + LOGE("Failed to parse key point"); + return false; + } + curve_ = curve; + key_ = key.release(); + return true; +} + +bool EccPublicKey::InitFromPrivateScalar(EccCurve curve, + const std::vector& d) { + ScopedEcKey key; + if (!ParseEccPrivateScalar(curve, d, &key)) { + LOGE("Failed to parse private scalar"); + return false; + } + curve_ = curve; + key_ = key.release(); + return true; +} + OEMCryptoResult EccPublicKey::DigestAndVerify( const uint8_t* message, size_t message_length, const ECDSA_SIG* sig_point) const { @@ -969,6 +1338,20 @@ std::unique_ptr EccPrivateKey::Load( return Load(buffer.data(), buffer.size()); } +// static +std::unique_ptr EccPrivateKey::LoadPrivateScalar( + EccCurve curve, const std::vector& d) { + std::unique_ptr key(new EccPrivateKey()); + if (!key->InitFromPrivateScalar(curve, d)) { + LOGE( + "Failed to initialize private key from private scalar: " + "curve = %s, d_length = %zu", + EccCurveToString(curve).c_str(), d.size()); + key.reset(); + } + return key; +} + std::unique_ptr EccPrivateKey::MakePublicKey() const { return EccPublicKey::New(*this); } @@ -980,6 +1363,11 @@ bool EccPrivateKey::IsMatchingPublicKey(const EccPublicKey& public_key) const { return IsMatchingKeyPair(public_key.GetEcKey(), GetEcKey()); } +OEMCrypto_SignatureHashAlgorithm EccPrivateKey::GetSignatureHashAlgorithm() + const { + return CurveToSigningHashAlgorithm(curve_); +} + OEMCryptoResult EccPrivateKey::Serialize(uint8_t* buffer, size_t* buffer_size) const { if (buffer_size == nullptr) { @@ -1064,6 +1452,20 @@ std::vector EccPrivateKey::SerializeAsPublicSec1KeyPoint( return SerializeEccPublicKeyAsSec1KeyPoint(key_, compressed); } +std::vector EccPrivateKey::SerializeXCoord() const { + return SerializeEccPublicXCoord(key_); +} + +std::vector EccPrivateKey::SerializeYCoord() const { + return SerializeEccPublicYCoord(key_); +} + +bool EccPrivateKey::GetYSignBit() const { return GetEccPublicYSignBit(key_); } + +std::vector EccPrivateKey::SerializePrivateScalar() const { + return SerializeEccPrivateScalar(key_); +} + OEMCryptoResult EccPrivateKey::GenerateSignature( const uint8_t* message, size_t message_length, uint8_t* signature, size_t* signature_length) const { @@ -1283,5 +1685,17 @@ bool EccPrivateKey::InitFromCurve(EccCurve curve) { key_ = key.release(); return true; } + +bool EccPrivateKey::InitFromPrivateScalar(EccCurve curve, + const std::vector& d) { + ScopedEcKey key; + if (!ParseEccPrivateScalar(curve, d, &key)) { + LOGE("Failed to parse private scalar"); + return false; + } + curve_ = curve; + key_ = key.release(); + return true; +} } // namespace util } // namespace wvoec diff --git a/oemcrypto/util/src/oemcrypto_ed_key.cpp b/oemcrypto/util/src/oemcrypto_ed_key.cpp new file mode 100644 index 0000000..3fb65f5 --- /dev/null +++ b/oemcrypto/util/src/oemcrypto_ed_key.cpp @@ -0,0 +1,855 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include "oemcrypto_ed_key.h" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "log.h" +#include "scoped_object.h" + +namespace wvoec { +namespace util { +namespace { +// Estimated max size (in bytes) of serialized PublicKeyInfo and +// SubjectPublicKeyInfo. Values are based on rough calculations. +// +// For PrivateKeyInfo, size is estimated based on containing both +// a private key and a public key. +constexpr size_t kPrivateKeyInfoSize = 128; +constexpr size_t kSubjectPublicKeyInfoSize = 64; + +// Key and signature sizes, from RFC 8032. +constexpr size_t kRawEd25519PublicKeySize = 32; +constexpr size_t kRawEd25519PrivateKeySize = 32; +constexpr size_t kRawEd25519SignatureSize = 64; + +// OpenSSL/BoringSSL wrappers. +using ScopedEvpPkey = ScopedObject; +using ScopedEvpPkeyCtx = ScopedObject; +using ScopedEvpMdCtx = ScopedObject; +using ScopedBio = ScopedObject; +using ScopedPrivateKeyInfo = + ScopedObject; +void OpensslFreeU8(uint8_t* ptr) { OPENSSL_free(ptr); } +using ScopedBuffer = ScopedObject; + +// Null engine variable. Clang-tidy will complain about trying +// to name function arguments in comments. OpenSSL and BoringSSL +// have different parameter names for some functions. +// Ex: Putting: /* engine = */ nullptr +// OpenSSL sometimes uses "engine" or "e". +// BoringSSL sometimes uses "engine", "e" or "unused". +constexpr ENGINE* const kNullEngine = nullptr; + +// == Key Checking == + +bool CheckKey(const ScopedEvpPkey& pkey, bool check_private) { + if (!pkey) return false; + // Check key only available on OpenSSL. +#ifndef OPENSSL_IS_BORINGSSL + // EVP_PKEY_CTX_new() will share ownership of the provided + // pkey (increment OpenSSL reference), but will release shared + // ownership once out of scope (decrement OpenSSL reference). + ScopedEvpPkeyCtx pkey_ctx = EVP_PKEY_CTX_new(pkey.get(), kNullEngine); + if (!pkey_ctx) { + LOGE("Failed to convert EVP_PKEY to EVP_PKEY_CTX"); + return false; + } + constexpr int kCheckSuccess = 1; + // Special return value if operation is not supported. + constexpr int kCheckNoSupported = -2; + // EVP_PKEY_check() will check both public and private. + // EVP_PKEY_public_check() will only check public. + const int check = check_private ? EVP_PKEY_check(pkey_ctx.get()) + : EVP_PKEY_public_check(pkey_ctx.get()); + if (check == kCheckNoSupported) { + LOGD("Ed25519 key checking not supported"); + return true; // Assume key is OK. + } + if (check != kCheckSuccess) { + LOGE("Key check failed: res = %d", check); + return false; + } + return true; +#endif + (void)check_private; + return true; +} + +// == PrivateKeyInfo Utilities == + +// Converts an ASN.1 DER encoded PrivateKeyInfo to an +// EVP_PKEY with type Ed25519. +bool ParseEdPrivateKeyInfo(const uint8_t* buffer, size_t length, + ScopedEvpPkey* pkey_out) { + if (length == 0) { + LOGE("Private key is too small: length = %zu", length); + return false; + } + // Unfortunately, neither OpenSSL or BoringSSL support the + // user-friend d2i_PrivateKey when using Ed25519 keys. + + // Step 1: Setup BIO with the buffer data. + ScopedBio bio = BIO_new_mem_buf(buffer, static_cast(length)); + if (!bio) { + LOGE("Failed to allocate BIO buffer"); + return false; + } + // Step 2: Deserializes PKCS8 PrivateKeyInfo containing an Ed25519 key. + ScopedPrivateKeyInfo priv_info( + d2i_PKCS8_PRIV_KEY_INFO_bio(bio.get(), nullptr)); + if (!priv_info) { + LOGE("Failed to parse PrivateKeyInfo"); + return false; + } + + // Step 3a: Extract out the EVP_PKEY. + ScopedEvpPkey pkey(EVP_PKCS82PKEY(priv_info.get())); + if (!pkey) { + LOGE("Failed to convert PKCS8 to EVP"); + return false; + } + // Step 3b: Verify that the key is Ed25519. + const int key_type = EVP_PKEY_base_id(pkey.get()); + if (key_type != EVP_PKEY_ED25519) { + LOGE("Decoded private key is not Ed25519"); + return false; + } + // Step 3c: Verify key data. + if (!CheckKey(pkey, /* check_private = */ true)) { + // CheckKey() will log details. + return false; + } + *pkey_out = std::move(pkey); + return true; +} + +// == SubjectPublicKeyInfo Utilities == + +OEMCryptoResult SerializeSubjectPublicKeyInfoInternal(const EVP_PKEY* pkey, + uint8_t* buffer, + size_t* buffer_size) { + if (buffer_size == nullptr) { + LOGE("Output buffer size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (buffer == nullptr && *buffer_size > 0) { + LOGE("Output buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + uint8_t* der_key_raw = nullptr; + // Const cast required for OpenSSL/BoringSSL compatibility. + // BoringSSL: + // int i2d_PUBKEY(const EVP_PKEY *a, unsigned char **pp) + // OpenSSL: + // int i2d_PUBKEY(EVP_PKEY *a, unsigned char **pp) + // OpenSSL does not modify the key, but changes reference + // counts internally. + const int der_res = i2d_PUBKEY(const_cast(pkey), &der_key_raw); + if (der_res < 0) { + LOGE("Public key serialization failed"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + ScopedBuffer der_key(der_key_raw); + der_key_raw = nullptr; + if (!der_key) { + LOGE("Encoded key is unexpectedly null"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (der_res == 0) { + LOGE("Unexpected DER encoded size"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + const size_t required_size = static_cast(der_res); + if (buffer == nullptr || *buffer_size < required_size) { + *buffer_size = required_size; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + memcpy(buffer, der_key.get(), required_size); + *buffer_size = required_size; + return OEMCrypto_SUCCESS; +} + +std::vector SerializeSubjectPublicKeyInfoInternal( + const EVP_PKEY* pkey) { + size_t subject_public_key_info_size = kSubjectPublicKeyInfoSize; + std::vector subject_public_key_info(subject_public_key_info_size); + const OEMCryptoResult result = SerializeSubjectPublicKeyInfoInternal( + pkey, subject_public_key_info.data(), &subject_public_key_info_size); + if (result != OEMCrypto_SUCCESS) { + LOGE("Failed to serialize public key: result = %d", + static_cast(result)); + subject_public_key_info.clear(); + } else { + subject_public_key_info.resize(subject_public_key_info_size); + } + return subject_public_key_info; +} + +std::vector SerializeRawPublicKeyInternal(const EVP_PKEY* pkey) { + size_t key_length = kRawEd25519PublicKeySize; + std::vector raw_public_key(key_length); + if (!EVP_PKEY_get_raw_public_key(pkey, raw_public_key.data(), &key_length)) { + LOGE("Failed to serialize raw public key"); + return {}; + } + if (key_length > kRawEd25519PublicKeySize) { + LOGE("Unexpected raw key size: actual = %zu, expected = %zu", key_length, + kRawEd25519PublicKeySize); + return {}; + } + raw_public_key.resize(key_length); + return raw_public_key; +} + +bool IsMatchingKeyPair(const EVP_PKEY* public_key, + const EVP_PKEY* private_key) { + const int res = EVP_PKEY_cmp(public_key, private_key); + if (res == 1) return true; + if (res != 0) { + // A value other than 1 or 0 indicates an error during the + // comparison process. + LOGE("Failed to compare keys: res = %d", res); + } + return false; +} +} // namespace + +// static +std::unique_ptr EdPublicKey::New(const EdPrivateKey& private_key) { + std::unique_ptr key(new EdPublicKey()); + if (!key->InitFromPrivateKey(private_key)) { + LOGE("Failed to initialize from private key"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EdPublicKey::Load(const uint8_t* buffer, + size_t length) { + if (buffer == nullptr) { + LOGE("Provided public key buffer is null"); + return nullptr; + } + if (length == 0) { + LOGE("Provided public key buffer is zero length"); + return nullptr; + } + std::unique_ptr key(new EdPublicKey()); + if (!key->InitFromSubjectPublicKeyInfo(buffer, length)) { + LOGE("Failed to initialize public key from SubjectPublicKeyInfo"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EdPublicKey::Load(const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return nullptr; + } + return Load(reinterpret_cast(buffer.data()), buffer.size()); +} + +// static +std::unique_ptr EdPublicKey::Load( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return nullptr; + } + return Load(buffer.data(), buffer.size()); +} + +// static +std::unique_ptr EdPublicKey::LoadRaw(const uint8_t* buffer, + size_t length) { + if (buffer == nullptr) { + LOGE("Provided public key buffer is null"); + return nullptr; + } + if (length == 0) { + LOGE("Provided public key buffer is zero length"); + return nullptr; + } + std::unique_ptr key(new EdPublicKey()); + if (!key->InitFromRawPublicKey(buffer, length)) { + LOGE("Failed to initialize public key from raw public key"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EdPublicKey::LoadRaw(const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return nullptr; + } + return LoadRaw(reinterpret_cast(buffer.data()), + buffer.size()); +} + +// static +std::unique_ptr EdPublicKey::LoadRaw( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return nullptr; + } + return LoadRaw(buffer.data(), buffer.size()); +} + +// static +std::unique_ptr EdPublicKey::LoadFromRawPrivate( + const uint8_t* buffer, size_t length) { + if (buffer == nullptr) { + LOGE("Provided private key buffer is null"); + return nullptr; + } + if (length == 0) { + LOGE("Provided private key buffer is zero length"); + return nullptr; + } + std::unique_ptr key(new EdPublicKey()); + if (!key->InitFromRawPrivateKey(buffer, length)) { + LOGE("Failed to initialize public key from raw private key"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EdPublicKey::LoadFromRawPrivate( + const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return nullptr; + } + return LoadFromRawPrivate(reinterpret_cast(buffer.data()), + buffer.size()); +} + +// static +std::unique_ptr EdPublicKey::LoadFromRawPrivate( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return nullptr; + } + return LoadFromRawPrivate(buffer.data(), buffer.size()); +} + +bool EdPublicKey::IsMatchingPrivateKey(const EdPrivateKey& private_key) const { + return IsMatchingKeyPair(key_, private_key.GetEvpPkey()); +} + +OEMCryptoResult EdPublicKey::SerializeSubjectPublicKeyInfo( + uint8_t* buffer, size_t* buffer_size) const { + return SerializeSubjectPublicKeyInfoInternal(key_, buffer, buffer_size); +} + +std::vector EdPublicKey::SerializeSubjectPublicKeyInfo() const { + return SerializeSubjectPublicKeyInfoInternal(key_); +} + +std::vector EdPublicKey::SerializeRaw() const { + return SerializeRawPublicKeyInternal(key_); +} + +OEMCryptoResult EdPublicKey::VerifySignature(const uint8_t* message, + size_t message_length, + const uint8_t* signature, + size_t signature_length) const { + if (signature == nullptr || signature_length == 0) { + LOGE("Signature is missing"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (message == nullptr && message_length > 0) { + LOGE("Bad message data"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + // Step 1: Setup OpenSSL/BoringSSL signing CTX. + ScopedEvpMdCtx md_ctx = EVP_MD_CTX_new(); + // |type| is null for PureEdDsa. + if (!EVP_DigestVerifyInit(md_ctx.get(), /* pctx = */ nullptr, + /* type = */ nullptr, kNullEngine, key_)) { + LOGE("Failed to initialize Ed25519 EVP_MD_CTX"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2: Perform verification. + const int res = EVP_DigestVerify(md_ctx.get(), signature, signature_length, + message, message_length); + if (res == 0) { + LOGD("Signatures did not match"); + return OEMCrypto_ERROR_SIGNATURE_FAILURE; + } + if (res < 0) { + LOGE("Error occurred verifying signature: res = %d", res); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + return OEMCrypto_SUCCESS; +} + +OEMCryptoResult EdPublicKey::VerifySignature( + const std::string& message, const std::string& signature) const { + if (signature.empty()) { + LOGE("Signature should not be empty"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + return VerifySignature( + reinterpret_cast(message.data()), message.size(), + reinterpret_cast(signature.data()), signature.size()); +} + +OEMCryptoResult EdPublicKey::VerifySignature( + const std::vector& message, + const std::vector& signature) const { + if (signature.empty()) { + LOGE("Signature should not be empty"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + return VerifySignature(message.data(), message.size(), signature.data(), + signature.size()); +} + +EdPublicKey::~EdPublicKey() { + if (key_ != nullptr) { + EVP_PKEY_free(key_); + key_ = nullptr; + } +} + +bool EdPublicKey::InitFromSubjectPublicKeyInfo(const uint8_t* buffer, + size_t length) { + // Deserialize SubjectPublicKeyInfo + const uint8_t* tp = buffer; + ScopedEvpPkey pkey = d2i_PUBKEY(nullptr, &tp, length); + if (!pkey) { + LOGE("Failed to parse SubjectPublicKeyInfo"); + return false; + } + const int nid = EVP_PKEY_base_id(pkey.get()); + if (nid < 0) { + LOGE("Failed to get NID of parsed SubjectPublicKeyInfo"); + return false; + } + if (nid != EVP_PKEY_ED25519) { + LOGE("Parsed SubjectPublicKeyInfo is not Ed25519: nid = %d", nid); + return false; + } + if (!CheckKey(pkey, /* check_private = */ false)) { + return false; // CheckKey() will log details + } + key_ = pkey.release(); + return true; +} + +bool EdPublicKey::InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length) { + ScopedEvpPkey pkey; + if (!ParseEdPrivateKeyInfo(buffer, length, &pkey)) return false; + // OpenSSL and BoringSSL do not have a common way of clearing + // the private key parts of the key. + key_ = pkey.release(); + return true; +} + +bool EdPublicKey::InitFromRawPublicKey(const uint8_t* buffer, size_t length) { + if (length != kRawEd25519PublicKeySize) { + LOGE("Unexpected raw public key length: actual = %zu, expected = %zu", + length, kRawEd25519PublicKeySize); + return false; + } + ScopedEvpPkey pkey = EVP_PKEY_new_raw_public_key(EVP_PKEY_ED25519, + kNullEngine, buffer, length); + if (!pkey) { + LOGE("Failed to create Ed25519 public key from raw public key"); + return false; + } + key_ = pkey.release(); + return true; +} + +bool EdPublicKey::InitFromRawPrivateKey(const uint8_t* buffer, size_t length) { + if (length != kRawEd25519PrivateKeySize) { + LOGE("Unexpected raw private key length: actual = %zu, expected = %zu", + length, kRawEd25519PrivateKeySize); + return false; + } + ScopedEvpPkey pkey = EVP_PKEY_new_raw_private_key( + EVP_PKEY_ED25519, kNullEngine, buffer, length); + if (!pkey) { + LOGE("Failed to create Ed25519 public key from raw private key"); + return false; + } + key_ = pkey.release(); + return true; +} + +bool EdPublicKey::InitFromPrivateKey(const EdPrivateKey& private_key) { + // No reliable way to duplicate the key between OpenSSL and + // BoringSSL. + // Instead, share the pointer, and increament the reference count. + EVP_PKEY* pkey = const_cast(private_key.GetEvpPkey()); + if (!EVP_PKEY_up_ref(pkey)) { + LOGE("Failed to duplicate EVP_PKEY"); + return false; + } + key_ = pkey; + return true; +} + +// == Ed25519 Private Key == + +// static +std::unique_ptr EdPrivateKey::New() { + std::unique_ptr key(new EdPrivateKey()); + if (!key->InitNew()) { + LOGE("Failed to initialize new private key"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EdPrivateKey::Load(const uint8_t* buffer, + size_t length) { + if (buffer == nullptr) { + LOGE("Provided private key buffer is null"); + return nullptr; + } + if (length == 0) { + LOGE("Provided private key buffer is zero length"); + return nullptr; + } + std::unique_ptr key(new EdPrivateKey()); + if (!key->InitFromPrivateKeyInfo(buffer, length)) { + LOGE("Failed to initialize private key from PrivateKeyInfo"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EdPrivateKey::Load(const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return nullptr; + } + return Load(reinterpret_cast(buffer.data()), buffer.size()); +} + +// static +std::unique_ptr EdPrivateKey::Load( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return nullptr; + } + return Load(buffer.data(), buffer.size()); +} + +// static +std::unique_ptr EdPrivateKey::LoadRaw(const uint8_t* buffer, + size_t length) { + if (buffer == nullptr) { + LOGE("Provided private key buffer is null"); + return nullptr; + } + if (length == 0) { + LOGE("Provided private key buffer is zero length"); + return nullptr; + } + std::unique_ptr key(new EdPrivateKey()); + if (!key->InitFromRaw(buffer, length)) { + LOGE("Failed to initialize private key from raw private key"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr EdPrivateKey::LoadRaw(const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return nullptr; + } + return LoadRaw(reinterpret_cast(buffer.data()), + buffer.size()); +} + +// static +std::unique_ptr EdPrivateKey::LoadRaw( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Provided private key buffer is empty"); + return nullptr; + } + return LoadRaw(buffer.data(), buffer.size()); +} + +std::unique_ptr EdPrivateKey::MakePublicKey() const { + return EdPublicKey::New(*this); +} + +bool EdPrivateKey::IsMatchingPublicKey(const EdPublicKey& public_key) const { + return IsMatchingKeyPair(public_key.GetEvpPkey(), GetEvpPkey()); +} + +OEMCryptoResult EdPrivateKey::SerializePrivateKeyInfo( + uint8_t* buffer, size_t* buffer_size) const { + if (buffer_size == nullptr) { + LOGE("Output buffer size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (buffer == nullptr && *buffer_size > 0) { + LOGE("Output buffer is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + // Step 1: Convert from EVP_PKEY to PKCS#8 PrivateKeyInfo. + ScopedPrivateKeyInfo priv_info(EVP_PKEY2PKCS8(key_)); + if (!priv_info) { + LOGE("Failed to convert Ed25519 key to PKCS8 info"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 2: Create a temporary BIO to hold ASN.1 DER encoded + // value. + ScopedBio bio(BIO_new(BIO_s_mem())); + if (!bio) { + LOGE("Failed to allocate IO buffer for Ed25519 key"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 3: Serialize PKCS8 to DER encoding. + if (!i2d_PKCS8_PRIV_KEY_INFO_bio(bio.get(), priv_info.get())) { + LOGE("Failed to serialize Ed25519 key"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 4: Extract out a pointer too the raw key data. + char* key_ptr = nullptr; + const long key_size = BIO_get_mem_data(bio.get(), &key_ptr); + if (key_size < 0) { + LOGE("Failed to get Ed25519 PrivateKeyInfo size"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (key_ptr == nullptr) { + LOGE("Encoded key is unexpectedly null"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + // Step 5a: Check that the provided buffer is large enough + // for output. + const size_t required_size = static_cast(key_size); + if (buffer == nullptr || *buffer_size < required_size) { + *buffer_size = required_size; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + // Step 5b: Copy output. + memcpy(buffer, key_ptr, required_size); + *buffer_size = required_size; + return OEMCrypto_SUCCESS; +} + +std::vector EdPrivateKey::SerializePrivateKeyInfo() const { + size_t private_key_info_size = kPrivateKeyInfoSize; + std::vector private_key_info(private_key_info_size); + const OEMCryptoResult result = + SerializePrivateKeyInfo(private_key_info.data(), &private_key_info_size); + if (result != OEMCrypto_SUCCESS) { + LOGE("Failed to serialize public key: result = %d", + static_cast(result)); + private_key_info.clear(); + } else { + private_key_info.resize(private_key_info_size); + } + return private_key_info; +} + +OEMCryptoResult EdPrivateKey::SerializeAsSubjectPublicKeyInfo( + uint8_t* buffer, size_t* buffer_size) const { + return SerializeSubjectPublicKeyInfoInternal(key_, buffer, buffer_size); +} + +std::vector EdPrivateKey::SerializeAsSubjectPublicKeyInfo() const { + return SerializeSubjectPublicKeyInfoInternal(key_); +} + +std::vector EdPrivateKey::SerializeRaw() const { + size_t key_length = kRawEd25519PrivateKeySize; + std::vector raw_private_key(key_length); + if (!EVP_PKEY_get_raw_private_key(key_, raw_private_key.data(), + &key_length)) { + LOGE("Failed to serialize raw private key"); + return {}; + } + if (key_length > kRawEd25519PrivateKeySize) { + LOGE("Unexpected raw key size: actual = %zu, expected = %zu", key_length, + kRawEd25519PrivateKeySize); + return {}; + } + raw_private_key.resize(key_length); + return raw_private_key; +} + +std::vector EdPrivateKey::SerializeAsRawPublicKey() const { + return SerializeRawPublicKeyInternal(key_); +} + +OEMCryptoResult EdPrivateKey::GenerateSignature( + const uint8_t* message, size_t message_length, uint8_t* signature, + size_t* signature_length) const { + if (signature_length == nullptr) { + LOGE("Output signature size is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (signature == nullptr && *signature_length > 0) { + LOGE("Output signature is null"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + if (message == nullptr && message_length > 0) { + LOGE("Invalid message data"); + return OEMCrypto_ERROR_INVALID_CONTEXT; + } + // Step 1: Setup OpenSSL/BoringSSL signing CTX. + ScopedEvpMdCtx md_ctx = EVP_MD_CTX_new(); + // |type| is null for PureEdDsa. + if (!EVP_DigestSignInit(md_ctx.get(), /* pctx = */ nullptr, + /* type = */ nullptr, kNullEngine, key_)) { + LOGE("Failed to initialize Ed25519 EVP_MD_CTX"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + + // Step 2: Calculate estimated size. + size_t required_size = 0; + if (!EVP_DigestSign(md_ctx.get(), nullptr, &required_size, message, + message_length)) { + LOGE("Failed to determine size of Ed25519 signature"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (signature == nullptr || *signature_length < required_size) { + *signature_length = required_size; + return OEMCrypto_ERROR_SHORT_BUFFER; + } + + // Step 3: Generate signature. + if (!EVP_DigestSign(md_ctx.get(), signature, signature_length, message, + message_length)) { + LOGE("Failed to generate Ed25519 signature"); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + if (*signature_length > required_size) { + LOGE("Unexpected signature size: actual = %zu, expected = %zu", + *signature_length, required_size); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } + return OEMCrypto_SUCCESS; +} + +std::vector EdPrivateKey::GenerateSignature( + const std::vector& message) const { + size_t signature_size = SignatureSize(); + std::vector signature(signature_size, 0); + const OEMCryptoResult res = GenerateSignature( + message.data(), message.size(), signature.data(), &signature_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to generate signature: result = %d", static_cast(res)); + signature.clear(); + } else { + signature.resize(signature_size); + } + return signature; +} + +std::vector EdPrivateKey::GenerateSignature( + const std::string& message) const { + size_t signature_size = SignatureSize(); + std::vector signature(signature_size, 0); + const OEMCryptoResult res = + GenerateSignature(reinterpret_cast(message.data()), + message.size(), signature.data(), &signature_size); + if (res != OEMCrypto_SUCCESS) { + LOGE("Failed to generate signature: result = %d", static_cast(res)); + signature.clear(); + } else { + signature.resize(signature_size); + } + return signature; +} + +size_t EdPrivateKey::SignatureSize() const { return kRawEd25519SignatureSize; } + +EdPrivateKey::~EdPrivateKey() { + if (key_ != nullptr) { + EVP_PKEY_free(key_); + key_ = nullptr; + } +} + +bool EdPrivateKey::InitFromPrivateKeyInfo(const uint8_t* buffer, + size_t length) { + ScopedEvpPkey pkey; + if (!ParseEdPrivateKeyInfo(buffer, length, &pkey)) return false; + key_ = pkey.release(); + return true; +} + +bool EdPrivateKey::InitFromRaw(const uint8_t* buffer, size_t length) { + if (length != kRawEd25519PrivateKeySize) { + LOGE("Unexpected raw private key length: actual = %zu, expected = %zu", + length, kRawEd25519PrivateKeySize); + return false; + } + ScopedEvpPkey pkey = EVP_PKEY_new_raw_private_key( + EVP_PKEY_ED25519, kNullEngine, buffer, length); + if (!pkey) { + LOGE("Failed to create Ed25519 private key raw private key"); + return false; + } + key_ = pkey.release(); + return true; +} + +bool EdPrivateKey::InitNew() { + ScopedEvpPkeyCtx ed_ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_ED25519, kNullEngine); + if (!ed_ctx) { + LOGE("Failed to create EVP_PKEY_CTX for Ed25519"); + return false; + } + // EVP_PKEY_keygen_init() returns 0 or a negative value on error. + // Certain negative values indicate particular failures. + int res = EVP_PKEY_keygen_init(ed_ctx.get()); + if (res <= 0) { + LOGE("Failed to initialize EVP_PKEY_CTX for Ed25519: res = %d", res); + return false; + } + EVP_PKEY* pkey_ptr = nullptr; + res = EVP_PKEY_keygen(ed_ctx.get(), &pkey_ptr); + if (res <= 0) { + LOGE("Failed to generate private key for Ed25519: res = %d", res); + return false; + } + if (pkey_ptr == nullptr) { + LOGE("EVP_PKEY is unexpectedly null"); + return false; + } + key_ = pkey_ptr; + return true; +} +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/src/oemcrypto_oem_cert.cpp b/oemcrypto/util/src/oemcrypto_oem_cert.cpp index e4ec3f9..ed6fb87 100644 --- a/oemcrypto/util/src/oemcrypto_oem_cert.cpp +++ b/oemcrypto/util/src/oemcrypto_oem_cert.cpp @@ -8,40 +8,49 @@ #include -#include -#include -#include - #include "log.h" +#include "oemcrypto_ecc_key.h" +#include "oemcrypto_oem_cert_chain.h" #include "oemcrypto_rsa_key.h" -#include "scoped_object.h" -#include "wv_class_utils.h" namespace wvoec { namespace util { namespace { -using ScopedCertificate = ScopedObject; -using ScopedEvpKey = ScopedObject; -using ScopedPkcs7 = ScopedObject; - -constexpr size_t kExpectedCertCount = 2; // Leaf and intermediate. -constexpr int kDeviceCertIndex = 0; - -// Checks that the |public_key| from an X.509 certificate is the -// correct public key of the serialized |private_key_data|. -OEMCryptoResult VerifyRsaKey(const RSA* public_key, +// Checks that the serialized |public_key_data| from an X.509 +// certificate is the correct public key of the serialized +// |private_key_data|. +OEMCryptoResult VerifyRsaKey(const std::vector& public_key_data, const std::vector& private_key_data) { - if (public_key == nullptr) { - LOGE("RSA key is null"); - return OEMCrypto_ERROR_UNKNOWN_FAILURE; - } - std::unique_ptr private_key = - RsaPrivateKey::Load(private_key_data); + auto private_key = RsaPrivateKey::Load(private_key_data); if (!private_key) { LOGE("Failed to parse provided RSA private key"); return OEMCrypto_ERROR_INVALID_KEY; } - if (!RsaKeysAreMatchingPair(public_key, private_key->GetRsaKey())) { + auto public_key = RsaPublicKey::Load(public_key_data); + if (!public_key) { + LOGE("Failed to parse provided RSA public key"); + return OEMCrypto_ERROR_INVALID_KEY; + } + if (!private_key->IsMatchingPublicKey(*public_key)) { + LOGE("OEM certificate keys do not match"); + return OEMCrypto_ERROR_INVALID_KEY; + } + return OEMCrypto_SUCCESS; +} + +OEMCryptoResult VerifyEccKey(const std::vector& public_key_data, + const std::vector& private_key_data) { + auto private_key = EccPrivateKey::Load(private_key_data); + if (!private_key) { + LOGE("Failed to parse provided ECC private key"); + return OEMCrypto_ERROR_INVALID_KEY; + } + auto public_key = EccPublicKey::Load(public_key_data); + if (!public_key) { + LOGE("Failed to parse provided ECC public key"); + return OEMCrypto_ERROR_INVALID_KEY; + } + if (!private_key->IsMatchingPublicKey(*public_key)) { LOGE("OEM certificate keys do not match"); return OEMCrypto_ERROR_INVALID_KEY; } @@ -49,122 +58,49 @@ OEMCryptoResult VerifyRsaKey(const RSA* public_key, } } // namespace -// This utility class encapsulates the minimum functionality of an -// OEM Public Certificate required to verify a device's OEM Public -// Certificate. -class OemPublicCertificate { - public: - ~OemPublicCertificate() = default; - WVCDM_DISALLOW_COPY_AND_MOVE(OemPublicCertificate); - - // Loads a PKCS #7 signedData message with certificate chain. - // Minimum validation is performed. Only checks that the - // device's public key is of a known type (RSA). - static std::unique_ptr Load(const uint8_t* public_cert, - size_t public_cert_size) { - std::unique_ptr oem_public_cert; - if (public_cert == nullptr) { - LOGE("Public cert buffer is null"); - return oem_public_cert; - } - if (public_cert_size == 0) { - LOGE("Public cert buffer is empty"); - return oem_public_cert; - } - oem_public_cert.reset(new OemPublicCertificate()); - if (!oem_public_cert->InitFromBuffer(public_cert, public_cert_size)) { - oem_public_cert.reset(); - } - return oem_public_cert; - } - - OemCertificate::KeyType key_type() const { return key_type_; } - const std::vector& cert_data() const { return cert_data_; } - - const RSA* GetPublicRsaKey() const { - return EVP_PKEY_get0_RSA(device_public_key_.get()); - } - - private: - OemPublicCertificate() = default; - - bool InitFromBuffer(const uint8_t* public_cert, size_t public_cert_size) { - // Step 1: Parse the PKCS7 certificate chain as signedData. - const uint8_t* public_cert_ptr = public_cert; - pkcs7_.reset(d2i_PKCS7(nullptr, &public_cert_ptr, public_cert_size)); - if (!pkcs7_) { - LOGE("Failed to parse PKCS#7 certificate chain"); - return false; - } - if (!PKCS7_type_is_signed(pkcs7_.get())) { - LOGE("OEM Public Certificate is not PKCS#7 signed data"); - return false; - } - PKCS7_SIGNED* signed_data = pkcs7_->d.sign; - // Step 2: Get the leaf certificate. - const size_t cert_count = - static_cast(sk_X509_num(signed_data->cert)); - if (cert_count != kExpectedCertCount) { - LOGE("Unexpected number of certificates: expected = %zu, actual = %zu", - kExpectedCertCount, cert_count); - return false; - } - X509* leaf_cert = sk_X509_value(signed_data->cert, kDeviceCertIndex); - // Step 3a: Get the device's public key. - device_public_key_.reset(X509_get_pubkey(leaf_cert)); - if (!device_public_key_) { - LOGE("Device X.509 certificate is missing a public key"); - return false; - } - // Step 3b: Check key type. - if (EVP_PKEY_get0_RSA(device_public_key_.get()) == nullptr) { - LOGE("Device public key is not RSA"); - return false; - } - key_type_ = OemCertificate::kRsa; - cert_data_.assign(public_cert, public_cert + public_cert_size); - return true; - } - - OemCertificate::KeyType key_type_ = OemCertificate::kNone; - // OpenSSL/BoringSSL's implementation of PKCS7 objects. - ScopedPkcs7 pkcs7_; - ScopedEvpKey device_public_key_; - std::vector cert_data_; -}; - // ===== ===== ===== OEM Certificate ===== ===== ===== // static std::unique_ptr OemCertificate::Create( const uint8_t* private_key_data, size_t private_key_size, const uint8_t* public_cert_data, size_t public_cert_size) { - std::unique_ptr oem_cert; // Step 1: Verify public cert is well-formed. - std::unique_ptr oem_public_cert = - OemPublicCertificate::Load(public_cert_data, public_cert_size); - if (!oem_public_cert) { - LOGE("Invalid OEM Public Certificate"); - return oem_cert; + std::unique_ptr cert_chain = + OemCertificateChain::LoadPkcs7(public_cert_data, public_cert_size); + if (!cert_chain) { + LOGE("Invalid OEM Public Certificate Chain"); + return nullptr; } // Step 2: Verify private key is well-formed. - switch (oem_public_cert->key_type()) { - case kRsa: { + const OemCertKeyType device_cert_key_type = + cert_chain->device_cert()->GetSubjectPublicKeyType(); + switch (device_cert_key_type) { + case OemCertKeyType::kRsa: { std::unique_ptr oem_private_key = RsaPrivateKey::Load(private_key_data, private_key_size); if (!oem_private_key) { - LOGE("Invalid OEM Private Key"); - return oem_cert; + LOGE("Invalid OEM RSA Private Key"); + return nullptr; } } break; - case kNone: // Suppress compiler warnings. - return oem_cert; + case OemCertKeyType::kEcc: { + std::unique_ptr oem_private_key = + EccPrivateKey::Load(private_key_data, private_key_size); + if (!oem_private_key) { + LOGE("Invalid OEM ECC Private Key"); + return nullptr; + } + } break; + case OemCertKeyType::kUnknown: { + LOGE("Unsupported OEM public key type"); + return nullptr; + } } + std::unique_ptr oem_cert(new OemCertificate()); // Step 3: Copy over data. - oem_cert.reset(new OemCertificate()); oem_cert->private_key_.assign(private_key_data, private_key_data + private_key_size); - oem_cert->public_cert_ = std::move(oem_public_cert); + oem_cert->cert_chain_ = std::move(cert_chain); return oem_cert; } @@ -184,8 +120,12 @@ std::unique_ptr OemCertificate::Create( public_cert.size()); } -OemCertificate::KeyType OemCertificate::key_type() const { - return public_cert_->key_type(); +OemCertKeyType OemCertificate::GetKeyType() const { + return cert_chain_->device_cert()->GetSubjectPublicKeyType(); +} + +std::vector OemCertificate::SerializePublicKey() const { + return cert_chain_->device_cert()->SerializedSubjectPublicKeyInfo(); } OEMCryptoResult OemCertificate::GetPublicCertificate( @@ -198,7 +138,7 @@ OEMCryptoResult OemCertificate::GetPublicCertificate( LOGE("Output |public_cert| is null"); return OEMCrypto_ERROR_INVALID_CONTEXT; } - const std::vector& cert_data = public_cert_->cert_data(); + const std::vector& cert_data = cert_chain_->cert_data(); if (public_cert == nullptr || *public_cert_length < cert_data.size()) { *public_cert_length = cert_data.size(); return OEMCrypto_ERROR_SHORT_BUFFER; @@ -209,23 +149,24 @@ OEMCryptoResult OemCertificate::GetPublicCertificate( } const std::vector& OemCertificate::GetPublicCertificate() const { - return public_cert_->cert_data(); + return cert_chain_->cert_data(); } OEMCryptoResult OemCertificate::IsCertificateValid() const { - switch (key_type()) { - case kRsa: - return VerifyRsaKey(public_cert_->GetPublicRsaKey(), private_key_); - case kNone: // Suppress compiler warnings. + const OemPublicCertificate* device_cert = cert_chain_->device_cert(); + const OemCertKeyType key_type = device_cert->GetSubjectPublicKeyType(); + switch (key_type) { + case OemCertKeyType::kRsa: + return VerifyRsaKey(device_cert->SerializedSubjectPublicKeyInfo(), + private_key_); + case OemCertKeyType::kEcc: + return VerifyEccKey(device_cert->SerializedSubjectPublicKeyInfo(), + private_key_); + case OemCertKeyType::kUnknown: // Suppress compiler warnings. break; } - LOGE("Unexpected error key type: type = %d", static_cast(key_type())); + LOGE("Unexpected error key type: type = %d", static_cast(key_type)); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } - -// Constructor and destructor do not perform anything special, but -// must be declared within a scope which defines OemPublicCertificate. -OemCertificate::OemCertificate() {} -OemCertificate::~OemCertificate() {} } // namespace util } // namespace wvoec diff --git a/oemcrypto/util/src/oemcrypto_oem_cert_chain.cpp b/oemcrypto/util/src/oemcrypto_oem_cert_chain.cpp new file mode 100644 index 0000000..1fbb1f5 --- /dev/null +++ b/oemcrypto/util/src/oemcrypto_oem_cert_chain.cpp @@ -0,0 +1,179 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include "oemcrypto_oem_cert_chain.h" + +#include + +#include +#include + +#include +#include +#include + +#include "log.h" +#include "scoped_object.h" + +namespace wvoec { +namespace util { +namespace { +// OpenSSL/BoringSSL NID for rsaEncryption. +constexpr int kRsaBaseId = EVP_PKEY_RSA; +// OpenSSL/BoringSSL NID for ecPublicKey. +constexpr int kEcBaseId = EVP_PKEY_EC; +// OpenSSL/BoringSSL value for errors when retrieving +// the NID of a EVP_PKEY. +constexpr int kErrorNid = EVP_PKEY_NONE; + +void OpensslFreeU8(uint8_t* ptr) { OPENSSL_free(ptr); } +using ScopedBuffer = ScopedObject; +} // namespace + +const char* OemCertKeyTypeToString(OemCertKeyType type) { + switch (type) { + case OemCertKeyType::kUnknown: + return "Unknown"; + case OemCertKeyType::kRsa: + return "RSA"; + case OemCertKeyType::kEcc: + return "ECC"; + } + return ""; +} + +// static +std::unique_ptr OemCertificateChain::LoadPkcs7( + const uint8_t* message, size_t message_length) { + if (message == nullptr) { + LOGE("PKCS#7 |message| is null"); + return nullptr; + } + if (message_length == 0) { + LOGE("PKCS#7 |message| is empty"); + return nullptr; + } + std::unique_ptr cert_chain(new OemCertificateChain()); + if (!cert_chain->InitFromPkcs7(message, message_length)) { + LOGE("Failed to initialize OEM certificate chain from PKCS#7"); + return nullptr; + } + return cert_chain; +} + +std::unique_ptr OemCertificateChain::LoadPkcs7( + const std::vector& message) { + if (message.empty()) { + LOGE("PKCS#7 |message| is empty"); + return nullptr; + } + return LoadPkcs7(message.data(), message.size()); +} + +bool OemCertificateChain::InitFromPkcs7(const uint8_t* message, + size_t message_length) { + // Step 1a: Parse the PKCS7 message. + const uint8_t* public_cert_ptr = message; + pkcs7_.reset(d2i_PKCS7(nullptr, &public_cert_ptr, message_length)); + if (!pkcs7_) { + LOGE("Failed to parse PKCS#7 certificate chain"); + return false; + } + // Step 1b: Convert to SignedData type. + if (!PKCS7_type_is_signed(pkcs7_.get())) { + LOGE("OEM Certificate Chain is not PKCS#7 signed data"); + return false; + } + PKCS7_SIGNED* signed_data = pkcs7_->d.sign; + if (signed_data == nullptr) { + // OpenSSL error. + LOGE("SignedData is unexpectedly null"); + return false; + } + + // Step 2a: Verify number of certs + const size_t cert_count = static_cast(sk_X509_num(signed_data->cert)); + if (cert_count != kExpectedCertCount) { + LOGE("Unexpected number of certificates: expected = %zu, actual = %zu", + kExpectedCertCount, cert_count); + return false; + } + + // Step 2b: Initialize leaf cert. + X509* cert = + sk_X509_value(signed_data->cert, static_cast(kDeviceCertIndex)); + if (cert == nullptr) { + // OpenSSL error. + LOGE("Failed to get OEM leaf X.509 certificate"); + return false; + } + device_cert_.reset(new OemPublicCertificate(cert)); + + // Step 2c: Initialize intermediate cert. + cert = sk_X509_value(signed_data->cert, + static_cast(kIntermediateCertIndex)); + if (cert == nullptr) { + // OpenSSL error. + LOGE("Failed to get OEM intermedia X.509 certificate"); + return false; + } + intermediate_cert_.reset(new OemPublicCertificate(cert)); + + // Keep copy of cert chain data. + cert_data_.assign(message, &message[message_length]); + return true; +} + +// == OEM Public Certificate == + +OemCertKeyType OemPublicCertificate::GetSubjectPublicKeyType() const { + EVP_PKEY* pkey = X509_get0_pubkey(cert_); + if (pkey == nullptr) { + LOGE("Failed to get OEM public key"); + return OemCertKeyType::kUnknown; + } + + const int base_id = EVP_PKEY_base_id(pkey); + if (base_id == kErrorNid) { + LOGE("Error encounted when retrieving key type"); + return OemCertKeyType::kUnknown; + } + if (base_id == kRsaBaseId) { + return OemCertKeyType::kRsa; + } + if (base_id == kEcBaseId) { + return OemCertKeyType::kEcc; + } + LOGE("Unsupported key type: base_id = %d", base_id); + return OemCertKeyType::kUnknown; +} + +std::vector OemPublicCertificate::SerializedSubjectPublicKeyInfo() + const { + EVP_PKEY* pkey = X509_get0_pubkey(cert_); + if (pkey == nullptr) { + LOGE("Failed to get OEM public key"); + return {}; + } + + uint8_t* der_key_raw = nullptr; + // OpenSSL will allocated |der_key_raw|. + const int der_res = i2d_PUBKEY(pkey, &der_key_raw); + if (der_res < 0) { + LOGE("Public key serialization failed"); + return {}; + } + // Wrapping with ScopedBuffer() to ensure buffer is freed. + ScopedBuffer der_key(der_key_raw); + if (!der_key) { + LOGE("Encoded key is unexpectedly null"); + return {}; + } + const size_t key_size = static_cast(der_res); + return std::vector(der_key_raw, &der_key_raw[key_size]); +} +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/src/oemcrypto_rsa_key.cpp b/oemcrypto/util/src/oemcrypto_rsa_key.cpp index bd8ba7b..1972933 100644 --- a/oemcrypto/util/src/oemcrypto_rsa_key.cpp +++ b/oemcrypto/util/src/oemcrypto_rsa_key.cpp @@ -83,6 +83,23 @@ bool IsValidAllowedSchemes(uint32_t allowed_schemes) { return (allowed_schemes & kAllSchemesMask) != 0; } +const EVP_MD* GetEvpDigestAlgorithm(RsaSignatureAlgorithm algorithm) { + switch (algorithm) { + case kRsaPssSha1: + return EVP_sha1(); + case kRsaPssSha256: + return EVP_sha256(); + case kRsaPssSha384: + return EVP_sha384(); + case kRsaPssSha512: + return EVP_sha512(); + case kRsaPkcs1Cast: // Suppress compiler warnings + break; + } + // Caller should log the error. + return nullptr; +} + bool ParseRsaPrivateKeyInfo(const uint8_t* buffer, size_t length, ScopedRsaKey* key, uint32_t* allowed_schemes, bool* explicit_schemes, RsaFieldSize* field_size) { @@ -94,8 +111,7 @@ bool ParseRsaPrivateKeyInfo(const uint8_t* buffer, size_t length, // Check allowed scheme type. if (memcmp("SIGN", buffer, 4) == 0) { uint32_t allowed_schemes_bno; - memcpy(&allowed_schemes_bno, reinterpret_cast(&buffer[4]), - 4); + memcpy(&allowed_schemes_bno, &buffer[4], 4); *allowed_schemes = ntohl(allowed_schemes_bno); if (!IsValidAllowedSchemes(*allowed_schemes)) { LOGE("Invalid allowed schemes value: allowed_schemes = %08x", @@ -155,8 +171,26 @@ bool ParseRsaPrivateKeyInfo(const uint8_t* buffer, size_t length, } void OpensslFreeU8(uint8_t* ptr) { OPENSSL_free(ptr); } +using ScopedBuffer = ScopedObject; } // namespace +RsaSignatureAlgorithm RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm( + OEMCrypto_SignatureHashAlgorithm hash_algorithm) { + switch (hash_algorithm) { + case OEMCrypto_SHA1: + return kRsaPssSha1; + case OEMCrypto_SHA2_256: + return kRsaPssSha256; + case OEMCrypto_SHA2_384: + return kRsaPssSha384; + case OEMCrypto_SHA2_512: + return kRsaPssSha512; + } + LOGE("Unexpected OEMCrypto hash algorithm: %d", + static_cast(hash_algorithm)); + return kRsaPssDefault; +} + std::string RsaFieldSizeToString(RsaFieldSize field_size) { switch (field_size) { case kRsa2048Bit: @@ -169,6 +203,7 @@ std::string RsaFieldSizeToString(RsaFieldSize field_size) { return "Unknown(" + std::to_string(static_cast(field_size)) + ")"; } +// Note: Only the public components of |private_key| are checked. bool RsaKeysAreMatchingPair(const RSA* public_key, const RSA* private_key) { if (public_key == nullptr) { LOGE("Public key is null"); @@ -181,7 +216,7 @@ bool RsaKeysAreMatchingPair(const RSA* public_key, const RSA* private_key) { // Step 1: Extract public key components. const BIGNUM* public_n = nullptr; const BIGNUM* public_e = nullptr; - const BIGNUM* d = nullptr; + const BIGNUM* d = nullptr; // Unused RSA_get0_key(public_key, &public_n, &public_e, &d); if (public_n == nullptr || public_e == nullptr) { LOGE("Failed to get RSA public key components"); @@ -318,6 +353,46 @@ std::unique_ptr RsaPublicKey::LoadPrivateKeyInfo( return LoadPrivateKeyInfo(buffer.data(), buffer.size()); } +// static +std::unique_ptr RsaPublicKey::LoadRsaPublicKey( + const uint8_t* buffer, size_t length) { + if (buffer == nullptr) { + LOGE("Provided public key buffer is null"); + return nullptr; + } + if (length == 0) { + LOGE("Provided public key buffer is zero length"); + return nullptr; + } + std::unique_ptr key(new RsaPublicKey()); + if (!key->InitFromRsaPublicKey(buffer, length)) { + LOGE("Failed to initialize public key from PrivateKeyInfo"); + key.reset(); + } + return key; +} + +// static +std::unique_ptr RsaPublicKey::LoadRsaPublicKey( + const std::string& buffer) { + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return nullptr; + } + return LoadRsaPublicKey(reinterpret_cast(buffer.data()), + buffer.size()); +} + +// static +std::unique_ptr RsaPublicKey::LoadRsaPublicKey( + const std::vector& buffer) { + if (buffer.empty()) { + LOGE("Provided public key buffer is empty"); + return nullptr; + } + return LoadRsaPublicKey(buffer.data(), buffer.size()); +} + bool RsaPublicKey::IsMatchingPrivateKey( const RsaPrivateKey& private_key) const { if (private_key.field_size() != field_size_) { @@ -326,6 +401,16 @@ bool RsaPublicKey::IsMatchingPrivateKey( return RsaKeysAreMatchingPair(GetRsaKey(), private_key.GetRsaKey()); } +bool RsaPublicKey::IsMatchingPublicKey(const RsaPublicKey& other_key) const { + if (this == &other_key) return true; + if (other_key.field_size() != field_size_) { + return false; + } + // Note: RsaKeysAreMatchingPair() will work even if + // |private_key| is actully a public key. + return RsaKeysAreMatchingPair(GetRsaKey(), other_key.GetRsaKey()); +} + std::vector RsaPrivateKey::GetPrivateExponent() const { const BIGNUM* d = RSA_get0_d(key_); if (d == nullptr) { @@ -364,7 +449,7 @@ OEMCryptoResult RsaPublicKey::Serialize(uint8_t* buffer, LOGE("Public key serialization failed"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } - ScopedObject der_key(der_key_raw); + ScopedBuffer der_key(der_key_raw); der_key_raw = nullptr; if (!der_key) { LOGE("Encoded key is unexpectedly null"); @@ -397,10 +482,30 @@ std::vector RsaPublicKey::Serialize() const { return key_data; } +std::vector RsaPublicKey::SerializeAsRsaPublicKey() const { + uint8_t* der_key_raw = nullptr; + const int der_res = i2d_RSAPublicKey(key_, &der_key_raw); + if (der_res < 0) { + LOGE("Public key serialization failed"); + return {}; + } + ScopedBuffer der_key(der_key_raw); + der_key_raw = nullptr; + if (!der_key) { + LOGE("Encoded key is unexpectedly null"); + return {}; + } + if (der_res == 0) { + LOGE("Unexpected DER encoded size"); + return {}; + } + const size_t key_length = static_cast(der_res); + return std::vector(der_key.get(), der_key.get() + key_length); +} + OEMCryptoResult RsaPublicKey::VerifySignature( const uint8_t* message, size_t message_length, const uint8_t* signature, - size_t signature_length, RsaSignatureAlgorithm algorithm, - OEMCrypto_SignatureHashAlgorithm hash_algorithm) const { + size_t signature_length, RsaSignatureAlgorithm algorithm) const { if (signature == nullptr || signature_length == 0) { LOGE("Signature is missing"); return OEMCrypto_ERROR_INVALID_CONTEXT; @@ -410,9 +515,12 @@ OEMCryptoResult RsaPublicKey::VerifySignature( return OEMCrypto_ERROR_INVALID_CONTEXT; } switch (algorithm) { - case kRsaPssDefault: + case kRsaPssSha1: + case kRsaPssSha256: + case kRsaPssSha384: + case kRsaPssSha512: return VerifySignaturePss(message, message_length, signature, - signature_length, hash_algorithm); + signature_length, algorithm); case kRsaPkcs1Cast: return VerifySignaturePkcs1Cast(message, message_length, signature, signature_length); @@ -423,8 +531,7 @@ OEMCryptoResult RsaPublicKey::VerifySignature( OEMCryptoResult RsaPublicKey::VerifySignature( const std::string& message, const std::string& signature, - RsaSignatureAlgorithm algorithm, - OEMCrypto_SignatureHashAlgorithm hash_algorithm) const { + RsaSignatureAlgorithm algorithm) const { if (signature.empty()) { LOGE("Signature should not be empty"); return OEMCrypto_ERROR_INVALID_CONTEXT; @@ -432,19 +539,18 @@ OEMCryptoResult RsaPublicKey::VerifySignature( return VerifySignature(reinterpret_cast(message.data()), message.size(), reinterpret_cast(signature.data()), - signature.size(), algorithm, hash_algorithm); + signature.size(), algorithm); } OEMCryptoResult RsaPublicKey::VerifySignature( const std::vector& message, const std::vector& signature, - RsaSignatureAlgorithm algorithm, - OEMCrypto_SignatureHashAlgorithm hash_algorithm) const { + RsaSignatureAlgorithm algorithm) const { if (signature.empty()) { LOGE("Signature should not be empty"); return OEMCrypto_ERROR_INVALID_CONTEXT; } return VerifySignature(message.data(), message.size(), signature.data(), - signature.size(), algorithm, hash_algorithm); + signature.size(), algorithm); } OEMCryptoResult RsaPublicKey::EncryptSessionKey( @@ -587,6 +693,7 @@ RsaPublicKey::~RsaPublicKey() { field_size_ = kRsaFieldUnknown; } +// private bool RsaPublicKey::InitFromSubjectPublicKeyInfo(const uint8_t* buffer, size_t length) { // Step 1: Deserialize SubjectPublicKeyInfo as RSA key. @@ -608,6 +715,28 @@ bool RsaPublicKey::InitFromSubjectPublicKeyInfo(const uint8_t* buffer, return true; } +// private +bool RsaPublicKey::InitFromRsaPublicKey(const uint8_t* buffer, size_t length) { + // Step 1: Deserialize RSAPublicKey as RSA key. + const uint8_t* tp = buffer; + ScopedRsaKey key(d2i_RSAPublicKey(nullptr, &tp, length)); + if (!key) { + LOGE("Failed to parse key"); + return false; + } + // Step 2: Verify key. + const int bits = RSA_bits(key.get()); + field_size_ = RealBitSizeToFieldSize(bits); + if (field_size_ == kRsaFieldUnknown) { + LOGE("Unsupported RSA key size: bits = %d", bits); + return false; + } + allowed_schemes_ = kSign_RSASSA_PSS; + key_ = key.release(); + return true; +} + +// private bool RsaPublicKey::InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length) { ScopedRsaKey private_key; @@ -620,11 +749,13 @@ bool RsaPublicKey::InitFromPrivateKeyInfo(const uint8_t* buffer, return InitFromSslHandle(private_key.get(), allowed_schemes_); } +// private bool RsaPublicKey::InitFromPrivateKey(const RsaPrivateKey& private_key) { return InitFromSslHandle(private_key.GetRsaKey(), private_key.allowed_schemes()); } +// private bool RsaPublicKey::InitFromSslHandle(const RSA* rsa_handle, uint32_t allowed_schemes) { ScopedRsaKey key(RSA_new()); @@ -665,10 +796,10 @@ bool RsaPublicKey::InitFromSslHandle(const RSA* rsa_handle, return true; } +// private OEMCryptoResult RsaPublicKey::VerifySignaturePss( const uint8_t* message, size_t message_length, const uint8_t* signature, - size_t signature_length, - OEMCrypto_SignatureHashAlgorithm hash_algorithm) const { + size_t signature_length, RsaSignatureAlgorithm algorithm) const { // Step 0: Ensure the signature algorithm is supported by key. if (!(allowed_schemes_ & kSign_RSASSA_PSS)) { LOGE("RSA key cannot verify using PSS"); @@ -685,24 +816,12 @@ OEMCryptoResult RsaPublicKey::VerifySignaturePss( return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2a: Choose the correct digest algorithm. - const EVP_MD* digest = nullptr; - switch (hash_algorithm) { - case OEMCrypto_SHA1: - digest = EVP_sha1(); - break; - case OEMCrypto_SHA2_256: - digest = EVP_sha256(); - break; - case OEMCrypto_SHA2_384: - digest = EVP_sha384(); - break; - case OEMCrypto_SHA2_512: - digest = EVP_sha512(); - break; - } + const EVP_MD* digest = GetEvpDigestAlgorithm(algorithm); if (digest == nullptr) { - LOGE("Unrecognized hash algorithm %d", hash_algorithm); - return OEMCrypto_ERROR_INVALID_CONTEXT; + // |algorithm| should have been validated by calling + // method. This is an internal error. + LOGE("Unrecognized signing algorithm %d", algorithm); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; } // Step 2b: Setup an EVP MD CTX for PSS Verification. ScopedEvpMdCtx md_ctx = EVP_MD_CTX_new(); @@ -754,6 +873,7 @@ OEMCryptoResult RsaPublicKey::VerifySignaturePss( return res ? OEMCrypto_SUCCESS : OEMCrypto_ERROR_SIGNATURE_FAILURE; } +// private OEMCryptoResult RsaPublicKey::VerifySignaturePkcs1Cast( const uint8_t* message, size_t message_length, const uint8_t* signature, size_t signature_length) const { @@ -791,6 +911,7 @@ OEMCryptoResult RsaPublicKey::VerifySignaturePkcs1Cast( return OEMCrypto_SUCCESS; } +// private OEMCryptoResult RsaPublicKey::EncryptOaep(const uint8_t* message, size_t message_size, uint8_t* enc_message, @@ -976,8 +1097,11 @@ OEMCryptoResult RsaPrivateKey::GenerateSignature( return OEMCrypto_ERROR_INVALID_CONTEXT; } switch (algorithm) { - case kRsaPssDefault: - return GenerateSignaturePss(message, message_length, signature, + case kRsaPssSha1: + case kRsaPssSha256: + case kRsaPssSha384: + case kRsaPssSha512: + return GenerateSignaturePss(message, message_length, algorithm, signature, signature_length); case kRsaPkcs1Cast: return GenerateSignaturePkcs1Cast(message, message_length, signature, @@ -1154,6 +1278,7 @@ RsaPrivateKey::~RsaPrivateKey() { field_size_ = kRsaFieldUnknown; } +// private bool RsaPrivateKey::InitFromPrivateKeyInfo(const uint8_t* buffer, size_t length) { ScopedRsaKey key; @@ -1165,6 +1290,7 @@ bool RsaPrivateKey::InitFromPrivateKeyInfo(const uint8_t* buffer, return true; } +// private bool RsaPrivateKey::InitFromFieldSize(RsaFieldSize field_size) { if (field_size != kRsa2048Bit && field_size != kRsa3072Bit) { LOGE("Unsupported RSA field size: bits = %d", static_cast(field_size)); @@ -1197,8 +1323,10 @@ bool RsaPrivateKey::InitFromFieldSize(RsaFieldSize field_size) { return true; } +// private OEMCryptoResult RsaPrivateKey::GenerateSignaturePss( - const uint8_t* message, size_t message_length, uint8_t* signature, + const uint8_t* message, size_t message_length, + RsaSignatureAlgorithm algorithm, uint8_t* signature, size_t* signature_length) const { // Step 0: Ensure the signature algorithm is supported by key. if (!(allowed_schemes_ & kSign_RSASSA_PSS)) { @@ -1215,6 +1343,14 @@ OEMCryptoResult RsaPrivateKey::GenerateSignaturePss( LOGE("Failed to set PKEY RSA key"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; } + // Step 2a: Choose the correct digest algorithm. + const EVP_MD* digest = GetEvpDigestAlgorithm(algorithm); + if (digest == nullptr) { + // |algorithm| should have been validated by calling + // method. This is an internal error. + LOGE("Unrecognized signing algorithm %d", algorithm); + return OEMCrypto_ERROR_UNKNOWN_FAILURE; + } // Step 2a: Setup a EVP MD CTX for PSS Signature Generation. ScopedEvpMdCtx md_ctx(EVP_MD_CTX_new()); if (!md_ctx) { @@ -1222,8 +1358,8 @@ OEMCryptoResult RsaPrivateKey::GenerateSignaturePss( return OEMCrypto_ERROR_UNKNOWN_FAILURE; } EVP_PKEY_CTX* pkey_ctx = nullptr; // Ownership is maintained by |md_ctx| - int res = EVP_DigestSignInit(md_ctx.get(), &pkey_ctx, EVP_sha1(), nullptr, - pkey.get()); + int res = + EVP_DigestSignInit(md_ctx.get(), &pkey_ctx, digest, nullptr, pkey.get()); if (res != 1) { LOGE("Failed to initialize MD CTX for signing"); return OEMCrypto_ERROR_UNKNOWN_FAILURE; @@ -1276,6 +1412,7 @@ OEMCryptoResult RsaPrivateKey::GenerateSignaturePss( return OEMCrypto_SUCCESS; } +// private OEMCryptoResult RsaPrivateKey::GenerateSignaturePkcs1Cast( const uint8_t* message, size_t message_length, uint8_t* signature, size_t* signature_length) const { @@ -1307,6 +1444,7 @@ OEMCryptoResult RsaPrivateKey::GenerateSignaturePkcs1Cast( return OEMCrypto_SUCCESS; } +// private OEMCryptoResult RsaPrivateKey::DecryptOaep( const uint8_t* enc_message, size_t enc_message_size, uint8_t* message, size_t expected_message_length) const { diff --git a/oemcrypto/util/src/wvcrc.cpp b/oemcrypto/util/src/wvcrc.cpp deleted file mode 100644 index 097ca70..0000000 --- a/oemcrypto/util/src/wvcrc.cpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary -// source code may only be used and distributed under the Widevine -// License Agreement. -// -// Compute CRC32/MPEG2 Checksum. Needed for verification of WV Keybox. -// -#include "platform.h" -#include "wvcrc32.h" - -namespace wvoec { -namespace util { -#define INIT_CRC32 0xffffffff - -uint32_t wvrunningcrc32(const uint8_t* p_begin, size_t i_count, - uint32_t i_crc) { - constexpr uint32_t CRC32[256] = { - 0x00000000, 0x04c11db7, 0x09823b6e, 0x0d4326d9, 0x130476dc, 0x17c56b6b, - 0x1a864db2, 0x1e475005, 0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, 0x2b4bcb61, - 0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd, 0x4c11db70, 0x48d0c6c7, - 0x4593e01e, 0x4152fda9, 0x5f15adac, 0x5bd4b01b, 0x569796c2, 0x52568b75, - 0x6a1936c8, 0x6ed82b7f, 0x639b0da6, 0x675a1011, 0x791d4014, 0x7ddc5da3, - 0x709f7b7a, 0x745e66cd, 0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039, - 0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5, 0xbe2b5b58, 0xbaea46ef, - 0xb7a96036, 0xb3687d81, 0xad2f2d84, 0xa9ee3033, 0xa4ad16ea, 0xa06c0b5d, - 0xd4326d90, 0xd0f37027, 0xddb056fe, 0xd9714b49, 0xc7361b4c, 0xc3f706fb, - 0xceb42022, 0xca753d95, 0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1, - 0xe13ef6f4, 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d, 0x34867077, 0x30476dc0, - 0x3d044b19, 0x39c556ae, 0x278206ab, 0x23431b1c, 0x2e003dc5, 0x2ac12072, - 0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16, 0x018aeb13, 0x054bf6a4, - 0x0808d07d, 0x0cc9cdca, 0x7897ab07, 0x7c56b6b0, 0x71159069, 0x75d48dde, - 0x6b93dddb, 0x6f52c06c, 0x6211e6b5, 0x66d0fb02, 0x5e9f46bf, 0x5a5e5b08, - 0x571d7dd1, 0x53dc6066, 0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba, - 0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e, 0xbfa1b04b, 0xbb60adfc, - 0xb6238b25, 0xb2e29692, 0x8aad2b2f, 0x8e6c3698, 0x832f1041, 0x87ee0df6, - 0x99a95df3, 0x9d684044, 0x902b669d, 0x94ea7b2a, 0xe0b41de7, 0xe4750050, - 0xe9362689, 0xedf73b3e, 0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2, - 0xc6bcf05f, 0xc27dede8, 0xcf3ecb31, 0xcbffd686, 0xd5b88683, 0xd1799b34, - 0xdc3abded, 0xd8fba05a, 0x690ce0ee, 0x6dcdfd59, 0x608edb80, 0x644fc637, - 0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb, 0x4f040d56, 0x4bc510e1, - 0x46863638, 0x42472b8f, 0x5c007b8a, 0x58c1663d, 0x558240e4, 0x51435d53, - 0x251d3b9e, 0x21dc2629, 0x2c9f00f0, 0x285e1d47, 0x36194d42, 0x32d850f5, - 0x3f9b762c, 0x3b5a6b9b, 0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff, - 0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623, 0xf12f560e, 0xf5ee4bb9, - 0xf8ad6d60, 0xfc6c70d7, 0xe22b20d2, 0xe6ea3d65, 0xeba91bbc, 0xef68060b, - 0xd727bbb6, 0xd3e6a601, 0xdea580d8, 0xda649d6f, 0xc423cd6a, 0xc0e2d0dd, - 0xcda1f604, 0xc960ebb3, 0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7, - 0xae3afba2, 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b, 0x9b3660c6, 0x9ff77d71, - 0x92b45ba8, 0x9675461f, 0x8832161a, 0x8cf30bad, 0x81b02d74, 0x857130c3, - 0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640, 0x4e8ee645, 0x4a4ffbf2, - 0x470cdd2b, 0x43cdc09c, 0x7b827d21, 0x7f436096, 0x7200464f, 0x76c15bf8, - 0x68860bfd, 0x6c47164a, 0x61043093, 0x65c52d24, 0x119b4be9, 0x155a565e, - 0x18197087, 0x1cd86d30, 0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec, - 0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088, 0x2497d08d, 0x2056cd3a, - 0x2d15ebe3, 0x29d4f654, 0xc5a92679, 0xc1683bce, 0xcc2b1d17, 0xc8ea00a0, - 0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, 0xdbee767c, 0xe3a1cbc1, 0xe760d676, - 0xea23f0af, 0xeee2ed18, 0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4, - 0x89b8fd09, 0x8d79e0be, 0x803ac667, 0x84fbdbd0, 0x9abc8bd5, 0x9e7d9662, - 0x933eb0bb, 0x97ffad0c, 0xafb010b1, 0xab710d06, 0xa6322bdf, 0xa2f33668, - 0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4}; - - /* Calculate the CRC */ - while (i_count > 0) { - i_crc = (i_crc << 8) ^ CRC32[(i_crc >> 24) ^ ((uint32_t) * p_begin)]; - p_begin++; - i_count--; - } - - return(i_crc); -} - -uint32_t wvcrc32(const uint8_t* p_begin, size_t i_count) { - return(wvrunningcrc32(p_begin, i_count, INIT_CRC32)); -} - -uint32_t wvcrc32Init() { - return INIT_CRC32; -} - -uint32_t wvcrc32Cont(const uint8_t* p_begin, size_t i_count, - uint32_t prev_crc) { - return(wvrunningcrc32(p_begin, i_count, prev_crc)); -} - -uint32_t wvcrc32n(const uint8_t* p_begin, size_t i_count) { - return htonl(wvrunningcrc32(p_begin, i_count, INIT_CRC32)); -} -} // namespace util -} // namespace wvoec diff --git a/oemcrypto/util/src/wvcrc32.cpp b/oemcrypto/util/src/wvcrc32.cpp new file mode 100644 index 0000000..b519eaf --- /dev/null +++ b/oemcrypto/util/src/wvcrc32.cpp @@ -0,0 +1,156 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine +// License Agreement. +// +#include "wvcrc32.h" + +#include + +#include +#include + +// This is intended to be a stand-alone file. +// Do not include any non-standard library headers. +namespace wvoec { +namespace util { +namespace { +// clang-format off +// Pre-computed CRC-32/MPEG-2 lookup table. +constexpr uint32_t kCrcLookupTable[256] = { + 0x00000000, 0x04c11db7, 0x09823b6e, 0x0d4326d9, + 0x130476dc, 0x17c56b6b, 0x1a864db2, 0x1e475005, + 0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, 0x2b4bcb61, + 0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd, + 0x4c11db70, 0x48d0c6c7, 0x4593e01e, 0x4152fda9, + 0x5f15adac, 0x5bd4b01b, 0x569796c2, 0x52568b75, + 0x6a1936c8, 0x6ed82b7f, 0x639b0da6, 0x675a1011, + 0x791d4014, 0x7ddc5da3, 0x709f7b7a, 0x745e66cd, + 0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039, + 0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5, + 0xbe2b5b58, 0xbaea46ef, 0xb7a96036, 0xb3687d81, + 0xad2f2d84, 0xa9ee3033, 0xa4ad16ea, 0xa06c0b5d, + 0xd4326d90, 0xd0f37027, 0xddb056fe, 0xd9714b49, + 0xc7361b4c, 0xc3f706fb, 0xceb42022, 0xca753d95, + 0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1, + 0xe13ef6f4, 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d, + 0x34867077, 0x30476dc0, 0x3d044b19, 0x39c556ae, + 0x278206ab, 0x23431b1c, 0x2e003dc5, 0x2ac12072, + 0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16, + 0x018aeb13, 0x054bf6a4, 0x0808d07d, 0x0cc9cdca, + 0x7897ab07, 0x7c56b6b0, 0x71159069, 0x75d48dde, + 0x6b93dddb, 0x6f52c06c, 0x6211e6b5, 0x66d0fb02, + 0x5e9f46bf, 0x5a5e5b08, 0x571d7dd1, 0x53dc6066, + 0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba, + 0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e, + 0xbfa1b04b, 0xbb60adfc, 0xb6238b25, 0xb2e29692, + 0x8aad2b2f, 0x8e6c3698, 0x832f1041, 0x87ee0df6, + 0x99a95df3, 0x9d684044, 0x902b669d, 0x94ea7b2a, + 0xe0b41de7, 0xe4750050, 0xe9362689, 0xedf73b3e, + 0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2, + 0xc6bcf05f, 0xc27dede8, 0xcf3ecb31, 0xcbffd686, + 0xd5b88683, 0xd1799b34, 0xdc3abded, 0xd8fba05a, + 0x690ce0ee, 0x6dcdfd59, 0x608edb80, 0x644fc637, + 0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb, + 0x4f040d56, 0x4bc510e1, 0x46863638, 0x42472b8f, + 0x5c007b8a, 0x58c1663d, 0x558240e4, 0x51435d53, + 0x251d3b9e, 0x21dc2629, 0x2c9f00f0, 0x285e1d47, + 0x36194d42, 0x32d850f5, 0x3f9b762c, 0x3b5a6b9b, + 0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff, + 0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623, + 0xf12f560e, 0xf5ee4bb9, 0xf8ad6d60, 0xfc6c70d7, + 0xe22b20d2, 0xe6ea3d65, 0xeba91bbc, 0xef68060b, + 0xd727bbb6, 0xd3e6a601, 0xdea580d8, 0xda649d6f, + 0xc423cd6a, 0xc0e2d0dd, 0xcda1f604, 0xc960ebb3, + 0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7, + 0xae3afba2, 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b, + 0x9b3660c6, 0x9ff77d71, 0x92b45ba8, 0x9675461f, + 0x8832161a, 0x8cf30bad, 0x81b02d74, 0x857130c3, + 0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640, + 0x4e8ee645, 0x4a4ffbf2, 0x470cdd2b, 0x43cdc09c, + 0x7b827d21, 0x7f436096, 0x7200464f, 0x76c15bf8, + 0x68860bfd, 0x6c47164a, 0x61043093, 0x65c52d24, + 0x119b4be9, 0x155a565e, 0x18197087, 0x1cd86d30, + 0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec, + 0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088, + 0x2497d08d, 0x2056cd3a, 0x2d15ebe3, 0x29d4f654, + 0xc5a92679, 0xc1683bce, 0xcc2b1d17, 0xc8ea00a0, + 0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, 0xdbee767c, + 0xe3a1cbc1, 0xe760d676, 0xea23f0af, 0xeee2ed18, + 0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4, + 0x89b8fd09, 0x8d79e0be, 0x803ac667, 0x84fbdbd0, + 0x9abc8bd5, 0x9e7d9662, 0x933eb0bb, 0x97ffad0c, + 0xafb010b1, 0xab710d06, 0xa6322bdf, 0xa2f33668, + 0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4 +}; +// clang-format on + +constexpr uint32_t CrcLookup(uint8_t b) { + return kCrcLookupTable[static_cast(b)]; +} + +// Note: Assumes |buffer| is at least 4 bytes in length. +void HostToNetworkU32Serialize(uint32_t value, uint8_t* buffer) { + if (buffer == nullptr) return; + buffer[0] = static_cast(value >> 24); + buffer[1] = static_cast(value >> 16); + buffer[2] = static_cast(value >> 8); + buffer[3] = static_cast(value); +} + +uint32_t HostToNetworkU32(uint32_t value) { + static_assert(sizeof(uint32_t) == 4, "uint32_t must be 4 bytes"); + uint32_t value_nbo = 0; + HostToNetworkU32Serialize(value, reinterpret_cast(&value_nbo)); + return value_nbo; +} +} // namespace + +Crc32Ctx& Crc32Ctx::Update(const std::vector& data) { + return Update(data.empty() ? nullptr : data.data(), data.size()); +} + +Crc32Ctx& Crc32Ctx::Update(const std::string& data) { + return Update( + data.empty() ? nullptr : reinterpret_cast(data.data()), + data.size()); +} + +Crc32Ctx& Crc32Ctx::UpdateInternal(const uint8_t* data, size_t data_length) { + if (data == nullptr) return *this; // Skip if null. + for (size_t i = 0; i < data_length; i++) { + const uint8_t carry = static_cast(accumulator_ >> 24); + accumulator_ <<= 8; + accumulator_ ^= CrcLookup(data[i] ^ carry); + } + return *this; +} + +uint32_t Crc32Ctx::FinalizeNbo() const { return HostToNetworkU32(Finalize()); } + +bool Crc32Ctx::FinalizeNbo(uint8_t* crc_buffer, + size_t crc_buffer_length) const { + if (crc_buffer == nullptr || crc_buffer_length < 4) return false; + const uint32_t crc = Finalize(); + HostToNetworkU32Serialize(crc, crc_buffer); + return true; +} + +uint32_t wvcrc32Init() { return Crc32Ctx::kInitialValue; } + +uint32_t wvcrc32(const uint8_t* data, size_t data_length) { + if (data == nullptr) return 0; + return Crc32Ctx().Update(data, data_length).Finalize(); +} + +uint32_t wvcrc32Cont(const uint8_t* data, size_t data_length, + uint32_t prev_crc) { + if (data == nullptr) return 0; + return Crc32Ctx(prev_crc).Update(data, data_length).Finalize(); +} + +uint32_t wvcrc32n(const uint8_t* data, size_t data_length) { + if (data == nullptr) return 0; + return Crc32Ctx().Update(data, data_length).FinalizeNbo(); +} +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/test/cose_utils_unittest.cpp b/oemcrypto/util/test/cose_utils_unittest.cpp new file mode 100644 index 0000000..50a6dc1 --- /dev/null +++ b/oemcrypto/util/test/cose_utils_unittest.cpp @@ -0,0 +1,1335 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include "cose_utils.h" + +#include + +#include +#include +#include + +#include +#include + +#include "oemcrypto_ref_test_utils.h" + +namespace wvoec { +namespace util { +TEST(CoseKeyDataTest, ParseBad_UnknownKeyType) { + const int64_t bad_key_type = 42; + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ bad_key_type) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBad_KeyTypeType) { + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ cppbor::Bool(false)) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +// Parse an EC2 COSE_Key containing the minimum amount of data. +TEST(CoseKeyDataTest, ParseEc2_Minimum) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeEc2); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_FALSE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), std::nullopt); + EXPECT_EQ(cose_key.curve(), kCoseCurveP256); + EXPECT_FALSE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsEc2Data()); + EXPECT_FALSE(cose_key.IsOkpData()); + + const auto ec2_data_opt = cose_key.ec2_data(); + ASSERT_TRUE(ec2_data_opt.has_value()) << "Failed to get EC2 data"; + const auto& ec2_data = ec2_data_opt.value(); + + EXPECT_EQ(ec2_data.key_type, kCoseKeyTypeEc2); + EXPECT_TRUE(ec2_data.key_ops.empty()); + EXPECT_EQ(ec2_data.algorithm, std::nullopt); + EXPECT_EQ(ec2_data.curve, kCoseCurveP256); + EXPECT_FALSE(ec2_data.has_public_key()); + EXPECT_TRUE(ec2_data.x.empty()); + ASSERT_TRUE(std::holds_alternative>(ec2_data.y)); + EXPECT_TRUE(std::get>(ec2_data.y).empty()); + EXPECT_TRUE(ec2_data.has_private_key()); + EXPECT_EQ(ec2_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseEc2_WithKeyOps) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + const std::vector key_ops = {kCoseKeyOpVerify}; + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, + /* key_ops = */ cppbor::Array().add(kCoseKeyOpVerify)) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeEc2); + EXPECT_TRUE(cose_key.has_key_ops()); + EXPECT_EQ(cose_key.key_ops(), key_ops); + EXPECT_FALSE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), std::nullopt); + EXPECT_EQ(cose_key.curve(), kCoseCurveP256); + EXPECT_FALSE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsEc2Data()); + EXPECT_FALSE(cose_key.IsOkpData()); + + const auto ec2_data_opt = cose_key.ec2_data(); + ASSERT_TRUE(ec2_data_opt.has_value()) << "Failed to get EC2 data"; + const auto& ec2_data = ec2_data_opt.value(); + + EXPECT_EQ(ec2_data.key_type, kCoseKeyTypeEc2); + EXPECT_EQ(ec2_data.key_ops, key_ops); + EXPECT_EQ(ec2_data.algorithm, std::nullopt); + EXPECT_EQ(ec2_data.curve, kCoseCurveP256); + EXPECT_FALSE(ec2_data.has_public_key()); + EXPECT_TRUE(ec2_data.x.empty()); + ASSERT_TRUE(std::holds_alternative>(ec2_data.y)); + EXPECT_TRUE(std::get>(ec2_data.y).empty()); + EXPECT_TRUE(ec2_data.has_private_key()); + EXPECT_EQ(ec2_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseEc2_WithAlgorithm) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyAlgorithmLabel, /* alg = */ kCoseAlgorithmEs256) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeEc2); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_TRUE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), kCoseAlgorithmEs256); + EXPECT_EQ(cose_key.curve(), kCoseCurveP256); + EXPECT_FALSE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsEc2Data()); + EXPECT_FALSE(cose_key.IsOkpData()); + + const auto ec2_data_opt = cose_key.ec2_data(); + ASSERT_TRUE(ec2_data_opt.has_value()) << "Failed to get EC2 data"; + const auto& ec2_data = ec2_data_opt.value(); + + EXPECT_EQ(ec2_data.key_type, kCoseKeyTypeEc2); + EXPECT_TRUE(ec2_data.key_ops.empty()); + EXPECT_EQ(ec2_data.algorithm, kCoseAlgorithmEs256); + EXPECT_EQ(ec2_data.curve, kCoseCurveP256); + EXPECT_FALSE(ec2_data.has_public_key()); + EXPECT_TRUE(ec2_data.x.empty()); + ASSERT_TRUE(std::holds_alternative>(ec2_data.y)); + EXPECT_TRUE(std::get>(ec2_data.y).empty()); + EXPECT_TRUE(ec2_data.has_private_key()); + EXPECT_EQ(ec2_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseEc2_PublicOnly_YBytes) { + // Fake P-256 public key. + const std::vector x_coord = RandomData(32); + const std::vector y_coord = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyXCoordLabel, /* x = */ x_coord) + .add(kCoseKeyYCoordLabel, /* y = */ y_coord) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeEc2); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_FALSE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), std::nullopt); + EXPECT_EQ(cose_key.curve(), kCoseCurveP256); + EXPECT_TRUE(cose_key.has_public_key()); + EXPECT_FALSE(cose_key.has_private_key()); + + EXPECT_TRUE(cose_key.IsEc2Data()); + EXPECT_FALSE(cose_key.IsOkpData()); + + const auto ec2_data_opt = cose_key.ec2_data(); + ASSERT_TRUE(ec2_data_opt.has_value()) << "Failed to get EC2 data"; + const auto& ec2_data = ec2_data_opt.value(); + + EXPECT_EQ(ec2_data.key_type, kCoseKeyTypeEc2); + EXPECT_TRUE(ec2_data.key_ops.empty()); + EXPECT_EQ(ec2_data.algorithm, std::nullopt); + EXPECT_EQ(ec2_data.curve, kCoseCurveP256); + EXPECT_TRUE(ec2_data.has_public_key()); + EXPECT_EQ(ec2_data.x, x_coord); + ASSERT_TRUE(std::holds_alternative>(ec2_data.y)); + EXPECT_EQ(std::get>(ec2_data.y), y_coord); + EXPECT_FALSE(ec2_data.has_private_key()); + EXPECT_TRUE(ec2_data.private_key.empty()); +} + +TEST(CoseKeyDataTest, ParseEc2_PublicOnly_YSignBit) { + // Fake P-256 public key. + const std::vector x_coord = RandomData(32); + const bool y_sign_bit = true; + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyXCoordLabel, /* x = */ x_coord) + .add(kCoseKeyYCoordLabel, /* y = */ cppbor::Bool(y_sign_bit)) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeEc2); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_FALSE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), std::nullopt); + EXPECT_EQ(cose_key.curve(), kCoseCurveP256); + EXPECT_TRUE(cose_key.has_public_key()); + EXPECT_FALSE(cose_key.has_private_key()); + + EXPECT_TRUE(cose_key.IsEc2Data()); + EXPECT_FALSE(cose_key.IsOkpData()); + + const auto ec2_data_opt = cose_key.ec2_data(); + ASSERT_TRUE(ec2_data_opt.has_value()) << "Failed to get EC2 data"; + const auto& ec2_data = ec2_data_opt.value(); + + EXPECT_EQ(ec2_data.key_type, kCoseKeyTypeEc2); + EXPECT_TRUE(ec2_data.key_ops.empty()); + EXPECT_EQ(ec2_data.algorithm, std::nullopt); + EXPECT_EQ(ec2_data.curve, kCoseCurveP256); + EXPECT_TRUE(ec2_data.has_public_key()); + EXPECT_EQ(ec2_data.x, x_coord); + ASSERT_TRUE(std::holds_alternative(ec2_data.y)); + EXPECT_EQ(std::get(ec2_data.y), y_sign_bit); + EXPECT_FALSE(ec2_data.has_private_key()); + EXPECT_TRUE(ec2_data.private_key.empty()); +} + +TEST(CoseKeyDataTest, ParseEc2_AllFields) { + // Fake P-256 private/public key. + const std::vector private_key = RandomData(32); + // Order of key_ops must match that of the CBOR map. + const std::vector key_ops = {kCoseKeyOpSign, kCoseKeyOpVerify}; + const std::vector x_coord = RandomData(32); + const bool y_sign_bit = false; + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ cppbor::Array() + .add(kCoseKeyOpSign) + .add(kCoseKeyOpVerify)) + .add(kCoseKeyAlgorithmLabel, /* alg = */ kCoseAlgorithmEs256) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyXCoordLabel, /* x = */ x_coord) + .add(kCoseKeyYCoordLabel, /* y = */ cppbor::Bool(y_sign_bit)) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeEc2); + EXPECT_TRUE(cose_key.has_key_ops()); + EXPECT_EQ(cose_key.key_ops(), key_ops); + EXPECT_TRUE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), kCoseAlgorithmEs256); + EXPECT_EQ(cose_key.curve(), kCoseCurveP256); + EXPECT_TRUE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsEc2Data()); + EXPECT_FALSE(cose_key.IsOkpData()); + + const auto ec2_data_opt = cose_key.ec2_data(); + ASSERT_TRUE(ec2_data_opt.has_value()) << "Failed to get EC2 data"; + const auto& ec2_data = ec2_data_opt.value(); + + EXPECT_EQ(ec2_data.key_type, kCoseKeyTypeEc2); + EXPECT_EQ(ec2_data.key_ops, key_ops); + EXPECT_EQ(ec2_data.algorithm, kCoseAlgorithmEs256); + EXPECT_EQ(ec2_data.curve, kCoseCurveP256); + EXPECT_TRUE(ec2_data.has_public_key()); + EXPECT_EQ(ec2_data.x, x_coord); + ASSERT_TRUE(std::holds_alternative(ec2_data.y)); + EXPECT_EQ(std::get(ec2_data.y), y_sign_bit); + EXPECT_TRUE(ec2_data.has_private_key()); + EXPECT_EQ(ec2_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseBadEc2_MissingCurve) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_NoKeyValue) { + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_XWithoutY) { + const std::vector x_coord = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyXCoordLabel, /* x = */ x_coord) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_YWithoutX_YBytes) { + const std::vector y_coord = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyYCoordLabel, /* y = */ y_coord) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_YWithoutX_YSignBit) { + const bool y_sign_bit = false; + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyYCoordLabel, /* y = */ cppbor::Bool(y_sign_bit)) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_BadKeyOpsType) { + const std::vector private_key = RandomData(32); + + // 'key_ops' must be an array of ints or not present. + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ kCoseKeyOpSign) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_EmptyKeyOps) { + const std::vector private_key = RandomData(32); + + // 'key_ops' must be non-empty or not present. + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ cppbor::Array()) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_BadKeyOpsElementType) { + const std::vector private_key = RandomData(32); + + // 'key_ops' elements must be int. + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, + /* key_ops = */ cppbor::Array().add(cppbor::Bool(false))) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_BadAlgorithmType) { + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyAlgorithmLabel, /* alg = */ cppbor::Null()) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_BadCurveType) { + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ cppbor::Tstr("String")) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_BadXType) { + // Fake P-256 public key. + const bool x_sign_bit = false; // X cannot be a sign bit. + const std::vector y_coord = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyXCoordLabel, /* x = */ cppbor::Bool(x_sign_bit)) + .add(kCoseKeyYCoordLabel, /* y = */ y_coord) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_BadYType) { + // Fake P-256 public key. + const std::vector x_coord = RandomData(32); + const int64_t y_coord = 12345; // Must be bstr, not integer. + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyXCoordLabel, /* x = */ x_coord) + .add(kCoseKeyYCoordLabel, /* y = */ y_coord) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadEc2_BadPrivateType) { + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ cppbor::Tstr("Not a key")) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseOkp_Minimum) { + // Fake Ed25519 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeOkp); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_FALSE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), std::nullopt); + EXPECT_EQ(cose_key.curve(), kCoseCurveEd25519); + EXPECT_FALSE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsOkpData()); + EXPECT_FALSE(cose_key.IsEc2Data()); + + const auto okp_data_opt = cose_key.okp_data(); + ASSERT_TRUE(okp_data_opt.has_value()) << "Failed to get OKP data"; + const auto& okp_data = okp_data_opt.value(); + + EXPECT_EQ(okp_data.key_type, kCoseKeyTypeOkp); + EXPECT_TRUE(okp_data.key_ops.empty()); + EXPECT_EQ(okp_data.algorithm, std::nullopt); + EXPECT_EQ(okp_data.curve, kCoseCurveEd25519); + EXPECT_FALSE(okp_data.has_public_key()); + EXPECT_TRUE(okp_data.public_key.empty()); + EXPECT_TRUE(okp_data.has_private_key()); + EXPECT_EQ(okp_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseOkp_WithKeyOps) { + // Fake Ed25519 private key. + const std::vector private_key = RandomData(32); + const std::vector key_ops = {kCoseKeyOpVerify}; + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyKeyOpsLabel, + /* key_ops = */ cppbor::Array().add(kCoseKeyOpVerify)) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeOkp); + EXPECT_TRUE(cose_key.has_key_ops()); + EXPECT_EQ(cose_key.key_ops(), key_ops); + EXPECT_FALSE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), std::nullopt); + EXPECT_EQ(cose_key.curve(), kCoseCurveEd25519); + EXPECT_FALSE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsOkpData()); + EXPECT_FALSE(cose_key.IsEc2Data()); + + const auto okp_data_opt = cose_key.okp_data(); + ASSERT_TRUE(okp_data_opt.has_value()) << "Failed to get OKP data"; + const auto& okp_data = okp_data_opt.value(); + + EXPECT_EQ(okp_data.key_type, kCoseKeyTypeOkp); + EXPECT_EQ(okp_data.key_ops, key_ops); + EXPECT_EQ(okp_data.algorithm, std::nullopt); + EXPECT_EQ(okp_data.curve, kCoseCurveEd25519); + EXPECT_FALSE(okp_data.has_public_key()); + EXPECT_TRUE(okp_data.public_key.empty()); + EXPECT_TRUE(okp_data.has_private_key()); + EXPECT_EQ(okp_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseOkp_WithAlgorithm) { + // Fake Ed25519 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyAlgorithmLabel, /* alg = */ kCoseAlgorithmEdDsa) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeOkp); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_TRUE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), kCoseAlgorithmEdDsa); + EXPECT_EQ(cose_key.curve(), kCoseCurveEd25519); + EXPECT_FALSE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsOkpData()); + EXPECT_FALSE(cose_key.IsEc2Data()); + + const auto okp_data_opt = cose_key.okp_data(); + ASSERT_TRUE(okp_data_opt.has_value()) << "Failed to get OKP data"; + const auto& okp_data = okp_data_opt.value(); + + EXPECT_EQ(okp_data.key_type, kCoseKeyTypeOkp); + EXPECT_TRUE(okp_data.key_ops.empty()); + EXPECT_EQ(okp_data.algorithm, kCoseAlgorithmEdDsa); + EXPECT_EQ(okp_data.curve, kCoseCurveEd25519); + EXPECT_FALSE(okp_data.has_public_key()); + EXPECT_TRUE(okp_data.public_key.empty()); + EXPECT_TRUE(okp_data.has_private_key()); + EXPECT_EQ(okp_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseOkp_PublicOnly) { + // Fake Ed25519 public key. + const std::vector public_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPublicKeyLabel, /* public_key = */ public_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeOkp); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_FALSE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), std::nullopt); + EXPECT_EQ(cose_key.curve(), kCoseCurveEd25519); + EXPECT_TRUE(cose_key.has_public_key()); + EXPECT_FALSE(cose_key.has_private_key()); + + EXPECT_TRUE(cose_key.IsOkpData()); + EXPECT_FALSE(cose_key.IsEc2Data()); + + const auto okp_data_opt = cose_key.okp_data(); + ASSERT_TRUE(okp_data_opt.has_value()) << "Failed to get OKP data"; + const auto& okp_data = okp_data_opt.value(); + + EXPECT_EQ(okp_data.key_type, kCoseKeyTypeOkp); + EXPECT_TRUE(okp_data.key_ops.empty()); + EXPECT_EQ(okp_data.algorithm, std::nullopt); + EXPECT_EQ(okp_data.curve, kCoseCurveEd25519); + EXPECT_TRUE(okp_data.has_public_key()); + EXPECT_EQ(okp_data.public_key, public_key); + EXPECT_FALSE(okp_data.has_private_key()); + EXPECT_TRUE(okp_data.private_key.empty()); +} + +TEST(CoseKeyDataTest, ParseOkp_AllFields) { + // Fake Ed25519 public / private key. + const std::vector public_key = RandomData(32); + const std::vector private_key = RandomData(32); + // Order of key_ops must match that of the CBOR map. + const std::vector key_ops = {kCoseKeyOpSign, kCoseKeyOpVerify}; + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ cppbor::Array() + .add(kCoseKeyOpSign) + .add(kCoseKeyOpVerify)) + .add(kCoseKeyAlgorithmLabel, /* alg = */ kCoseAlgorithmEdDsa) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .add(kCoseKeyPublicKeyLabel, /* public_key = */ public_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + EXPECT_EQ(cose_key.key_type(), kCoseKeyTypeOkp); + EXPECT_TRUE(cose_key.has_key_ops()); + EXPECT_EQ(cose_key.key_ops(), key_ops); + EXPECT_TRUE(cose_key.has_algorithm()); + EXPECT_EQ(cose_key.algorithm(), kCoseAlgorithmEdDsa); + EXPECT_EQ(cose_key.curve(), kCoseCurveEd25519); + EXPECT_TRUE(cose_key.has_public_key()); + EXPECT_TRUE(cose_key.has_private_key()); + EXPECT_EQ(cose_key.private_key(), private_key); + + EXPECT_TRUE(cose_key.IsOkpData()); + EXPECT_FALSE(cose_key.IsEc2Data()); + + const auto okp_data_opt = cose_key.okp_data(); + ASSERT_TRUE(okp_data_opt.has_value()) << "Failed to get OKP data"; + const auto& okp_data = okp_data_opt.value(); + + EXPECT_EQ(okp_data.key_type, kCoseKeyTypeOkp); + EXPECT_EQ(okp_data.key_ops, key_ops); + EXPECT_EQ(okp_data.algorithm, kCoseAlgorithmEdDsa); + EXPECT_EQ(okp_data.curve, kCoseCurveEd25519); + EXPECT_TRUE(okp_data.has_public_key()); + EXPECT_EQ(okp_data.public_key, public_key); + EXPECT_TRUE(okp_data.has_private_key()); + EXPECT_EQ(okp_data.private_key, private_key); +} + +TEST(CoseKeyDataTest, ParseBadOkp_MissingCurve) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_NoKeyValue) { + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_BadKeyOpsType) { + const std::vector private_key = RandomData(32); + + // 'key_ops' must be an array of ints or not present. + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ kCoseKeyOpSign) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_EmptyKeyOps) { + const std::vector private_key = RandomData(32); + + // 'key_ops' must be non-empty or not present. + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ cppbor::Array()) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_BadKeyOpsElementType) { + const std::vector private_key = RandomData(32); + + // 'key_ops' elements must be int. + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyKeyOpsLabel, + /* key_ops = */ cppbor::Array().add(cppbor::Bool(false))) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_BadAlgorithmType) { + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyAlgorithmLabel, /* alg = */ cppbor::Bool(true)) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_BadCurveType) { + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyCurveLabel, /* crv = */ private_key) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ private_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_BadPublicType) { + const int64_t bad_public_key = 99999; + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPublicKeyLabel, /* public_key = */ bad_public_key) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, ParseBadOkp_BadPrivateType) { + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeOkp) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveEd25519) + .add(kCoseKeyPrivateKeyLabel, /* private_key = */ cppbor::Null()) + .canonicalize() + .encode(); + CoseKeyData cose_key; + EXPECT_FALSE(cose_key.ParseCbor(cose_key_cbor)); +} + +TEST(CoseKeyDataTest, KeyOps_Contains) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ cppbor::Array() + .add(kCoseKeyOpSign) + .add(kCoseKeyOpVerify)) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + ASSERT_TRUE(cose_key.has_key_ops()); + EXPECT_TRUE(cose_key.ContainsKeyOp(kCoseKeyOpSign)); + EXPECT_TRUE(cose_key.ContainsKeyOp(kCoseKeyOpVerify)); + EXPECT_FALSE(cose_key.ContainsKeyOp(kCoseKeyOpEncrypt)); +} + +TEST(CoseKeyDataTest, KeyOps_Remove) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ cppbor::Array() + .add(kCoseKeyOpSign) + .add(kCoseKeyOpVerify)) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + ASSERT_TRUE(cose_key.has_key_ops()); + EXPECT_TRUE(cose_key.ContainsKeyOp(kCoseKeyOpSign)); + cose_key.RemoveKeyOp(kCoseKeyOpSign); + + // Check that the remaining still exist. + EXPECT_TRUE(cose_key.has_key_ops()); + EXPECT_FALSE(cose_key.ContainsKeyOp(kCoseKeyOpSign)); + EXPECT_TRUE(cose_key.ContainsKeyOp(kCoseKeyOpVerify)); + + cose_key.RemoveKeyOp(kCoseKeyOpVerify); + EXPECT_FALSE(cose_key.has_key_ops()); + EXPECT_FALSE(cose_key.ContainsKeyOp(kCoseKeyOpVerify)); +} + +TEST(CoseKeyDataTest, KeyOps_RemoveMultiple) { + // Fake P-256 private key. + const std::vector private_key = RandomData(32); + + const std::vector cose_key_cbor = + cppbor::Map() + .add(kCoseKeyKeyTypeLabel, /* kty = */ kCoseKeyTypeEc2) + .add(kCoseKeyKeyOpsLabel, /* key_ops = */ cppbor::Array() + .add(kCoseKeyOpSign) + .add(kCoseKeyOpVerify) + .add(kCoseKeyOpSign)) + .add(kCoseKeyCurveLabel, /* crv = */ kCoseCurveP256) + .add(kCoseKeyPrivateKeyLabel, /* d = */ private_key) + .canonicalize() + .encode(); + + CoseKeyData cose_key; + ASSERT_TRUE(cose_key.ParseCbor(cose_key_cbor)) << "Failed to parse COSE_Key"; + + constexpr size_t kInitialKeyOpsCount = 3; + ASSERT_EQ(cose_key.key_ops().size(), kInitialKeyOpsCount); + EXPECT_TRUE(cose_key.ContainsKeyOp(kCoseKeyOpSign)); + cose_key.RemoveKeyOp(kCoseKeyOpSign); // Should remove both + + // Check that the remaining still exist. + constexpr size_t kExpectedKeyOpsCount = 1; + EXPECT_EQ(cose_key.key_ops().size(), kExpectedKeyOpsCount); + EXPECT_FALSE(cose_key.ContainsKeyOp(kCoseKeyOpSign)); + EXPECT_TRUE(cose_key.ContainsKeyOp(kCoseKeyOpVerify)); +} + +TEST(CoseSign1DataTest, Parse_Minimum) { + const std::vector kEmptyVector; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), kEmptyVector); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), kEmptyVector); +} + +TEST(CoseSign1DataTest, Parse_EmptyProtectedOnly) { + const std::vector kEmptyVector; + const std::vector protected_cbor = + cppbor::Map().canonicalize().encode(); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ protected_cbor) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), protected_cbor); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), kEmptyVector); +} + +TEST(CoseSign1DataTest, Parse_ProtectedWithAlgorithmOnly) { + const std::vector kEmptyVector; + const std::vector protected_cbor = + cppbor::Map() + .add(kCoseGenHeaderAlgorithmLabel, /* alg = */ kCoseAlgorithmEdDsa) + .canonicalize() + .encode(); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ protected_cbor) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), protected_cbor); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), kCoseAlgorithmEdDsa); + EXPECT_FALSE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), kEmptyVector); +} + +TEST(CoseSign1DataTest, Parse_ProtectedWithoutAlgorithmOnly) { + const std::vector kEmptyVector; + // Adding other fields which are not used. + const std::vector protected_cbor = + cppbor::Map() + .add(3, /* content_type = */ cppbor::Tstr("application/json")) + .add(6, /* iv = */ std::vector(16, 0x42)) + .canonicalize() + .encode(); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ protected_cbor) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), protected_cbor); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), kEmptyVector); +} + +TEST(CoseSign1DataTest, Parse_EmptyPayloadOnly) { + const std::vector kEmptyVector; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ kEmptyVector) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), kEmptyVector); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), std::nullopt); + // An empty payload is still a payload. + EXPECT_TRUE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), kEmptyVector); + EXPECT_FALSE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), kEmptyVector); +} + +TEST(CoseSign1DataTest, Parse_PayloadOnly) { + const std::vector kEmptyVector; + const std::vector payload = RandomData(256); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ payload) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), kEmptyVector); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), std::nullopt); + EXPECT_TRUE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), payload); + EXPECT_FALSE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), kEmptyVector); +} + +TEST(CoseSign1DataTest, Parse_SignatureOnly) { + const std::vector kEmptyVector; + const std::vector signature = RandomData(64); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ signature) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), kEmptyVector); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), std::nullopt); + EXPECT_FALSE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), std::nullopt); + EXPECT_TRUE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), signature); +} + +TEST(CoseSign1DataTest, Parse_AllFields) { + const std::vector protected_cbor = + cppbor::Map() + .add(kCoseGenHeaderAlgorithmLabel, /* alg = */ kCoseAlgorithmEs256) + .canonicalize() + .encode(); + const std::vector payload = RandomData(256); + const std::vector signature = RandomData(128); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ protected_cbor) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ payload) + .add(/* signature = */ signature) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), protected_cbor); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), kCoseAlgorithmEs256); + EXPECT_TRUE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), payload); + EXPECT_TRUE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), signature); +} + +// Having many fields other than "payload" is a common form +// of COSE_Sign1, called "detached content". +TEST(CoseSign1DataTest, Parse_AllButPayloadFields) { + const std::vector protected_cbor = + cppbor::Map() + .add(kCoseGenHeaderAlgorithmLabel, /* alg = */ kCoseAlgorithmEs256) + .canonicalize() + .encode(); + const std::vector signature = RandomData(128); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ protected_cbor) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ signature) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_TRUE(cose_sign1.ParseCbor(cose_sign1_cbor)); + + EXPECT_EQ(cose_sign1.protected_params(), protected_cbor); + EXPECT_EQ(cose_sign1.ExtractAlgorithm(), kCoseAlgorithmEs256); + EXPECT_FALSE(cose_sign1.has_payload()); + EXPECT_EQ(cose_sign1.payload(), std::nullopt); + EXPECT_TRUE(cose_sign1.has_signature()); + EXPECT_EQ(cose_sign1.signature(), signature); +} + +TEST(CoseSign1DataTest, ParseBad_TooFewFields) { + const std::vector kEmptyVector; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + // No signature. + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_TooManyFields) { + const std::vector kEmptyVector; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .add(cppbor::Tstr("extra value")) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_BadProtected_NotCbor) { + const std::vector kEmptyVector; + const std::vector bad_protected = {'n', 'o', 't', 'C', 'B', + 'O', 'R', 'm', 'a', 'p'}; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ bad_protected) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_BadProtected_WrongType) { + const std::vector kEmptyVector; + const std::vector cose_sign1_cbor = + cppbor::Array() + // Must be an empty bstr not an empty tstr. + .add(/* protected = */ cppbor::Tstr("")) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_BadProtected_NotCborMap) { + const std::vector kEmptyVector; + // Protected CBOR bytes should be a map. + const std::vector bad_protected_cbor = + cppbor::Array().add(kCoseAlgorithmEs256).encode(); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ bad_protected_cbor) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_BadProtected_WrongAlgorithmType) { + const std::vector kEmptyVector; + // Algorithm value must be an int. + const std::vector bad_protected_cbor = + cppbor::Map() + .add(kCoseGenHeaderAlgorithmLabel, /* alg = */ cppbor::Bool(true)) + .canonicalize() + .encode(); + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ bad_protected_cbor) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_BadUnprotectedType) { + const std::vector kEmptyVector; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + // "unprotected" must be a map. + .add(/* unprotected = */ cppbor::Null()) + .add(/* payload = */ cppbor::Null()) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_BadPayloadType) { + const std::vector kEmptyVector; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + // "payload" must be nil or bstr + .add(/* payload = */ cppbor::Tstr("string payload")) + .add(/* signature = */ kEmptyVector) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} + +TEST(CoseSign1DataTest, ParseBad_BadSignatureType) { + const std::vector kEmptyVector; + + const std::vector cose_sign1_cbor = + cppbor::Array() + .add(/* protected = */ kEmptyVector) + .add(/* unprotected = */ cppbor::Map()) + .add(/* payload = */ kEmptyVector) + // "signature" must be bstr + .add(/* signature = */ cppbor::Tstr("string signature")) + .encode(); + + CoseSign1Data cose_sign1; + EXPECT_FALSE(cose_sign1.ParseCbor(cose_sign1_cbor)); +} +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/test/hmac_unittest.cpp b/oemcrypto/util/test/hmac_unittest.cpp index bc4f355..0cbdb61 100644 --- a/oemcrypto/util/test/hmac_unittest.cpp +++ b/oemcrypto/util/test/hmac_unittest.cpp @@ -13,7 +13,8 @@ namespace wvoec { namespace util { -namespace { + +// Putting type in non-anonymous namespace to prevent linkage warnings. struct HmacTestVector { std::vector key; std::vector message; @@ -43,6 +44,7 @@ void PrintTo(const HmacTestVector& v, std::ostream* os) { *os << "signature_sha1 = " << wvutil::b2a_hex(v.signature_sha1) << "}"; } +namespace { std::vector FromString(const std::string& s) { return std::vector(s.begin(), s.end()); } diff --git a/oemcrypto/util/test/oemcrypto_cose_key_unittest.cpp b/oemcrypto/util/test/oemcrypto_cose_key_unittest.cpp new file mode 100644 index 0000000..fcc0e3b --- /dev/null +++ b/oemcrypto/util/test/oemcrypto_cose_key_unittest.cpp @@ -0,0 +1,1073 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include "oemcrypto_cose_key.h" + +#include + +#include +#include +#include + +#include +#include + +#include "cose_utils.h" +#include "oemcrypto_ref_test_utils.h" + +namespace wvoec { +namespace util { +namespace { +constexpr size_t kMessageSize = 1024; + +constexpr bool kWithoutAlgorithm = false; +constexpr bool kWithAlgorithm = true; + +constexpr bool kDetachedContent = false; +constexpr bool kAttachedContent = true; +} // namespace + +class OEMCryptoCoseKeyTest : public ::testing::TestWithParam { + public: + void SetUp() override { + key_ = CosePrivateKey::New(GetParam()); + ASSERT_TRUE(key_) << "Key initialization failed: key = " + << CoseKeyFamilyToString(GetParam()); + } + + void TearDown() override { key_.reset(); } + + int64_t GetExpectedCoseAlgorithm() const { + switch (GetParam()) { + case kCoseKeyP256: + return kCoseAlgorithmEs256; + case kCoseKeyP384: + return kCoseAlgorithmEs384; + case kCoseKeyEd25519: + return kCoseAlgorithmEdDsa; + case kCoseKeyFamilyUnknown: + break; + } + return 0; + } + + int64_t GetExpectedCoseKeyType() const { + switch (GetParam()) { + case kCoseKeyP256: + case kCoseKeyP384: + return kCoseKeyTypeEc2; + case kCoseKeyEd25519: + return kCoseKeyTypeOkp; + case kCoseKeyFamilyUnknown: + break; + } + return 0; + } + + int64_t GetExpectedCoseCurve() const { + switch (GetParam()) { + case kCoseKeyP256: + return kCoseCurveP256; + case kCoseKeyP384: + return kCoseCurveP384; + case kCoseKeyEd25519: + return kCoseCurveEd25519; + case kCoseKeyFamilyUnknown: + break; + } + return 0; + } + + protected: + std::unique_ptr key_; +}; + +// Basic verification of COSE private key generation. +TEST_P(OEMCryptoCoseKeyTest, KeyProperties) { + const CoseKeyFamily expected_family = GetParam(); + ASSERT_EQ(key_->family(), expected_family); + switch (expected_family) { + case kCoseKeyP256: + case kCoseKeyP384: { + EXPECT_NE(key_->ecc_key(), nullptr); + EXPECT_EQ(key_->ed_key(), nullptr); + } break; + case kCoseKeyEd25519: { + EXPECT_EQ(key_->ecc_key(), nullptr); + EXPECT_NE(key_->ed_key(), nullptr); + } break; + case kCoseKeyFamilyUnknown: // Suppress compiler warning. + break; + } +} + +// Tests that serializing private key without the optional algorithm. +TEST_P(OEMCryptoCoseKeyTest, SerializePrivateKey_WithoutAlgorithm) { + const std::vector key_data = key_->SerializeCbor(kWithoutAlgorithm); + ASSERT_FALSE(key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(key_data)); + EXPECT_EQ(data.key_type(), GetExpectedCoseKeyType()); + EXPECT_FALSE(data.has_algorithm()); + EXPECT_EQ(data.algorithm(), std::nullopt); + EXPECT_EQ(data.curve(), GetExpectedCoseCurve()); + + // Expect both public and private key components. + EXPECT_TRUE(data.has_public_key()); + EXPECT_TRUE(data.has_private_key()); +} + +// Tests that serializing private key with the optional algorithm. +TEST_P(OEMCryptoCoseKeyTest, SerializePrivateKey_WithAlgorithm) { + const std::vector key_data = key_->SerializeCbor(kWithAlgorithm); + ASSERT_FALSE(key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(key_data)); + EXPECT_EQ(data.key_type(), GetExpectedCoseKeyType()); + EXPECT_TRUE(data.has_algorithm()); + EXPECT_EQ(data.algorithm(), GetExpectedCoseAlgorithm()); + EXPECT_EQ(data.curve(), GetExpectedCoseCurve()); + + // Expect both public and private key components. + EXPECT_TRUE(data.has_public_key()); + EXPECT_TRUE(data.has_private_key()); +} + +// Tests that serializing public key without the optional algorithm. +TEST_P(OEMCryptoCoseKeyTest, SerializePublicKey_WithoutAlgorithm) { + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector key_data = + pub_key->SerializeCbor(kWithoutAlgorithm); + ASSERT_FALSE(key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(key_data)); + EXPECT_EQ(data.key_type(), GetExpectedCoseKeyType()); + EXPECT_FALSE(data.has_algorithm()); + EXPECT_EQ(data.algorithm(), std::nullopt); + EXPECT_EQ(data.curve(), GetExpectedCoseCurve()); + + EXPECT_TRUE(data.has_public_key()); + EXPECT_FALSE(data.has_private_key()); +} + +// Tests that serializing public key with the optional algorithm. +TEST_P(OEMCryptoCoseKeyTest, SerializePublicKey_WithAlgorithm) { + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector key_data = pub_key->SerializeCbor(kWithAlgorithm); + ASSERT_FALSE(key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(key_data)); + EXPECT_EQ(data.key_type(), GetExpectedCoseKeyType()); + EXPECT_TRUE(data.has_algorithm()); + EXPECT_EQ(data.algorithm(), GetExpectedCoseAlgorithm()); + EXPECT_EQ(data.curve(), GetExpectedCoseCurve()); + + EXPECT_TRUE(data.has_public_key()); + EXPECT_FALSE(data.has_private_key()); +} + +TEST_P(OEMCryptoCoseKeyTest, SerializeAndReloadPrivateKey_WithoutAlgorithm) { + const std::vector key_data = key_->SerializeCbor(kWithoutAlgorithm); + ASSERT_FALSE(key_data.empty()); + + auto loaded_key_by_buffer = + CosePrivateKey::LoadCborCoseKey(key_data.data(), key_data.size()); + ASSERT_TRUE(loaded_key_by_buffer); + EXPECT_EQ(key_->family(), loaded_key_by_buffer->family()); + + auto loaded_key_by_vector = CosePrivateKey::LoadCborCoseKey(key_data); + ASSERT_TRUE(loaded_key_by_vector); + EXPECT_EQ(key_->family(), loaded_key_by_vector->family()); + + const std::string key_data_str(key_data.begin(), key_data.end()); + auto loaded_key_by_string = CosePrivateKey::LoadCborCoseKey(key_data_str); + ASSERT_TRUE(loaded_key_by_string); + EXPECT_EQ(key_->family(), loaded_key_by_string->family()); +} + +TEST_P(OEMCryptoCoseKeyTest, SerializeAndReloadPrivateKey_WithAlgorithm) { + const std::vector key_data = key_->SerializeCbor(kWithAlgorithm); + ASSERT_FALSE(key_data.empty()); + + auto loaded_key_by_buffer = + CosePrivateKey::LoadCborCoseKey(key_data.data(), key_data.size()); + ASSERT_TRUE(loaded_key_by_buffer); + EXPECT_EQ(key_->family(), loaded_key_by_buffer->family()); + + auto loaded_key_by_vector = CosePrivateKey::LoadCborCoseKey(key_data); + ASSERT_TRUE(loaded_key_by_vector); + EXPECT_EQ(key_->family(), loaded_key_by_vector->family()); + + const std::string key_data_str(key_data.begin(), key_data.end()); + auto loaded_key_by_string = CosePrivateKey::LoadCborCoseKey(key_data_str); + ASSERT_TRUE(loaded_key_by_string); + EXPECT_EQ(key_->family(), loaded_key_by_string->family()); +} + +TEST_P(OEMCryptoCoseKeyTest, + SerializePrivateKeyAndReloadAsPublicKey_WithoutAlgorithm) { + const std::vector key_data = key_->SerializeCbor(kWithoutAlgorithm); + ASSERT_FALSE(key_data.empty()); + + auto loaded_key_by_buffer = + CosePublicKey::LoadCborCoseKey(key_data.data(), key_data.size()); + ASSERT_TRUE(loaded_key_by_buffer); + EXPECT_EQ(key_->family(), loaded_key_by_buffer->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_buffer)); + EXPECT_TRUE(loaded_key_by_buffer->IsMatchingPrivateKey(*key_)); + + auto loaded_key_by_vector = CosePublicKey::LoadCborCoseKey(key_data); + ASSERT_TRUE(loaded_key_by_vector); + EXPECT_EQ(key_->family(), loaded_key_by_vector->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_vector)); + EXPECT_TRUE(loaded_key_by_vector->IsMatchingPrivateKey(*key_)); + + const std::string key_data_str(key_data.begin(), key_data.end()); + auto loaded_key_by_string = CosePublicKey::LoadCborCoseKey(key_data_str); + ASSERT_TRUE(loaded_key_by_string); + EXPECT_EQ(key_->family(), loaded_key_by_string->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_string)); + EXPECT_TRUE(loaded_key_by_string->IsMatchingPrivateKey(*key_)); +} + +TEST_P(OEMCryptoCoseKeyTest, + SerializePrivateKeyAndReloadAsPublicKey_WithAlgorithm) { + const std::vector key_data = key_->SerializeCbor(kWithAlgorithm); + ASSERT_FALSE(key_data.empty()); + + auto loaded_key_by_buffer = + CosePublicKey::LoadCborCoseKey(key_data.data(), key_data.size()); + ASSERT_TRUE(loaded_key_by_buffer); + EXPECT_EQ(key_->family(), loaded_key_by_buffer->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_buffer)); + EXPECT_TRUE(loaded_key_by_buffer->IsMatchingPrivateKey(*key_)); + + auto loaded_key_by_vector = CosePublicKey::LoadCborCoseKey(key_data); + ASSERT_TRUE(loaded_key_by_vector); + EXPECT_EQ(key_->family(), loaded_key_by_vector->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_vector)); + EXPECT_TRUE(loaded_key_by_vector->IsMatchingPrivateKey(*key_)); + + const std::string key_data_str(key_data.begin(), key_data.end()); + auto loaded_key_by_string = CosePublicKey::LoadCborCoseKey(key_data_str); + ASSERT_TRUE(loaded_key_by_string); + EXPECT_EQ(key_->family(), loaded_key_by_string->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_string)); + EXPECT_TRUE(loaded_key_by_string->IsMatchingPrivateKey(*key_)); +} + +TEST_P(OEMCryptoCoseKeyTest, SerializeAndReloadPublicKey_WithoutAlgorithm) { + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector key_data = + pub_key->SerializeCbor(kWithoutAlgorithm); + ASSERT_FALSE(key_data.empty()); + + auto loaded_key_by_buffer = + CosePublicKey::LoadCborCoseKey(key_data.data(), key_data.size()); + ASSERT_TRUE(loaded_key_by_buffer); + EXPECT_EQ(key_->family(), loaded_key_by_buffer->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_buffer)); + EXPECT_TRUE(loaded_key_by_buffer->IsMatchingPrivateKey(*key_)); + + auto loaded_key_by_vector = CosePublicKey::LoadCborCoseKey(key_data); + ASSERT_TRUE(loaded_key_by_vector); + EXPECT_EQ(key_->family(), loaded_key_by_vector->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_vector)); + EXPECT_TRUE(loaded_key_by_vector->IsMatchingPrivateKey(*key_)); + + const std::string key_data_str(key_data.begin(), key_data.end()); + auto loaded_key_by_string = CosePublicKey::LoadCborCoseKey(key_data_str); + ASSERT_TRUE(loaded_key_by_string); + EXPECT_EQ(key_->family(), loaded_key_by_string->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_string)); + EXPECT_TRUE(loaded_key_by_string->IsMatchingPrivateKey(*key_)); +} + +TEST_P(OEMCryptoCoseKeyTest, SerializeAndReloadPublicKey_WithAlgorithm) { + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector key_data = pub_key->SerializeCbor(kWithAlgorithm); + ASSERT_FALSE(key_data.empty()); + + auto loaded_key_by_buffer = + CosePublicKey::LoadCborCoseKey(key_data.data(), key_data.size()); + ASSERT_TRUE(loaded_key_by_buffer); + EXPECT_EQ(key_->family(), loaded_key_by_buffer->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_buffer)); + EXPECT_TRUE(loaded_key_by_buffer->IsMatchingPrivateKey(*key_)); + + auto loaded_key_by_vector = CosePublicKey::LoadCborCoseKey(key_data); + ASSERT_TRUE(loaded_key_by_vector); + EXPECT_EQ(key_->family(), loaded_key_by_vector->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_vector)); + EXPECT_TRUE(loaded_key_by_vector->IsMatchingPrivateKey(*key_)); + + const std::string key_data_str(key_data.begin(), key_data.end()); + auto loaded_key_by_string = CosePublicKey::LoadCborCoseKey(key_data_str); + ASSERT_TRUE(loaded_key_by_string); + EXPECT_EQ(key_->family(), loaded_key_by_string->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key_by_string)); + EXPECT_TRUE(loaded_key_by_string->IsMatchingPrivateKey(*key_)); +} + +// Tests that a serialized public key cannot be reloaded as a private +// key (missing private key component). +TEST_P(OEMCryptoCoseKeyTest, + SerializePublicKeyAndReloadAsPrivateKey_NotAllowed) { + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Without algorithm. + std::vector key_data = pub_key->SerializeCbor(kWithoutAlgorithm); + // Expect to fail + EXPECT_FALSE(CosePrivateKey::LoadCborCoseKey(key_data)); + + // With algorithm. + key_data = pub_key->SerializeCbor(kWithAlgorithm); + // Expect to fail + EXPECT_FALSE(CosePrivateKey::LoadCborCoseKey(key_data)); +} + +// A typical private COSE_Key should include both the private +// component and the public component; however, it should still +// be loadable as a private key without the public key component. +TEST_P(OEMCryptoCoseKeyTest, LoadPrivateKey_NoPublicKeyComponents) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + // Clear the public components. + data.clear_public_key(); + const std::vector partial_key_data = data.SerializeCbor(); + ASSERT_FALSE(partial_key_data.empty()); + + auto loaded_key = CosePrivateKey::LoadCborCoseKey(partial_key_data); + ASSERT_TRUE(loaded_key); + + EXPECT_EQ(key_->family(), loaded_key->family()); + + // When reserialized, the public components should be + // repopulated. + EXPECT_EQ(loaded_key->SerializeCbor(), full_key_data); +} + +TEST_P(OEMCryptoCoseKeyTest, LoadPrivateKey_NoPrivateKeyComponent) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + // Clear private key component. + data.clear_private_key(); + const std::vector partial_key_data = data.SerializeCbor(); + ASSERT_FALSE(partial_key_data.empty()); + + // Should not be able to load. + EXPECT_FALSE(CosePrivateKey::LoadCborCoseKey(partial_key_data)); +} + +// A typical private COSE_Key should include both the private +// component and the public component; however, it should still +// be loadable as a public key without the public key component. +TEST_P(OEMCryptoCoseKeyTest, LoadPublicKey_NoPublicKeyComponents) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + // Clear the public components. + data.clear_public_key(); + const std::vector partial_key_data = data.SerializeCbor(); + ASSERT_FALSE(partial_key_data.empty()); + + // Public should be loadable using the private key + // component. + auto loaded_key = CosePublicKey::LoadCborCoseKey(partial_key_data); + ASSERT_TRUE(loaded_key); + + EXPECT_EQ(key_->family(), loaded_key->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key)); + + // When reserialized, the public components should be + // repopulated (but without private component). + const std::vector public_key_data = loaded_key->SerializeCbor(); + + ASSERT_TRUE(data.ParseCbor(public_key_data)); + EXPECT_TRUE(data.has_public_key()); + EXPECT_FALSE(data.has_private_key()); +} + +TEST_P(OEMCryptoCoseKeyTest, LoadPublicKey_NoPrivateKeyComponent) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + // Clear private key component. + data.clear_private_key(); + const std::vector partial_key_data = data.SerializeCbor(); + ASSERT_FALSE(partial_key_data.empty()); + + // Loading should not cause any problems for public keys. + auto loaded_key = CosePublicKey::LoadCborCoseKey(partial_key_data); + ASSERT_TRUE(loaded_key); + + EXPECT_EQ(key_->family(), loaded_key->family()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key)); + + // When reserialized, it should appear the same. + EXPECT_EQ(loaded_key->SerializeCbor(), partial_key_data); +} + +// For hardening purposes, the COSE key classes reject +// keys if the algorithm is incorrect for key type / curve. +// In this case, the algorithm is something unsupported for +// any EC2 or OKP format. +TEST_P(OEMCryptoCoseKeyTest, LoadKey_UnsupportedAlgorithm) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + constexpr int64_t kUnknownAlgorithm = 14; // AES-MAC 128/64 + data.set_algorithm(kUnknownAlgorithm); + const std::vector bad_key_data = data.SerializeCbor(); + ASSERT_FALSE(bad_key_data.empty()); + + // Should not be able to load. + EXPECT_FALSE(CosePrivateKey::LoadCborCoseKey(bad_key_data)); + EXPECT_FALSE(CosePublicKey::LoadCborCoseKey(bad_key_data)); +} + +// For hardening purposes, the COSE key classes reject +// keys if the algorithm is incorrect for key type / curve. +// Slightly different from LoadPrivateKey_UnsupportedAlgorithm +// where the algorithm is known, but the incorrect algorithm +// for the key. +TEST_P(OEMCryptoCoseKeyTest, LoadKey_IncorrectAlgorithm) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + // If ECC use EdDSA if Ed use ECDSA w/ SHA-256. + const int64_t bad_algorithm = + data.IsOkpData() ? kCoseAlgorithmEs256 : kCoseAlgorithmEdDsa; + data.set_algorithm(bad_algorithm); + const std::vector bad_key_data = data.SerializeCbor(); + ASSERT_FALSE(bad_key_data.empty()); + + // Should not be able to load. + EXPECT_FALSE(CosePrivateKey::LoadCborCoseKey(bad_key_data)); + EXPECT_FALSE(CosePublicKey::LoadCborCoseKey(bad_key_data)); +} + +// Checks that COSE_Key with unsupported curves cannot be loaded. +// Note: the curve is a defined curve and correct for the key type, +// but is not one of the supported curves. +TEST_P(OEMCryptoCoseKeyTest, LoadKey_UnsupportedCurve) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + // If ECC use secp256k1(8) if Ed use Ed448(7). + constexpr int64_t kCoseCurveSecp256k1 = 8; + constexpr int64_t kCoseCurveEd448 = 7; + const int64_t unsupported_curve = + data.IsEc2Data() ? kCoseCurveSecp256k1 : kCoseCurveEd448; + data.set_curve(unsupported_curve); + const std::vector bad_key_data = data.SerializeCbor(); + ASSERT_FALSE(bad_key_data.empty()); + + // Should not be able to load. + EXPECT_FALSE(CosePrivateKey::LoadCborCoseKey(bad_key_data)); + EXPECT_FALSE(CosePublicKey::LoadCborCoseKey(bad_key_data)); +} + +// Checks that COSE_Key with incorrect curves cannot be loaded. +// Note: the curve is one of the support curve types, but does +// is not the correct curve for the key type. +TEST_P(OEMCryptoCoseKeyTest, LoadKey_IncorrectCurve) { + const std::vector full_key_data = key_->SerializeCbor(); + ASSERT_FALSE(full_key_data.empty()); + + CoseKeyData data; + ASSERT_TRUE(data.ParseCbor(full_key_data)); + + // If ECC use Ed25519 if Ed use P-256. + const int64_t bad_curve = + data.IsOkpData() ? kCoseCurveP256 : kCoseCurveEd25519; + data.set_curve(bad_curve); + const std::vector bad_key_data = data.SerializeCbor(); + ASSERT_FALSE(bad_key_data.empty()); + + // Should not be able to load. + EXPECT_FALSE(CosePrivateKey::LoadCborCoseKey(bad_key_data)); + EXPECT_FALSE(CosePublicKey::LoadCborCoseKey(bad_key_data)); +} + +// Basic signature generation using detached content (payload +// is not contain in the COSE_Sign1 signature). +TEST_P(OEMCryptoCoseKeyTest, GenerateCoseSign1Signature_DetachedContent) { + const std::vector payload = RandomData(kMessageSize); + + const std::vector signature = + key_->GenerateCoseSign1Signature(payload, kDetachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + + // Algorithm should be included. + EXPECT_EQ(data.ExtractAlgorithm(), GetExpectedCoseAlgorithm()); + // No payload for detached content. + EXPECT_FALSE(data.has_payload()); + EXPECT_TRUE(data.has_signature()); +} + +// Basic signature generation using attached content (payload +// is contain in the COSE_Sign1 signature). +TEST_P(OEMCryptoCoseKeyTest, GenerateCoseSign1Signature_AttachedContent) { + const std::vector payload = RandomData(kMessageSize); + + const std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + + // Algorithm should be included. + EXPECT_EQ(data.ExtractAlgorithm(), GetExpectedCoseAlgorithm()); + // Payload included for detached content. + EXPECT_TRUE(data.has_payload()); + EXPECT_EQ(data.payload(), payload); + EXPECT_TRUE(data.has_signature()); +} + +// Basic signature generation using detached, empty content (payload +// is not contain in the COSE_Sign1 signature). +// Note: COSE_Sign1 should distinguish between empty content +// and no content. +TEST_P(OEMCryptoCoseKeyTest, + GenerateCoseSign1Signature_EmptyPayload_DetachedContent) { + const std::vector empty_payload; + + const std::vector signature = + key_->GenerateCoseSign1Signature(empty_payload, kDetachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + + // Algorithm should be included. + EXPECT_EQ(data.ExtractAlgorithm(), GetExpectedCoseAlgorithm()); + // No payload for detached content. + EXPECT_FALSE(data.has_payload()); + EXPECT_TRUE(data.has_signature()); +} + +// Basic signature generation using attached, empty content (payload +// is contain in the COSE_Sign1 signature). +// Note: COSE_Sign1 should distinguish between empty content +// and no content. +TEST_P(OEMCryptoCoseKeyTest, + GenerateCoseSign1Signature_EmptyPayload_AttachedContent) { + const std::vector empty_payload; + + const std::vector signature = + key_->GenerateCoseSign1Signature(empty_payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + + // Algorithm should be included. + EXPECT_EQ(data.ExtractAlgorithm(), GetExpectedCoseAlgorithm()); + // Payload included for detached content. + EXPECT_TRUE(data.has_payload()); + EXPECT_EQ(data.payload(), empty_payload); + EXPECT_TRUE(data.has_signature()); +} + +// Basic signature verification using detached content (payload +// is not contain in the COSE_Sign1 signature). +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_DetachedContent) { + const std::vector payload = RandomData(kMessageSize); + + const std::vector signature = + key_->GenerateCoseSign1Signature(payload, kDetachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_SUCCESS); +} + +// Basic signature verification using attached content (payload +// is contain in the COSE_Sign1 signature). +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_AttachedContent) { + const std::vector payload = RandomData(kMessageSize); + + const std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Payload is found within the signature. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(signature), OEMCrypto_SUCCESS); +} + +// Basic signature verification using detached, empty content (payload +// is not contain in the COSE_Sign1 signature). +// Note: COSE_Sign1 should distinguish between empty content +// and no content. +TEST_P(OEMCryptoCoseKeyTest, + VerifyCoseSign1Signature_EmptyPayload_DetachedContent) { + const std::vector empty_payload; + + const std::vector signature = + key_->GenerateCoseSign1Signature(empty_payload, kDetachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(empty_payload, signature), + OEMCrypto_SUCCESS); +} + +// Basic signature verification using attached, empty content (payload +// is contain in the COSE_Sign1 signature). +// Note: COSE_Sign1 should distinguish between empty content +// and no content. +TEST_P(OEMCryptoCoseKeyTest, + VerifyCoseSign1Signature_EmptyPayload_AttachedContent) { + const std::vector empty_payload; + + const std::vector signature = + key_->GenerateCoseSign1Signature(empty_payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Payload is found within the signature. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(signature), OEMCrypto_SUCCESS); +} + +// Tests that if the COSE_Sign1 signature contains the payload +// but the caller also provides the payload that the signature +// can still be verified. +// Note: Payloads must match. +TEST_P(OEMCryptoCoseKeyTest, + VerifyCoseSign1Signature_AttachedContentWithContent) { + const std::vector payload = RandomData(kMessageSize); + + const std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // An attached signature should be able to be verify + // with a separate payload as well; so long as they match. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_SUCCESS); +} + +// Same as VerifyCoseSign1Signature_AttachedContentWithContent, but +// using an empty payload. +TEST_P(OEMCryptoCoseKeyTest, + VerifyCoseSign1Signature_EmptyPayload_AttachedContentWithContent) { + const std::vector empty_payload; + + const std::vector signature = + key_->GenerateCoseSign1Signature(empty_payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // An attached signature should be able to be verify + // with a separate payload as well; so long as they match. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(empty_payload, signature), + OEMCrypto_SUCCESS); +} + +// Tests that signature verification fails if the protected +// parameters of the COSE_Sign1 are modified. +// In this case, the parameters are modified by removing the +// algorithm within the protected parameters. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_RemoveAlgorithm) { + const std::vector payload = RandomData(kMessageSize); + + std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + // Remove the algorithm from signature. + // This modifies the "protected" parameters which should + // cause the signature validation to fail. + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + data.clear_protected_params(); + + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Should fail. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); + + // Now remove the payload from the signature + // and include it separately. + data.clear_payload(); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); +} + +// Tests that signature verification fails if the protected +// parameters of the COSE_Sign1 invalid. +// +// Not enforcing exact error, only that it is not successful. +// +// Note: The internal signature is actually valid. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_InvalidAlgorithm) { + const std::vector payload = RandomData(kMessageSize); + + // The CosePrivateKey class intentionally does not allow for + // omitting the algorithm. Need to parcel our own. + CoseSign1Data data; + // If ECC use EdDSA if Ed use ECDSA w/ SHA-256. + const int64_t bad_algorithm = + GetParam() == kCoseKeyEd25519 ? kCoseAlgorithmEs256 : kCoseAlgorithmEdDsa; + data.SetAlgorithm(bad_algorithm); + ASSERT_FALSE(data.protected_params().empty()) + << "Failed to set invalid algorithm"; + const std::vector signing_payload = + PackageCoseSign1SigStructure(data.protected_params(), payload); + + const std::vector raw_signature = + key_->GenerateRawSignatureForTest(signing_payload); + ASSERT_FALSE(raw_signature.empty()) << "Failed to generate raw signature"; + data.set_signature(raw_signature); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Attempt as detached content. + std::vector signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()) + << "Failed to serialize COSE_Sign1 (detached)"; + + // Should reject with invalid signature algorithm. + EXPECT_NE(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_SUCCESS); + + // Attempt as attached content. + data.set_payload(payload); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()) + << "Failed to serialize COSE_Sign1 (attached)"; + + // Should reject with invalid signature algorithm. + EXPECT_NE(pub_key->VerifyCoseSign1Signature(signature), OEMCrypto_SUCCESS); +} + +// Tests that signature verification succeeds even if the algorithm +// parameters of the COSE_Sign1 missing. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_WithoutAlgorithm) { + const std::vector payload = RandomData(kMessageSize); + + // The CosePrivateKey class intentionally does not allow for + // omitting the algorithm. Need to parcel our own. + CoseSign1Data data; + ASSERT_TRUE(data.protected_params().empty()); + const std::vector signing_payload = + PackageCoseSign1SigStructure(data.protected_params(), payload); + + const std::vector raw_signature = + key_->GenerateRawSignatureForTest(signing_payload); + ASSERT_FALSE(raw_signature.empty()) << "Failed to generate raw signature"; + data.set_signature(raw_signature); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Attempt as detached content. + std::vector signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()) + << "Failed to serialize COSE_Sign1 (detached)"; + + // Without an specified algorithm, the CosePublicKey should + // still be able to select the appropriate algorithm. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_SUCCESS); + + // Attempt as attached content. + data.set_payload(payload); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()) + << "Failed to serialize COSE_Sign1 (attached)"; + + // Without an specified algorithm, the CosePublicKey should + // still be able to select the appropriate algorithm. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(signature), OEMCrypto_SUCCESS); +} + +// Tests that signature verification succeeds even if payload +// is removed. +// Note: The internal payload of COSE_Sign1 is only considered +// protected if if both the sender and the receiver are expected +// attached content. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_RemoveAttachedContent) { + const std::vector payload = RandomData(kMessageSize); + + std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + // Remove the payload from the signature block. + data.clear_payload(); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Should still be able to verify with content is passed + // separately. + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_SUCCESS); + + // Should not be able to verify if content is not passed. + EXPECT_NE(pub_key->VerifyCoseSign1Signature(signature), OEMCrypto_SUCCESS); +} + +// Tests that signature verification fails if the payload +// is missing and the receiver is expecting attached content. +// +// Not enforcing exact error, only that it is not successful. +TEST_P(OEMCryptoCoseKeyTest, + VerifyCoseSign1Signature_DetachedContent_MissingPayload) { + const std::vector payload = RandomData(kMessageSize); + + const std::vector signature = + key_->GenerateCoseSign1Signature(payload, kDetachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Should not be able to verify if content is not passed. + EXPECT_NE(pub_key->VerifyCoseSign1Signature(signature), OEMCrypto_SUCCESS); +} + +// Tests that signature verification fails if signature +// is not a CBOR encoded COSE_Sign1. +// +// Not enforcing exact error, only that it is not successful. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_NotCoseSign1) { + const std::vector payload = RandomData(kMessageSize); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector not_cose_sign1(128, 0x42); + // Provided signature is not a COSE_Sign1 struct. + EXPECT_NE(pub_key->VerifyCoseSign1Signature(not_cose_sign1), + OEMCrypto_SUCCESS); + EXPECT_NE(pub_key->VerifyCoseSign1Signature(payload, not_cose_sign1), + OEMCrypto_SUCCESS); +} + +// Tests that signature verification fails if signature has been +// modified for the same payload. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_ModifySignature) { + const std::vector payload = RandomData(kMessageSize); + + std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + // Modify the raw signature. + std::vector raw_signature = data.signature(); + for (uint8_t& b : raw_signature) { + b ^= 0x42; + } + data.set_signature(raw_signature); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); + + // Same thing, but with detached content. + data.clear_payload(); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); +} + +// Tests that signature verification fails if the payload has been +// modified for the same signature. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_ModifyPayload) { + const std::vector payload = RandomData(kMessageSize); + + std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + // Modify payload. + std::vector payload2 = payload; + for (uint8_t& b : payload2) { + b ^= 0x42; + } + data.set_payload(payload2); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); + + // Same thing, but with detached content. + data.clear_payload(); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload2, signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); +} + +// Tests that signature verification fails if the protected parameters +// have been modified for the same signature. +// +// Note: Even if the algorithm is valid, COSE_Sign1 enforces +// a protection on all parameters. +TEST_P(OEMCryptoCoseKeyTest, VerifyCoseSign1Signature_ModifyProtectedParams) { + const std::vector payload = RandomData(kMessageSize); + + std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + // Modify protected parameters. + // Note: Algorithm is still valid. + data.set_protected_params( + cppbor::Map() + .add(kCoseGenHeaderAlgorithmLabel, GetExpectedCoseAlgorithm()) + .add(-1337, cppbor::Tstr("Something new")) + .canonicalize() + .encode()); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); + + // Same thing, but with detached content. + data.clear_payload(); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + EXPECT_EQ(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_ERROR_SIGNATURE_FAILURE); +} + +// Tests that signature verification fails if the attached +// payload and the detached payload do not match. +// +// Not enforcing exact error, only that it is not successful. +TEST_P(OEMCryptoCoseKeyTest, + VerifyCoseSign1Signature_AttachedVsDetachedPayloadMismatch) { + const std::vector payload = RandomData(kMessageSize); + + std::vector signature = + key_->GenerateCoseSign1Signature(payload, kAttachedContent); + ASSERT_FALSE(signature.empty()) << "Failed to generate signature"; + + // Modify detached payload. + std::vector payload2 = payload; + for (uint8_t& b : payload2) { + b ^= 0x42; + } + + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // When using attached content signatures with separate payload + // the two payloads must match. + // Here attached payload is correct, but detached payload is + // invalid. + EXPECT_NE(pub_key->VerifyCoseSign1Signature(payload2, signature), + OEMCrypto_SUCCESS); + + // Swap the payloads. + CoseSign1Data data; + ASSERT_TRUE(data.ParseCbor(signature)); + data.set_payload(payload2); + signature = data.SerializeCbor(); + ASSERT_FALSE(signature.empty()); + + // Here attached payload is invalid, but detached payload is + // correct. + EXPECT_NE(pub_key->VerifyCoseSign1Signature(payload, signature), + OEMCrypto_SUCCESS); +} + +INSTANTIATE_TEST_SUITE_P(AllFamilies, OEMCryptoCoseKeyTest, + ::testing::Values(kCoseKeyP256, kCoseKeyP384, + kCoseKeyEd25519)); +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/test/oemcrypto_ecc_key_unittest.cpp b/oemcrypto/util/test/oemcrypto_ecc_key_unittest.cpp index 827d755..80aa432 100644 --- a/oemcrypto/util/test/oemcrypto_ecc_key_unittest.cpp +++ b/oemcrypto/util/test/oemcrypto_ecc_key_unittest.cpp @@ -35,47 +35,30 @@ TEST_P(OEMCryptoEccKeyTest, KeyProperties) { EXPECT_NE(nullptr, key_->GetEcKey()); } -// Checks that the private key serialization APIs are compatible -// and performing in a manner that is similar to other OEMCrypto methods -// that retrieve data. -TEST_P(OEMCryptoEccKeyTest, SerializePrivateKey) { - constexpr size_t kInitialBufferSize = 10; // Definitely too small. - size_t buffer_size = kInitialBufferSize; - std::vector buffer(buffer_size); - - EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, - key_->Serialize(buffer.data(), &buffer_size)); - EXPECT_GT(buffer_size, kInitialBufferSize); - - buffer.resize(buffer_size); - EXPECT_EQ(OEMCrypto_SUCCESS, key_->Serialize(buffer.data(), &buffer_size)); - buffer.resize(buffer_size); - - const std::vector direct_key_data = key_->Serialize(); - EXPECT_FALSE(direct_key_data.empty()); - ASSERT_EQ(buffer.size(), direct_key_data.size()); - for (size_t i = 0; i < buffer.size(); i++) { - ASSERT_EQ(buffer[i], direct_key_data[i]) << "i = " << i; - } -} - -// Checks that a private key that is serialized can be deserialized and -// reload. Also checks that the serialization of a key produces the -// same data to ensure consistency. -TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadPrivateKey) { - const std::vector key_data = key_->Serialize(); - std::unique_ptr loaded_key = EccPrivateKey::Load(key_data); - ASSERT_TRUE(loaded_key); - - EXPECT_EQ(key_->curve(), loaded_key->curve()); - - const std::vector loaded_key_data = loaded_key->Serialize(); - ASSERT_EQ(key_data.size(), loaded_key_data.size()); - for (size_t i = 0; i < key_data.size(); i++) { - ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; - } -} +// ==== Serializing / Deserializing Tests ==== +// Within OEMCrypto several transport formats of ECC keys +// are used. The following is a list of all serializing tests +// currently available. +// +// Case # | Format | From | To +// -------|----------------------------|------|------ +// 1 | ASN.1 SubjectPublicKeyInfo | priv | pub +// 2 | ASN.1 SubjectPublicKeyInfo | pub | pub +// 3 | ASN.1 PrivateKeyInfo | priv | priv +// 4 | ASN.1 PrivateKeyInfo | priv | pub +// 5 | SEC 1 Key Point (comp) | priv | pub +// 6 | SEC 1 Key Point (comp) | pub | pub +// 7 | SEC 1 Key Point (uncomp) | priv | pub +// 8 | SEC 1 Key Point (uncomp) | pub | pub +// 9 | X & Y Coord | priv | pub +// 10 | X & Y Coord | pub | pub +// 11 | X Coord & Y Sign Bit | priv | pub +// 12 | X Coord & Y Sign Bit | pub | pub +// 13 | Private Scalar | priv | priv +// 14 | Private Scalar | priv | pub +// Case 1: ASN.1 SubjectPublicKeyInfo - priv to pub +// // Checks that a private key can be serialized as a public key, and // that the serialized public key and be reloaded. TEST_P(OEMCryptoEccKeyTest, SerializePrivateKeyAsPublicKey) { @@ -84,70 +67,17 @@ TEST_P(OEMCryptoEccKeyTest, SerializePrivateKeyAsPublicKey) { auto loaded_key = EccPublicKey::Load(key_data); ASSERT_TRUE(loaded_key) << "Failed to deserialize public key"; + + EXPECT_EQ(key_->curve(), loaded_key->curve()); EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key)); EXPECT_TRUE(loaded_key->IsMatchingPrivateKey(*key_)); + + const std::vector loaded_key_data = loaded_key->Serialize(); + EXPECT_EQ(key_data, loaded_key_data); } -// Checks that a public key can be initialized from a ASN.1 DER encoded -// PrivateKeyInfo message. -TEST_P(OEMCryptoEccKeyTest, SerializePrivateKeyAndReloadAsPublicKey) { - const std::vector key_data = key_->Serialize(); - ASSERT_FALSE(key_data.empty()) << "Failed to serialize as private key"; - - auto key_by_buffer = - EccPublicKey::LoadPrivateKeyInfo(key_data.data(), key_data.size()); - ASSERT_TRUE(key_by_buffer) - << "Failed to deserialize private key into public key"; - EXPECT_TRUE(key_->IsMatchingPublicKey(*key_by_buffer)); - key_by_buffer.reset(); - - auto key_by_vector = EccPublicKey::LoadPrivateKeyInfo(key_data); - ASSERT_TRUE(key_by_vector) - << "Failed to deserialize private key into public key"; - EXPECT_TRUE(key_->IsMatchingPublicKey(*key_by_vector)); - key_by_vector.reset(); - - const std::string key_data_str(key_data.begin(), key_data.end()); - auto key_by_string = EccPublicKey::LoadPrivateKeyInfo(key_data_str); - ASSERT_TRUE(key_by_string) - << "Failed to deserialize private key into public key"; - EXPECT_TRUE(key_->IsMatchingPublicKey(*key_by_string)); -} - -// Checks that a public key can be created from the private key. -TEST_P(OEMCryptoEccKeyTest, DerivePublicKey) { - std::unique_ptr pub_key = key_->MakePublicKey(); - ASSERT_TRUE(pub_key); - EXPECT_TRUE(key_->IsMatchingPublicKey(*pub_key)); -} - -// Checks that a public key that is serialized can be deserialized and -// reload. Also checks that the serialization of a key produces the -// same data to ensure consistency. -TEST_P(OEMCryptoEccKeyTest, SerializePublicKey) { - std::unique_ptr pub_key = key_->MakePublicKey(); - ASSERT_TRUE(pub_key); - - constexpr size_t kInitialBufferSize = 10; // Definitely too small. - size_t buffer_size = kInitialBufferSize; - std::vector buffer(buffer_size); - - EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, - pub_key->Serialize(buffer.data(), &buffer_size)); - EXPECT_GT(buffer_size, kInitialBufferSize); - - buffer.resize(buffer_size); - EXPECT_EQ(OEMCrypto_SUCCESS, pub_key->Serialize(buffer.data(), &buffer_size)); - buffer.resize(buffer_size); - - const std::vector direct_key_data = pub_key->Serialize(); - EXPECT_FALSE(direct_key_data.empty()); - ASSERT_EQ(buffer.size(), direct_key_data.size()); - for (size_t i = 0; i < buffer.size(); i++) { - ASSERT_EQ(buffer[i], direct_key_data[i]) << "i = " << i; - } -} - +// Case 2: ASN.1 SubjectPublicKeyInfo - pub to pub +// // Checks that a public key that is serialized can be deserialized and // reload. Also checks that the serialization of a key produces the // same data to ensure consistency. @@ -158,53 +88,68 @@ TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadPublicKey) { ASSERT_TRUE(pub_key); const std::vector key_data = pub_key->Serialize(); + ASSERT_FALSE(key_data.empty()) << "Failed to serialize public key"; std::unique_ptr loaded_key = EccPublicKey::Load(key_data); ASSERT_TRUE(loaded_key); EXPECT_EQ(pub_key->curve(), loaded_key->curve()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key)); + EXPECT_TRUE(loaded_key->IsMatchingPrivateKey(*key_)); const std::vector loaded_key_data = loaded_key->Serialize(); - ASSERT_EQ(key_data.size(), loaded_key_data.size()); - for (size_t i = 0; i < key_data.size(); i++) { - ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; - } + EXPECT_EQ(key_data, loaded_key_data); } +// Case 3: ASN.1 PrivateKeyInfo - priv to priv +// +// Checks that a private key that is serialized can be deserialized and +// reload. Also checks that the serialization of a key produces the +// same data to ensure consistency. +TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadPrivateKey) { + const std::vector key_data = key_->Serialize(); + ASSERT_FALSE(key_data.empty()) << "Failed to serialize as private key"; + + std::unique_ptr loaded_key = EccPrivateKey::Load(key_data); + ASSERT_TRUE(loaded_key); + + EXPECT_EQ(key_->curve(), loaded_key->curve()); + + const std::vector loaded_key_data = loaded_key->Serialize(); + EXPECT_EQ(key_data, loaded_key_data); +} + +// Case 4: ASN.1 PrivateKeyInfo - priv to pub +// +// Checks that a public key can be initialized from a ASN.1 DER encoded +// PrivateKeyInfo message. +TEST_P(OEMCryptoEccKeyTest, SerializePrivateKeyAndReloadAsPublicKey) { + const std::vector key_data = key_->Serialize(); + ASSERT_FALSE(key_data.empty()) << "Failed to serialize private key"; + + auto key_by_buffer = + EccPublicKey::LoadPrivateKeyInfo(key_data.data(), key_data.size()); + ASSERT_TRUE(key_by_buffer) + << "Failed to deserialize private key into public key"; + EXPECT_TRUE(key_->IsMatchingPublicKey(*key_by_buffer)); + + auto key_by_vector = EccPublicKey::LoadPrivateKeyInfo(key_data); + ASSERT_TRUE(key_by_vector) + << "Failed to deserialize private key into public key"; + EXPECT_TRUE(key_->IsMatchingPublicKey(*key_by_vector)); + + const std::string key_data_str(key_data.begin(), key_data.end()); + auto key_by_string = EccPublicKey::LoadPrivateKeyInfo(key_data_str); + ASSERT_TRUE(key_by_string) + << "Failed to deserialize private key into public key"; + EXPECT_TRUE(key_->IsMatchingPublicKey(*key_by_string)); +} + +// Case 5 & 6: SEC 1 Key Point (comp) - priv and pub to pub +// // Checks that a public and private key can be serialized as a SEC 1 -// key point, and that a public key can be reloaded from the key point. -// Uses the default option for point compression. -TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadAsPublicKeySec1KeyPoint) { - auto pub_key = key_->MakePublicKey(); - ASSERT_TRUE(pub_key); - - const std::vector key_point_from_priv = - key_->SerializeAsPublicSec1KeyPoint(); - ASSERT_FALSE(key_point_from_priv.empty()) - << "Failed to serialize from private key"; - const std::vector key_point_from_pub = - pub_key->SerializeAsSec1KeyPoint(); - ASSERT_FALSE(key_point_from_pub.empty()) - << "Failed to serialize from public key"; - - // Check that both are equal. - EXPECT_EQ(key_point_from_priv, key_point_from_pub); - - // Reload new public key from key point data. - auto loaded_key = - EccPublicKey::LoadKeyPoint(key_->curve(), key_point_from_priv); - ASSERT_TRUE(loaded_key) << "Failed to load public key from key point"; - - const std::vector key_point_from_loaded = - loaded_key->SerializeAsSec1KeyPoint(); - ASSERT_FALSE(key_point_from_loaded.empty()) - << "Failed to serialize from loaded key"; - - EXPECT_EQ(key_point_from_priv, key_point_from_loaded); -} - -// Same as above, except explicitly serializes key point in its -// compressed form. +// key point with point compression, and that a public key can be +// reloaded from the key point. TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadAsPublicKeySec1KeyPointCompressed) { constexpr bool kCompressed = true; @@ -236,8 +181,11 @@ TEST_P(OEMCryptoEccKeyTest, EXPECT_EQ(key_point_from_priv, key_point_from_loaded); } -// Same as above, except explicitly serializes key point in its -// uncompressed form. +// Case 7 & 8: SEC 1 Key Point (uncomp) - priv and pub to pub +// +// Checks that a public and private key can be serialized as a SEC 1 +// key point without point compression, and that a public key can be +// reloaded from the key point. TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadAsPublicKeySec1KeyPointUncompressed) { constexpr bool kUncompressed = false; @@ -261,6 +209,8 @@ TEST_P(OEMCryptoEccKeyTest, EccPublicKey::LoadKeyPoint(key_->curve(), key_point_from_priv); ASSERT_TRUE(loaded_key) << "Failed to load public key from key point"; + EXPECT_EQ(key_->curve(), loaded_key->curve()); + const std::vector key_point_from_loaded = loaded_key->SerializeAsSec1KeyPoint(kUncompressed); ASSERT_FALSE(key_point_from_loaded.empty()) @@ -269,6 +219,176 @@ TEST_P(OEMCryptoEccKeyTest, EXPECT_EQ(key_point_from_priv, key_point_from_loaded); } +// Case 9 & 10: X & Y Coord - priv and pub to pub +// +// Checks that a public and private key can be serialized as a raw +// key point coordinates, and that a public key can be reloaded from +// the key point. +TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadAsPublicKeyCoords) { + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Serialize X & Y from private key. + const std::vector x_coord_from_priv = key_->SerializeXCoord(); + ASSERT_FALSE(x_coord_from_priv.empty()) + << "Failed to serialize X coord from private key"; + const std::vector y_coord_from_priv = key_->SerializeYCoord(); + ASSERT_FALSE(y_coord_from_priv.empty()) + << "Failed to serialize Y coord from private key"; + + // Serialize X & Y from public key. + const std::vector x_coord_from_pub = key_->SerializeXCoord(); + ASSERT_FALSE(x_coord_from_pub.empty()) + << "Failed to serialize X coord from public key"; + const std::vector y_coord_from_pub = key_->SerializeYCoord(); + ASSERT_FALSE(y_coord_from_pub.empty()) + << "Failed to serialize Y coord from public key"; + + // Ensure they are equal. + ASSERT_EQ(x_coord_from_priv, x_coord_from_pub); + ASSERT_EQ(y_coord_from_priv, y_coord_from_pub); + + // Reload new public key from coordinates key point data. + auto loaded_key = EccPublicKey::LoadKeyPoint(key_->curve(), x_coord_from_priv, + y_coord_from_priv); + ASSERT_TRUE(loaded_key) << "Failed to load public key from key point"; + + EXPECT_EQ(key_->curve(), loaded_key->curve()); + + EXPECT_EQ(x_coord_from_priv, loaded_key->SerializeXCoord()); + EXPECT_EQ(y_coord_from_priv, loaded_key->SerializeYCoord()); +} + +// Case 11 & 12: X Coord & Y sign bit - priv and pub to pub +// +// Checks that a public and private key can be serialized as a raw +// key X coordinate and Y sign bit, and that a public key can be +// reloaded from the X coord and Y sign bit. +TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadAsPublicXCoordAndYSignBit) { + auto pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + // Serialize X & Y from private key. + const std::vector x_coord_from_priv = key_->SerializeXCoord(); + ASSERT_FALSE(x_coord_from_priv.empty()) + << "Failed to serialize X coord from private key"; + const bool y_sign_bit_from_priv = key_->GetYSignBit(); + + // Serialize X & Y from public key. + const std::vector x_coord_from_pub = key_->SerializeXCoord(); + ASSERT_FALSE(x_coord_from_pub.empty()) + << "Failed to serialize X coord from public key"; + const bool y_sign_bit_from_pub = key_->GetYSignBit(); + + // Ensure they are equal. + ASSERT_EQ(x_coord_from_priv, x_coord_from_pub); + ASSERT_EQ(y_sign_bit_from_priv, y_sign_bit_from_pub); + + // Reload new public key from coordinates key point data. + auto loaded_key = EccPublicKey::LoadKeyPoint(key_->curve(), x_coord_from_priv, + y_sign_bit_from_priv); + ASSERT_TRUE(loaded_key) << "Failed to load public key from key point"; + + EXPECT_EQ(key_->curve(), loaded_key->curve()); + EXPECT_EQ(x_coord_from_priv, loaded_key->SerializeXCoord()); + EXPECT_EQ(y_sign_bit_from_priv, loaded_key->GetYSignBit()); +} + +// Case 13: Private Scalar - priv to priv +// +// Checks that a private key can be serialized as a raw +// private scalar, and that a private key can be reloaded +// using the scalar. +TEST_P(OEMCryptoEccKeyTest, SerializeAndReloadAsPrivateScalar) { + const std::vector key_data = key_->SerializePrivateScalar(); + ASSERT_FALSE(key_data.empty()) << "Failed to serialize as private key scalar"; + + std::unique_ptr loaded_key = + EccPrivateKey::LoadPrivateScalar(key_->curve(), key_data); + ASSERT_TRUE(loaded_key); + + EXPECT_EQ(key_->curve(), loaded_key->curve()); + const std::vector loaded_key_data = + loaded_key->SerializePrivateScalar(); + EXPECT_EQ(key_data, loaded_key_data); +} + +// Case 14: Private Scalar - priv to pub +// +// Checks that a private key can be serialized as a raw +// private scalar, and that a public key can be reloaded +// using the scalar. +TEST_P(OEMCryptoEccKeyTest, SerializePrivateKeyScalarAndReloadAsPublicKey) { + const std::vector key_data = key_->SerializePrivateScalar(); + ASSERT_FALSE(key_data.empty()) << "Failed to serialize as private key scalar"; + + auto loaded_key = + EccPublicKey::LoadFromPrivateScalar(key_->curve(), key_data); + ASSERT_TRUE(loaded_key) + << "Failed to deserialize private key into public key"; + + EXPECT_EQ(key_->curve(), loaded_key->curve()); + EXPECT_TRUE(key_->IsMatchingPublicKey(*loaded_key)); +} + +// Checks that the private key serialization APIs are compatible +// and performing in a manner that is similar to other OEMCrypto methods +// that retrieve data. +TEST_P(OEMCryptoEccKeyTest, SerializePrivateKey_APICheck) { + constexpr size_t kInitialBufferSize = 10; // Definitely too small. + size_t buffer_size = kInitialBufferSize; + std::vector buffer(buffer_size); + + EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, + key_->Serialize(buffer.data(), &buffer_size)); + EXPECT_GT(buffer_size, kInitialBufferSize); + + buffer.resize(buffer_size); + EXPECT_EQ(OEMCrypto_SUCCESS, key_->Serialize(buffer.data(), &buffer_size)); + buffer.resize(buffer_size); + + const std::vector direct_key_data = key_->Serialize(); + EXPECT_FALSE(direct_key_data.empty()); + ASSERT_EQ(buffer.size(), direct_key_data.size()); + for (size_t i = 0; i < buffer.size(); i++) { + ASSERT_EQ(buffer[i], direct_key_data[i]) << "i = " << i; + } +} + +// Checks that a public key that is serialized can be deserialized and +// reload. Also checks that the serialization of a key produces the +// same data to ensure consistency. +TEST_P(OEMCryptoEccKeyTest, SerializePublicKey_APICheck) { + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + constexpr size_t kInitialBufferSize = 10; // Definitely too small. + size_t buffer_size = kInitialBufferSize; + std::vector buffer(buffer_size); + + EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, + pub_key->Serialize(buffer.data(), &buffer_size)); + EXPECT_GT(buffer_size, kInitialBufferSize); + + buffer.resize(buffer_size); + EXPECT_EQ(OEMCrypto_SUCCESS, pub_key->Serialize(buffer.data(), &buffer_size)); + buffer.resize(buffer_size); + + const std::vector direct_key_data = pub_key->Serialize(); + EXPECT_FALSE(direct_key_data.empty()); + ASSERT_EQ(buffer.size(), direct_key_data.size()); + for (size_t i = 0; i < buffer.size(); i++) { + ASSERT_EQ(buffer[i], direct_key_data[i]) << "i = " << i; + } +} + +// Checks that a public key can be created from the private key. +TEST_P(OEMCryptoEccKeyTest, DerivePublicKey) { + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + EXPECT_TRUE(key_->IsMatchingPublicKey(*pub_key)); +} + // Checks that the ECC signature generating API operates similar to // existing signature generation functions. TEST_P(OEMCryptoEccKeyTest, GenerateSignature) { @@ -380,5 +500,29 @@ TEST_P(OEMCryptoEccKeyTest, DeriveSessionKey) { INSTANTIATE_TEST_SUITE_P(AllCurves, OEMCryptoEccKeyTest, ::testing::Values(kEccSecp256r1, kEccSecp384r1, kEccSecp521r1)); + +TEST(OEMCryptoBasicEccKeyTest, GetSignatureHashAlgorithm) { + // secp256r1 -> SHA-256 + auto private_key = EccPrivateKey::New(kEccSecp256r1); + ASSERT_TRUE(private_key); + EXPECT_EQ(private_key->GetSignatureHashAlgorithm(), OEMCrypto_SHA2_256); + auto public_key = private_key->MakePublicKey(); + ASSERT_TRUE(public_key); + EXPECT_EQ(public_key->GetSignatureHashAlgorithm(), OEMCrypto_SHA2_256); + // secp384r1 -> SHA-384 + private_key = EccPrivateKey::New(kEccSecp384r1); + ASSERT_TRUE(private_key); + EXPECT_EQ(private_key->GetSignatureHashAlgorithm(), OEMCrypto_SHA2_384); + public_key = private_key->MakePublicKey(); + ASSERT_TRUE(public_key); + EXPECT_EQ(public_key->GetSignatureHashAlgorithm(), OEMCrypto_SHA2_384); + // secp521r1 -> SHA-512 + private_key = EccPrivateKey::New(kEccSecp521r1); + ASSERT_TRUE(private_key); + EXPECT_EQ(private_key->GetSignatureHashAlgorithm(), OEMCrypto_SHA2_512); + public_key = private_key->MakePublicKey(); + ASSERT_TRUE(public_key); + EXPECT_EQ(public_key->GetSignatureHashAlgorithm(), OEMCrypto_SHA2_512); +} } // namespace util } // namespace wvoec diff --git a/oemcrypto/util/test/oemcrypto_ed_key_unittest.cpp b/oemcrypto/util/test/oemcrypto_ed_key_unittest.cpp new file mode 100644 index 0000000..6730c77 --- /dev/null +++ b/oemcrypto/util/test/oemcrypto_ed_key_unittest.cpp @@ -0,0 +1,415 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include +#include + +#include + +#include "OEMCryptoCENCCommon.h" +#include "oemcrypto_ed_key.h" +#include "oemcrypto_ref_test_utils.h" +#include "string_conversions.h" + +namespace wvoec { +namespace util { +constexpr size_t kMessageSize = 4 * 1024; // 4 kB + +class OEMCryptoEdKeyTest : public ::testing::Test { + public: + void SetUp() override { + key_ = EdPrivateKey::New(); + ASSERT_TRUE(key_) << "Key initialization failed"; + } + + void TearDown() override { key_.reset(); } + + protected: + std::unique_ptr key_; +}; + +// Checks that the private key serialization APIs are compatible +// and performing in a manner that is similar to other OEMCrypto methods +// that retrieve data. +TEST_F(OEMCryptoEdKeyTest, PrivateKey_SerializePrivateKeyInfo) { + constexpr size_t kInitialBufferSize = 10; // Definitely too small. + size_t buffer_size = kInitialBufferSize; + std::vector buffer(buffer_size); + + EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, + key_->SerializePrivateKeyInfo(buffer.data(), &buffer_size)); + EXPECT_GT(buffer_size, kInitialBufferSize); + + buffer.resize(buffer_size); + EXPECT_EQ(OEMCrypto_SUCCESS, + key_->SerializePrivateKeyInfo(buffer.data(), &buffer_size)); + buffer.resize(buffer_size); + + const std::vector direct_key_data = key_->SerializePrivateKeyInfo(); + EXPECT_FALSE(direct_key_data.empty()); + ASSERT_EQ(buffer.size(), direct_key_data.size()); + for (size_t i = 0; i < buffer.size(); i++) { + ASSERT_EQ(buffer[i], direct_key_data[i]) << "i = " << i; + } +} + +// Checks that a private key that is serialized can be deserialized and +// reload. Also checks that the serialization of a key produces the +// same data to ensure consistency. +TEST_F(OEMCryptoEdKeyTest, PrivateKey_SerializeAndReloadPrivateKeyInfo) { + const std::vector key_data = key_->SerializePrivateKeyInfo(); + std::unique_ptr loaded_key = EdPrivateKey::Load(key_data); + ASSERT_TRUE(loaded_key); + + const std::vector loaded_key_data = + loaded_key->SerializePrivateKeyInfo(); + ASSERT_EQ(key_data.size(), loaded_key_data.size()); + for (size_t i = 0; i < key_data.size(); i++) { + ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; + } +} + +// Checks that a private key that is serialized can be deserialized and +// reload. Also checks that the serialization of a key produces the +// same data to ensure consistency. +TEST_F(OEMCryptoEdKeyTest, PrivateKey_SerializeAndReloadRawPrivateKey) { + const std::vector key_data = key_->SerializeRaw(); + std::unique_ptr loaded_key = EdPrivateKey::LoadRaw(key_data); + ASSERT_TRUE(loaded_key); + + const std::vector loaded_key_data = loaded_key->SerializeRaw(); + ASSERT_EQ(key_data.size(), loaded_key_data.size()); + for (size_t i = 0; i < key_data.size(); i++) { + ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; + } +} + +// Checks that the private key to public key serialization APIs are compatible +// and performing in a manner that is similar to other OEMCrypto methods +// that retrieve data. +TEST_F(OEMCryptoEdKeyTest, PrivateKey_SerializeSubjectPublicKeyInfo) { + constexpr size_t kInitialBufferSize = 10; // Definitely too small. + size_t buffer_size = kInitialBufferSize; + std::vector buffer(buffer_size); + + EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, + key_->SerializeAsSubjectPublicKeyInfo(buffer.data(), &buffer_size)); + EXPECT_GT(buffer_size, kInitialBufferSize); + + buffer.resize(buffer_size); + EXPECT_EQ(OEMCrypto_SUCCESS, + key_->SerializeAsSubjectPublicKeyInfo(buffer.data(), &buffer_size)); + buffer.resize(buffer_size); + + const std::vector direct_key_data = + key_->SerializeAsSubjectPublicKeyInfo(); + EXPECT_FALSE(direct_key_data.empty()); + ASSERT_EQ(buffer.size(), direct_key_data.size()); + for (size_t i = 0; i < buffer.size(); i++) { + ASSERT_EQ(buffer[i], direct_key_data[i]) << "i = " << i; + } +} + +// Checks that the private key to public key that is serialized can be +// desierialized and reloaded. Also checks that the serialization of a +// key produces the same data to ensure consistency. +TEST_F(OEMCryptoEdKeyTest, PrivateKey_SerializeAndReloadSubjectPublicKeyInfo) { + const std::vector key_data = key_->SerializeAsSubjectPublicKeyInfo(); + std::unique_ptr loaded_key = EdPublicKey::Load(key_data); + ASSERT_TRUE(loaded_key); + + const std::vector loaded_key_data = + loaded_key->SerializeSubjectPublicKeyInfo(); + ASSERT_EQ(key_data.size(), loaded_key_data.size()); + for (size_t i = 0; i < key_data.size(); i++) { + ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; + } +} + +// Checks that the private key to public key that is serialized can be +// desierialized and reloaded. Also checks that the serialization of a +// key produces the same data to ensure consistency. +TEST_F(OEMCryptoEdKeyTest, PrivateKey_SerializeAndReloadRawPublicKey) { + const std::vector key_data = key_->SerializeAsRawPublicKey(); + std::unique_ptr loaded_key = EdPublicKey::LoadRaw(key_data); + ASSERT_TRUE(loaded_key); + + const std::vector loaded_key_data = loaded_key->SerializeRaw(); + ASSERT_EQ(key_data.size(), loaded_key_data.size()); + for (size_t i = 0; i < key_data.size(); i++) { + ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; + } +} + +// Checks that a public key that is serialized can be deserialized and +// reload. Also checks that the serialization of a key produces the +// same data to ensure consistency. +TEST_F(OEMCryptoEdKeyTest, PublicKey_SerializeSubjectPublicKeyInfo) { + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + constexpr size_t kInitialBufferSize = 10; // Definitely too small. + size_t buffer_size = kInitialBufferSize; + std::vector buffer(buffer_size); + + EXPECT_EQ( + OEMCrypto_ERROR_SHORT_BUFFER, + pub_key->SerializeSubjectPublicKeyInfo(buffer.data(), &buffer_size)); + EXPECT_GT(buffer_size, kInitialBufferSize); + + buffer.resize(buffer_size); + EXPECT_EQ(OEMCrypto_SUCCESS, pub_key->SerializeSubjectPublicKeyInfo( + buffer.data(), &buffer_size)); + buffer.resize(buffer_size); + + const std::vector direct_key_data = + pub_key->SerializeSubjectPublicKeyInfo(); + EXPECT_FALSE(direct_key_data.empty()); + ASSERT_EQ(buffer.size(), direct_key_data.size()); + for (size_t i = 0; i < buffer.size(); i++) { + ASSERT_EQ(buffer[i], direct_key_data[i]) << "i = " << i; + } +} + +// Checks that the public key serialization APIs are compatible +// and performing in a manner that is similar to other OEMCrypto methods +// that retrieve data. +TEST_F(OEMCryptoEdKeyTest, PublicKey_SerializeAndReloadSubjectPublicKeyInfo) { + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector key_data = + pub_key->SerializeSubjectPublicKeyInfo(); + std::unique_ptr loaded_key = EdPublicKey::Load(key_data); + ASSERT_TRUE(loaded_key); + + const std::vector loaded_key_data = + loaded_key->SerializeSubjectPublicKeyInfo(); + ASSERT_EQ(key_data.size(), loaded_key_data.size()); + for (size_t i = 0; i < key_data.size(); i++) { + ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; + } +} + +// Checks that a public key that is serialized can be deserialized and +// reload. Also checks that the serialization of a key produces the +// same data to ensure consistency. +TEST_F(OEMCryptoEdKeyTest, PublicKey_SerializeAndReloadRawPublicKey) { + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector key_data = pub_key->SerializeRaw(); + std::unique_ptr loaded_key = EdPublicKey::LoadRaw(key_data); + ASSERT_TRUE(loaded_key); + + const std::vector loaded_key_data = loaded_key->SerializeRaw(); + ASSERT_EQ(key_data.size(), loaded_key_data.size()); + for (size_t i = 0; i < key_data.size(); i++) { + ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << i; + } +} + +// Checks that the GenerateSignature() follows the expected +// API behavior as OEMCrypto APIs. +TEST_F(OEMCryptoEdKeyTest, GenerateSignature) { + const std::vector message = RandomData(kMessageSize); + ASSERT_FALSE(message.empty()) << "CdmRandom failed"; + + constexpr size_t kInitialBufferSize = 10; // Definitely too small. + size_t signature_size = kInitialBufferSize; + std::vector signature(signature_size); + EXPECT_EQ(OEMCrypto_ERROR_SHORT_BUFFER, + key_->GenerateSignature(message.data(), message.size(), + signature.data(), &signature_size)); + EXPECT_GT(signature_size, kInitialBufferSize); + + signature.resize(signature_size); + EXPECT_EQ(OEMCrypto_SUCCESS, + key_->GenerateSignature(message.data(), message.size(), + signature.data(), &signature_size)); + signature.resize(signature_size); + + EXPECT_LE(signature_size, key_->SignatureSize()); +} + +// Checks that a generate signature can be verify using the +// equivalent private key. +TEST_F(OEMCryptoEdKeyTest, VerifySignature) { + const std::vector message = RandomData(kMessageSize); + ASSERT_FALSE(message.empty()) << "CdmRandom failed"; + + const std::vector signature = key_->GenerateSignature(message); + ASSERT_FALSE(signature.empty()); + + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + EXPECT_EQ(OEMCrypto_SUCCESS, pub_key->VerifySignature(message, signature)); + + // Check with different message. + const std::vector message_two = RandomData(kMessageSize); + EXPECT_EQ(OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message_two, signature)); + + // Check with bad signature. + const std::vector bad_signature = RandomData(signature.size()); + EXPECT_EQ(OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message, bad_signature)); +} + +struct OEMCryptoEdKeyTestVector { + const char* raw_priv_key_hex; + const char* raw_pub_key_hex; + const char* message_hex; // May be empty. + const char* signature_hex; +}; + +// Test vectors are taken from RFC 8032, section 7.1 (PureEdDSA for Ed25519). +// Note: PureEdDSA is a deterministic function. The same key-data pair +// will alway produce the same signature value. +const OEMCryptoEdKeyTestVector kTestVectors[] = { + // Test 1 + {"9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60", + "d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a", "", + "e5564300c360ac729086e2cc806e828a84877f1eb8e5d974d873e06522490155" + "5fb8821590a33bacc61e39701cf9b46bd25bf5f0595bbe24655141438e7a100b"}, + // Test 2 + {"4ccd089b28ff96da9db6c346ec114e0f5b8a319f35aba624da8cf6ed4fb8a6fb", + "3d4017c3e843895a92b70aa74d1b7ebc9c982ccf2ec4968cc0cd55f12af4660c", "72", + "92a009a9f0d4cab8720e820b5f642540a2b27b5416503f8fb3762223ebdb69da" + "085ac1e43e15996e458f3613d0f11d8c387b2eaeb4302aeeb00d291612bb0c00"}, + // Test 3 + {"c5aa8df43f9f837bedb7442f31dcb7b166d38535076f094b85ce3a2e0b4458f7", + "fc51cd8e6218a1a38da47ed00230f0580816ed13ba3303ac5deb911548908025", "af82", + "6291d657deec24024827e69c3abe01a30ce548a284743a445e3680d7db5ac3ac" + "18ff9b538d16f290ae67f760984dc6594a7c15e9716ed28dc027beceea1ec40a"}, + // Test 1024 + {"f5e5767cf153319517630f226876b86c8160cc583bc013744c6bf255f5cc0ee5", + "278117fc144c72340f67d0f2316e8386ceffbf2b2428c9c51fef7c597f1d426e", + "08b8b2b733424243760fe426a4b54908632110a66c2f6591eabd3345e3e4eb98" + "fa6e264bf09efe12ee50f8f54e9f77b1e355f6c50544e23fb1433ddf73be84d8" + "79de7c0046dc4996d9e773f4bc9efe5738829adb26c81b37c93a1b270b20329d" + "658675fc6ea534e0810a4432826bf58c941efb65d57a338bbd2e26640f89ffbc" + "1a858efcb8550ee3a5e1998bd177e93a7363c344fe6b199ee5d02e82d522c4fe" + "ba15452f80288a821a579116ec6dad2b3b310da903401aa62100ab5d1a36553e" + "06203b33890cc9b832f79ef80560ccb9a39ce767967ed628c6ad573cb116dbef" + "efd75499da96bd68a8a97b928a8bbc103b6621fcde2beca1231d206be6cd9ec7" + "aff6f6c94fcd7204ed3455c68c83f4a41da4af2b74ef5c53f1d8ac70bdcb7ed1" + "85ce81bd84359d44254d95629e9855a94a7c1958d1f8ada5d0532ed8a5aa3fb2" + "d17ba70eb6248e594e1a2297acbbb39d502f1a8c6eb6f1ce22b3de1a1f40cc24" + "554119a831a9aad6079cad88425de6bde1a9187ebb6092cf67bf2b13fd65f270" + "88d78b7e883c8759d2c4f5c65adb7553878ad575f9fad878e80a0c9ba63bcbcc" + "2732e69485bbc9c90bfbd62481d9089beccf80cfe2df16a2cf65bd92dd597b07" + "07e0917af48bbb75fed413d238f5555a7a569d80c3414a8d0859dc65a46128ba" + "b27af87a71314f318c782b23ebfe808b82b0ce26401d2e22f04d83d1255dc51a" + "ddd3b75a2b1ae0784504df543af8969be3ea7082ff7fc9888c144da2af58429e" + "c96031dbcad3dad9af0dcbaaaf268cb8fcffead94f3c7ca495e056a9b47acdb7" + "51fb73e666c6c655ade8297297d07ad1ba5e43f1bca32301651339e22904cc8c" + "42f58c30c04aafdb038dda0847dd988dcda6f3bfd15c4b4c4525004aa06eeff8" + "ca61783aacec57fb3d1f92b0fe2fd1a85f6724517b65e614ad6808d6f6ee34df" + "f7310fdc82aebfd904b01e1dc54b2927094b2db68d6f903b68401adebf5a7e08" + "d78ff4ef5d63653a65040cf9bfd4aca7984a74d37145986780fc0b16ac451649" + "de6188a7dbdf191f64b5fc5e2ab47b57f7f7276cd419c17a3ca8e1b939ae49e4" + "88acba6b965610b5480109c8b17b80e1b7b750dfc7598d5d5011fd2dcc5600a3" + "2ef5b52a1ecc820e308aa342721aac0943bf6686b64b2579376504ccc493d97e" + "6aed3fb0f9cd71a43dd497f01f17c0e2cb3797aa2a2f256656168e6c496afc5f" + "b93246f6b1116398a346f1a641f3b041e989f7914f90cc2c7fff357876e506b5" + "0d334ba77c225bc307ba537152f3f1610e4eafe595f6d9d90d11faa933a15ef1" + "369546868a7f3a45a96768d40fd9d03412c091c6315cf4fde7cb68606937380d" + "b2eaaa707b4c4185c32eddcdd306705e4dc1ffc872eeee475a64dfac86aba41c" + "0618983f8741c5ef68d3a101e8a3b8cac60c905c15fc910840b94c00a0b9d0", + "0aab4c900501b3e24d7cdf4663326a3a87df5e4843b2cbdb67cbf6e460fec350" + "aa5371b1508f9f4528ecea23c436d94b5e8fcd4f681e30a6ac00a9704a188a03"}, + // Test SHA(abc) + {"833fe62409237b9d62ec77587520911e9a759cec1d19755b7da901b96dca3d42", + "ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf", + "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a" + "2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f", + "dc2a4459e7369633a52b1bf277839a00201009a3efbf3ecb69bea2186c26b589" + "09351fc9ac90b3ecfdfbc7c66431e0303dca179c138ac17ad9bef1177331a704"}}; + +class OEMCryptoEdKeyVectorTest + : public ::testing::TestWithParam { + public: + void SetUp() override { + const auto& params = GetParam(); + // Load private key. + raw_private_key_ = wvutil::a2b_hex(params.raw_priv_key_hex); + ASSERT_FALSE(raw_private_key_.empty()) + << "Failed to decode raw private key: " << params.raw_priv_key_hex; + private_key_ = EdPrivateKey::LoadRaw(raw_private_key_); + ASSERT_TRUE(private_key_) + << "Failed to parse raw private key: " << params.raw_priv_key_hex; + // Load public key. + raw_public_key_ = wvutil::a2b_hex(params.raw_pub_key_hex); + ASSERT_FALSE(raw_public_key_.empty()) + << "Failed to decode raw public key: " << params.raw_pub_key_hex; + public_key_ = EdPublicKey::LoadRaw(raw_public_key_); + ASSERT_TRUE(private_key_) + << "Failed to parse raw public key: " << params.raw_pub_key_hex; + // Parse message. + const std::string message_hex = params.message_hex; + if (!message_hex.empty()) { + message_ = wvutil::a2b_hex(message_hex); + ASSERT_FALSE(message_.empty()) + << "Failed to decode message: " << message_hex; + } + // Parse signature. + signature_ = wvutil::a2b_hex(params.signature_hex); + ASSERT_FALSE(signature_.empty()) + << "Failed to decode signature: " << params.signature_hex; + } + + std::vector raw_private_key_; + std::unique_ptr private_key_; + + std::vector raw_public_key_; + std::unique_ptr public_key_; + + std::vector message_; + std::vector signature_; +}; + +TEST_P(OEMCryptoEdKeyVectorTest, IsMatchingPair) { + EXPECT_TRUE(private_key_->IsMatchingPublicKey(*public_key_)); + EXPECT_TRUE(public_key_->IsMatchingPrivateKey(*private_key_)); +} + +TEST_P(OEMCryptoEdKeyVectorTest, PrivateKey_SerializeRaw) { + const std::vector raw = private_key_->SerializeRaw(); + ASSERT_FALSE(raw.empty()); + EXPECT_EQ(raw, raw_private_key_); +} + +TEST_P(OEMCryptoEdKeyVectorTest, PrivateKey_SerializeRawPublicKey) { + const std::vector raw = private_key_->SerializeAsRawPublicKey(); + ASSERT_FALSE(raw.empty()); + EXPECT_EQ(raw, raw_public_key_); +} + +TEST_P(OEMCryptoEdKeyVectorTest, PrivateKey_GenerateSignature) { + const std::vector computed_signature = + private_key_->GenerateSignature(message_); + ASSERT_FALSE(computed_signature.empty()); + EXPECT_EQ(computed_signature, signature_); +} + +TEST_P(OEMCryptoEdKeyVectorTest, PublicKey_SerializeRaw) { + const std::vector raw = public_key_->SerializeRaw(); + ASSERT_FALSE(raw.empty()); + EXPECT_EQ(raw, raw_public_key_); +} + +TEST_P(OEMCryptoEdKeyVectorTest, PublicKey_VerifySignature) { + const OEMCryptoResult result = + public_key_->VerifySignature(message_, signature_); + EXPECT_EQ(result, OEMCrypto_SUCCESS); +} + +INSTANTIATE_TEST_SUITE_P(Rfc8032, OEMCryptoEdKeyVectorTest, + ::testing::ValuesIn(kTestVectors)); +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/test/oemcrypto_oem_cert_chain_unittest.cpp b/oemcrypto/util/test/oemcrypto_oem_cert_chain_unittest.cpp new file mode 100644 index 0000000..f53637a --- /dev/null +++ b/oemcrypto/util/test/oemcrypto_oem_cert_chain_unittest.cpp @@ -0,0 +1,226 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +// +// Reference implementation utilities of OEMCrypto APIs +// +#include "oemcrypto_oem_cert_chain.h" + +#include + +#include +#include + +#include + +#include "oemcrypto_drm_key.h" +#include "string_conversions.h" + +namespace wvoec { +namespace util { +namespace { +// Sample OEM Certificate Chain w/ ECC Leaf Certificate. +// +// The following is a human readable representation of the +// contents. +// +// deviceCert = { +// serialNumber = 17962145401824487874, +// subject = { +// countryName = "US", +// stateOrProvinceName = "Washington", +// localityName = "Seattle", +// organizationName = "Google", +// organizationalUnitName = "Widevine", +// commonName = "{25676}-leaf" +// }, +// subjectPublicKeyInfo = { +// algorithm = { +// algorithm = ecPublicKey, +// parameters = prime256v1 +// }, +// subjectPublicKey = +// 0x04 || +// 0x589b668ba516222ff25222d01bcf124afc86ef023d7e3aef7867e379d428000b || +// 0xa88d38787d911d47020475cfda07396266b7eea766d336c71d320ff4f8652d5d, +// } +// }, +// extensions = { +// subjectKeyIdentifier = 0x1BBC63269CBFECA8432145ADE4679F58F56D1043, +// authorityKeyIdentifier = 0x9B047CEA4BEA39828EA20BB779358FC4F9036793, +// systemId = 25676 +// } +// }, +// leafCert = { +// serialNumber = 103280693203998492265742473323976560102, +// subject = { +// countryName = "US", +// stateOrProvinceName = "Washington", +// localityName = "Seattle", +// organizationName = "Google", +// organizationalUnitName = "Widevine", +// commonName = "Provisioning 4.0; system id: 25676" +// }, +// subjectPublicKeyInfo = { +// algorithm = { +// algorithm = rsaEncryption, +// parameters = null +// }, +// subjectPublicKey = { +// modulos = // Note: base10 +// 2212313408739241601759361433600117663153478943283760939064040828 || +// 9863818494081550091984549656244030116693481379492747559361927679 || +// 4132689719233272093921143381342138433601357997596667767394647997 || +// 5623221505050854448044073563590298743731429721095025561591228331 || +// 6868598890841310791397063262243203501233033288435072965387540422 || +// 3522388263167185658009649249613740268155231131120254496134188007 || +// 8544062420698625200678212956751967041487698953402818835366022930 || +// 6782626000346612966301260932056306971569229014672622118771547852 || +// 5963633536158635155467924864521142324709706455318729358005289005 || +// 17664686160742744082752683211029730200339 +// publicExponent = 65537 +// } +// }, +// extensions = { +// subjectKeyIdentifier = 0x9B047CEA4BEA39828EA20BB779358FC4F9036793, +// authorityKeyIdentifier = 0x049466AAF96189B6DBB5F713383D6284B8180A8F, +// systemId = 25676 +// } +// } +const char kEccOemCertChainHex[] = + "3082091906092a864886f70d010702a082090a308209060201013100300b0609" + "2a864886f70d010701a08208ee308203e7308202cfa003020102020900f9465c" + "134faa29c2300d06092a864886f70d0101050500307d310b3009060355040613" + "025553310b300906035504080c0257413110300e06035504070c075365617474" + "6c65310f300d060355040a0c06476f6f676c653111300f060355040b0c085769" + "646576696e65312b302906035504030c2250726f766973696f6e696e6720342e" + "303b2073797374656d2069643a203235363736301e170d323530323235313834" + "3331355a170d3434313131323138343331355a3067310b300906035504061302" + "5553310b300906035504080c0257413110300e06035504070c0753656174746c" + "65310f300d060355040a0c06476f6f676c653111300f060355040b0c08576964" + "6576696e653115301306035504030c0c7b32353637367d2d6c65616630593013" + "06072a8648ce3d020106082a8648ce3d03010703420004589b668ba516222ff2" + "5222d01bcf124afc86ef023d7e3aef7867e379d428000ba88d38787d911d4702" + "0475cfda07396266b7eea766d336c71d320ff4f8652d5da38201493082014530" + "0c0603551d130101ff04023000301d0603551d0e041604141bbc63269cbfeca8" + "432145ade4679f58f56d1043301f0603551d230418301680149b047cea4bea39" + "828ea20bb779358fc4f90367933012060a2b06010401d6790401010404020264" + "4c3011060a2b06010401d67904010204030101ff3081cd060a2b06010401d679" + "0401040481be0381bb00080110001a9101041b1d4e645d1a37e9a49df6c9a248" + "572b4dec1230d404d35ba76a54546cffc3dae3cd95710c67025aa52853c42ba2" + "3dd44f1aa455e692eda395cb7db06c74c2fb0077b0c6e68af1e9a058886b9fd6" + "1e89a28ce2417acb243dc0106fdb467ff7e7400e2dcc5cd87fec2b551e0681b3" + "4a8366bcbf6af15e5292d99bb201f027ef7de5a1cc5efd5098fafa65f2e9a9b9" + "b768222098645f56603ccab6aaaa9315e8d24d02688112b5dd38219d363edeeb" + "a5d38db8300d06092a864886f70d01010505000382010100442c40c54dc71157" + "e5a86a9bc3ba0dbf7c8e8e1e792c3c9aeee84979b7adbe4c1658dd9d812baf97" + "5d666754e2ba84e0b4047cead9e7eda9a213d8fd530d81ec3157b7b8c7155598" + "52cd9b151aa33e0377b5b02ee9f090bb6ac5f8913e5c47137180604417f03838" + "d611192ade04d2c12439764197be01d31c751f16e7e01ab77b08f64c2f48149f" + "c5d59053e3397edb6b868636490ca5888624ae4bc042a164c80656653e89d595" + "1cc685aaef8548469ac6f4326161ee463c01c50b8bfedb8ee0e9a3d282fba956" + "1af12745b1447317248498d7be8d0457a85553a83e4ed62c18111bf99dce5b73" + "10f416548a510dbffd0c4a07b8e499d35ac7119d03c22beb308204ff308202e7" + "a00302010202104db323521fc8cb98db7e17faead78de6300d06092a864886f7" + "0d01010b0500307e310b30090603550406130255533113301106035504080c0a" + "57617368696e67746f6e3111300f06035504070c084b69726b6c616e64310f30" + "0d060355040a0c06476f6f676c653111300f060355040b0c085769646576696e" + "653123302106035504030c1a7769646576696e652e636f6d2f6f656d2d726f6f" + "742d70726f64301e170d3232303231363138333634365a170d34323032313131" + "38333634365a307d310b3009060355040613025553310b300906035504080c02" + "57413110300e06035504070c0753656174746c65310f300d060355040a0c0647" + "6f6f676c653111300f060355040b0c085769646576696e65312b302906035504" + "030c2250726f766973696f6e696e6720342e303b2073797374656d2069643a20" + "323536373630820122300d06092a864886f70d01010105000382010f00308201" + "0a0282010100af3fbd1aa2b42e3626b75fda46cb0f7d7ab6bf80746fa839603d" + "ccf8dbd911b971f62640700be3592b24b9a7452aa7c92de644fe2a9679f924c1" + "67cb9f7b2d00f73ac2befe44077ccbe9788a5ff28d7a44efa1df93251c1c9367" + "5deaf9838b71c41eb95c99ad0c7e249eb171f33d64f77c0a39731958030301d3" + "99df4af4138e65d4db4bb0523e18d1222b475d70f971fd9622c4570e7eadac51" + "336284a0c92ce2ff0e57f916486c9f6ac793448c325f1d2016b8073fe20ae7e4" + "4473bbb2a2d089f93e4c712842f7632dca74acfd1f4e8cbdcc1bb649ca922dba" + "b03a7943a4a190a4640ee3cbe5c04cf293cb0ce89f761289caca90114cb8896f" + "56452ab28b130203010001a37a307830120603551d130101ff040830060101ff" + "020100300e0603551d0f0101ff040403020204301d0603551d0e041604149b04" + "7cea4bea39828ea20bb779358fc4f9036793301f0603551d2304183016801404" + "9466aaf96189b6dbb5f713383d6284b8180a8f3012060a2b06010401d6790401" + "0104040202644c300d06092a864886f70d01010b050003820201002858a916ed" + "57b77875e8ac70f9e0522b1d03c17c8bb30785577b82f6fddaf12f4c84fbc82e" + "8fce5d557715b5f4a3550fd4a0facb9ee9a0c6aaab9f662b9b43e355d23adee1" + "d57a15280ff8a64fdb659ea1c7c69d6049e214806ac384c234ecf324b3cb7180" + "ed83ca0bede7aa5c2034e5c9aa576618e1eb8e96c6d7dc9dd299806c4598c95a" + "bfb73fa7adf274e1022fc934ec14d770875c7eab8d6f9d3fdb7cbfeea095ce1b" + "f6989e0ca65fd32d32d1ba597db56e67478492bb5db8a90ab692191da4c2407d" + "741b63bbd5d99318586bd8205ca06ede857932e01773fdad0a4759e08e45900e" + "2dc1d2b4e05f9dc915131fc9a7325a04841c7c24dbec009d3c8588ef3992ab0a" + "2209827d03fec9f716f82ec8e25c54297d00d34b471dbee366605ac08a0700d0" + "23d6346f15c91a1986fd1a1d359fe8d6515ac02d69e2f9ab4046873cfb48d799" + "a58abf686e61cf73035605bc1643000e13744035cc8290794bc5456ae1ea1f5c" + "5aff87c3240d2cc2672f943ab756c8f60eb6c3d158b41972c4e714f69196e314" + "69bf2c6aa47c687354d12e6bffbc8164461ca5e7190352643d882f6b2efcad8e" + "9ddd627a26a89980bf3aef715c6b2c5efdd28d3f1b54511719a20408b76ca279" + "9daa58ec245e1b9122657474715c47d51bcc459cae9702e35447bed3e9018a83" + "95ad89ee2135e8290a9168d3757fa8f1a5a1c1e5185fb2519614bb3100"; + +const char kEccOemLeafEccSubjectPublicKeyInfoHex[] = + "3059301306072a8648ce3d020106082a8648ce3d03010703420004589b668ba5" + "16222ff25222d01bcf124afc86ef023d7e3aef7867e379d428000ba88d38787d" + "911d47020475cfda07396266b7eea766d336c71d320ff4f8652d5d"; + +const char kEccOemIntermediateRsaSubjectPublicKeyInfoHex[] = + "30820122300d06092a864886f70d01010105000382010f003082010a02820101" + "00af3fbd1aa2b42e3626b75fda46cb0f7d7ab6bf80746fa839603dccf8dbd911" + "b971f62640700be3592b24b9a7452aa7c92de644fe2a9679f924c167cb9f7b2d" + "00f73ac2befe44077ccbe9788a5ff28d7a44efa1df93251c1c93675deaf9838b" + "71c41eb95c99ad0c7e249eb171f33d64f77c0a39731958030301d399df4af413" + "8e65d4db4bb0523e18d1222b475d70f971fd9622c4570e7eadac51336284a0c9" + "2ce2ff0e57f916486c9f6ac793448c325f1d2016b8073fe20ae7e44473bbb2a2" + "d089f93e4c712842f7632dca74acfd1f4e8cbdcc1bb649ca922dbab03a7943a4" + "a190a4640ee3cbe5c04cf293cb0ce89f761289caca90114cb8896f56452ab28b" + "130203010001"; +} // namespace + +TEST(OEMCryptoOemCertChainTest, EccLeafCert) { + const std::vector cert_chain_data = + wvutil::a2b_hex(kEccOemCertChainHex); + ASSERT_FALSE(cert_chain_data.empty()) + << "Failed to decode test OEM cert chain"; + + auto cert_chain = OemCertificateChain::LoadPkcs7(cert_chain_data); + ASSERT_TRUE(cert_chain) << "Failed to parse cert chain"; + + // Check leaf certificate. + const OemPublicCertificate* device_cert = cert_chain->device_cert(); + ASSERT_NE(device_cert, nullptr); + + EXPECT_EQ(device_cert->GetSubjectPublicKeyType(), OemCertKeyType::kEcc); + + const std::vector expected_device_cert_data = + wvutil::a2b_hex(kEccOemLeafEccSubjectPublicKeyInfoHex); + ASSERT_FALSE(expected_device_cert_data.empty()) + << "Failed to decode test leaf SubjectPublicKeyInfo"; + const std::vector device_cert_data = + device_cert->SerializedSubjectPublicKeyInfo(); + ASSERT_FALSE(device_cert_data.empty()) + << "Serialize to leaf SubjectPublicKeyInfo"; + EXPECT_EQ(device_cert_data, expected_device_cert_data); + + // Check intermediate certificate. + const OemPublicCertificate* intermediate_cert = + cert_chain->intermediate_cert(); + ASSERT_NE(intermediate_cert, nullptr); + + EXPECT_EQ(intermediate_cert->GetSubjectPublicKeyType(), OemCertKeyType::kRsa); + + const std::vector expected_intermediate_cert_data = + wvutil::a2b_hex(kEccOemIntermediateRsaSubjectPublicKeyInfoHex); + ASSERT_FALSE(expected_intermediate_cert_data.empty()) + << "Failed to decode test leaf SubjectPublicKeyInfo"; + const std::vector intermediate_cert_data = + intermediate_cert->SerializedSubjectPublicKeyInfo(); + ASSERT_FALSE(intermediate_cert_data.empty()) + << "Serialize to leaf SubjectPublicKeyInfo"; + EXPECT_EQ(intermediate_cert_data, expected_intermediate_cert_data); +} +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/test/oemcrypto_oem_cert_unittest.cpp b/oemcrypto/util/test/oemcrypto_oem_cert_unittest.cpp index 6239045..06e0642 100644 --- a/oemcrypto/util/test/oemcrypto_oem_cert_unittest.cpp +++ b/oemcrypto/util/test/oemcrypto_oem_cert_unittest.cpp @@ -34,7 +34,7 @@ TEST(OEMCryptoOemCertTest, CreateFromArray) { kOEMPrivateKey, kOEMPrivateKeySize, kOEMPublicCert, kOEMPublicCertSize); ASSERT_TRUE(oem_cert); - EXPECT_EQ(OemCertificate::kRsa, oem_cert->key_type()); + EXPECT_EQ(OemCertKeyType::kRsa, oem_cert->GetKeyType()); const std::vector private_key = oem_cert->GetPrivateKey(); EXPECT_EQ(kOEMPrivateKeyVector, private_key); @@ -58,7 +58,7 @@ TEST(OEMCryptoOemCertTest, CreateFromVector) { OemCertificate::Create(kOEMPrivateKeyVector, kOEMPublicCertVector); ASSERT_TRUE(oem_cert); - EXPECT_EQ(OemCertificate::kRsa, oem_cert->key_type()); + EXPECT_EQ(OemCertKeyType::kRsa, oem_cert->GetKeyType()); const std::vector private_key = oem_cert->GetPrivateKey(); EXPECT_EQ(kOEMPrivateKeyVector, private_key); @@ -107,7 +107,7 @@ TEST(OEMCryptoOemCertTest, CreateWithDifferentPrivateRsaKey) { OemCertificate::Create(private_key, kOEMPublicCertVector); ASSERT_TRUE(oem_cert); - EXPECT_EQ(OemCertificate::kRsa, oem_cert->key_type()); + EXPECT_EQ(OemCertKeyType::kRsa, oem_cert->GetKeyType()); // Validating key should return an error. OEMCryptoResult status = oem_cert->IsCertificateValid(); diff --git a/oemcrypto/util/test/oemcrypto_rsa_key_unittest.cpp b/oemcrypto/util/test/oemcrypto_rsa_key_unittest.cpp index f2cb96c..92a88e2 100644 --- a/oemcrypto/util/test/oemcrypto_rsa_key_unittest.cpp +++ b/oemcrypto/util/test/oemcrypto_rsa_key_unittest.cpp @@ -56,8 +56,7 @@ class OEMCryptoRsaKeyTest : public ::testing::TestWithParam { } } break; case kRsaFieldUnknown: // Suppress compiler warnings - LOGE("RSA test was incorrectly instantiation"); - exit(EXIT_FAILURE); + FAIL() << "RSA test was incorrectly instantiation"; break; } ASSERT_TRUE(key_) << "Key initialization failed " @@ -193,6 +192,7 @@ TEST_P(OEMCryptoRsaKeyTest, SerializeAndReloadPublicKey) { EXPECT_EQ(pub_key->field_size(), loaded_key->field_size()); EXPECT_EQ(pub_key->allowed_schemes(), loaded_key->allowed_schemes()); + EXPECT_TRUE(pub_key->IsMatchingPublicKey(*loaded_key)); const std::vector loaded_key_data = loaded_key->Serialize(); ASSERT_EQ(key_data.size(), loaded_key_data.size()); @@ -201,6 +201,31 @@ TEST_P(OEMCryptoRsaKeyTest, SerializeAndReloadPublicKey) { } } +// Checks that a public key that is serialized can be deserialized and +// reload as RSAPublicKey. Also checks that the serialization of a key +// produces the same data to ensure consistency. +TEST_P(OEMCryptoRsaKeyTest, SerializeAndReloadPublicKeyAsRsaPublicKey) { + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + const std::vector key_data = pub_key->SerializeAsRsaPublicKey(); + + std::unique_ptr loaded_key = + RsaPublicKey::LoadRsaPublicKey(key_data); + ASSERT_TRUE(loaded_key); + + EXPECT_EQ(pub_key->field_size(), loaded_key->field_size()); + EXPECT_EQ(pub_key->allowed_schemes(), loaded_key->allowed_schemes()); + EXPECT_TRUE(pub_key->IsMatchingPublicKey(*loaded_key)); + + const std::vector loaded_key_data = + loaded_key->SerializeAsRsaPublicKey(); + ASSERT_EQ(key_data.size(), loaded_key_data.size()); + for (size_t i = 0; i < key_data.size(); i++) { + ASSERT_EQ(key_data[i], loaded_key_data[i]) << "i = " << std::to_string(i); + } +} + // Checks that a public key can be initialized from a ASN.1 DER encoded // PrivateKeyInfo message. TEST_P(OEMCryptoRsaKeyTest, SerializePrivateKeyAndReloadAsPublicKey) { @@ -253,26 +278,123 @@ TEST_P(OEMCryptoRsaKeyTest, GenerateSignature) { } // Checks that RSA signatures can be verified by an RSA public key. -TEST_P(OEMCryptoRsaKeyTest, VerifySignature) { +// using SHA-1 as the hash algorithm. +TEST_P(OEMCryptoRsaKeyTest, VerifySignatureSha1) { + constexpr RsaSignatureAlgorithm kSigningAlgorithm = kRsaPssSha1; + const std::vector message = RandomData(kMessageSize); ASSERT_FALSE(message.empty()) << "CdmRandom failed"; - const std::vector signature = key_->GenerateSignature(message); + const std::vector signature = + key_->GenerateSignature(message, kSigningAlgorithm); std::unique_ptr pub_key = key_->MakePublicKey(); ASSERT_TRUE(pub_key); - EXPECT_EQ(OEMCrypto_SUCCESS, pub_key->VerifySignature(message, signature)); + EXPECT_EQ(OEMCrypto_SUCCESS, + pub_key->VerifySignature(message, signature, kSigningAlgorithm)); // Check with different message. const std::vector message_two = RandomData(kMessageSize); - EXPECT_EQ(OEMCrypto_ERROR_SIGNATURE_FAILURE, - pub_key->VerifySignature(message_two, signature)); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message_two, signature, kSigningAlgorithm)); // Check with bad signature. const std::vector bad_signature = RandomData(signature.size()); - EXPECT_EQ(OEMCrypto_ERROR_SIGNATURE_FAILURE, - pub_key->VerifySignature(message, bad_signature)); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message, bad_signature, kSigningAlgorithm)); +} + +// Checks that RSA signatures can be verified by an RSA public key +// using SHA-256 as the hash algorithm. +TEST_P(OEMCryptoRsaKeyTest, VerifySignatureSha256) { + constexpr RsaSignatureAlgorithm kSigningAlgorithm = kRsaPssSha256; + + const std::vector message = RandomData(kMessageSize); + ASSERT_FALSE(message.empty()) << "CdmRandom failed"; + + const std::vector signature = + key_->GenerateSignature(message, kSigningAlgorithm); + + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + EXPECT_EQ(OEMCrypto_SUCCESS, + pub_key->VerifySignature(message, signature, kSigningAlgorithm)); + + // Check with different message. + const std::vector message_two = RandomData(kMessageSize); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message_two, signature, kSigningAlgorithm)); + + // Check with bad signature. + const std::vector bad_signature = RandomData(signature.size()); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message, bad_signature, kSigningAlgorithm)); +} + +// Checks that RSA signatures can be verified by an RSA public key +// using SHA-384 as the hash algorithm. +TEST_P(OEMCryptoRsaKeyTest, VerifySignatureSha384) { + constexpr RsaSignatureAlgorithm kSigningAlgorithm = kRsaPssSha384; + + const std::vector message = RandomData(kMessageSize); + ASSERT_FALSE(message.empty()) << "CdmRandom failed"; + + const std::vector signature = + key_->GenerateSignature(message, kSigningAlgorithm); + + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + EXPECT_EQ(OEMCrypto_SUCCESS, + pub_key->VerifySignature(message, signature, kSigningAlgorithm)); + + // Check with different message. + const std::vector message_two = RandomData(kMessageSize); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message_two, signature, kSigningAlgorithm)); + + // Check with bad signature. + const std::vector bad_signature = RandomData(signature.size()); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message, bad_signature, kSigningAlgorithm)); +} + +// Checks that RSA signatures can be verified by an RSA public key +// using SHA-512 as the hash algorithm. +TEST_P(OEMCryptoRsaKeyTest, VerifySignatureSha512) { + constexpr RsaSignatureAlgorithm kSigningAlgorithm = kRsaPssSha512; + + const std::vector message = RandomData(kMessageSize); + ASSERT_FALSE(message.empty()) << "CdmRandom failed"; + + const std::vector signature = + key_->GenerateSignature(message, kSigningAlgorithm); + + std::unique_ptr pub_key = key_->MakePublicKey(); + ASSERT_TRUE(pub_key); + + EXPECT_EQ(OEMCrypto_SUCCESS, + pub_key->VerifySignature(message, signature, kSigningAlgorithm)); + + // Check with different message. + const std::vector message_two = RandomData(kMessageSize); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message_two, signature, kSigningAlgorithm)); + + // Check with bad signature. + const std::vector bad_signature = RandomData(signature.size()); + EXPECT_EQ( + OEMCrypto_ERROR_SIGNATURE_FAILURE, + pub_key->VerifySignature(message, bad_signature, kSigningAlgorithm)); } // Checks that the special CAST receiver signature scheme works @@ -413,5 +535,16 @@ TEST_P(OEMCryptoRsaKeyTest, ShareEncryptionKey) { INSTANTIATE_TEST_SUITE_P(AllFieldSizes, OEMCryptoRsaKeyTest, ::testing::Values(kRsa2048Bit, kRsa3072Bit)); + +TEST(OEMCryptoBasicRsaKeyTest, RsaSignatureAlgorithmConverter) { + EXPECT_EQ(RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm(OEMCrypto_SHA1), + kRsaPssSha1); + EXPECT_EQ(RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm(OEMCrypto_SHA2_256), + kRsaPssSha256); + EXPECT_EQ(RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm(OEMCrypto_SHA2_384), + kRsaPssSha384); + EXPECT_EQ(RsaPssSignatureAlgorithmFromOEMCryptoAlgorithm(OEMCrypto_SHA2_512), + kRsaPssSha512); +} } // namespace util } // namespace wvoec diff --git a/oemcrypto/util/test/oemcrypto_wvcrc32_unittest.cpp b/oemcrypto/util/test/oemcrypto_wvcrc32_unittest.cpp deleted file mode 100644 index 3d62b17..0000000 --- a/oemcrypto/util/test/oemcrypto_wvcrc32_unittest.cpp +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2021 Google LLC. All Rights Reserved. This file and proprietary -// source code may only be used and distributed under the Widevine License -// Agreement. -// -// Reference implementation utilities of OEMCrypto APIs -// -#include - -#include "wvcrc32.h" - -namespace wvoec { -namespace util { -uint32_t ComputeCrc32(const std::string& s) { - return wvcrc32(reinterpret_cast(s.data()), s.size()); -} - -uint32_t ComputeCrc32Cont(const std::string& s, uint32_t prev_crc) { - return wvcrc32Cont(reinterpret_cast(s.data()), s.size(), - prev_crc); -} - -TEST(OEMCryptoWvCrc32Test, BasicTest) { - EXPECT_EQ(0xF88AC628u, ComputeCrc32("abcdefg")); - EXPECT_EQ(0xDF520F72u, ComputeCrc32("Widevine")); - EXPECT_EQ(0x0376E6E7u, ComputeCrc32("123456789")); - EXPECT_EQ(0xBA62119Eu, - ComputeCrc32("The quick brown fox jumps over the lazy dog")); -} - -TEST(OEMCryptoWvCrc32Test, StreamTest) { - const std::vector parts = {"The ", "quick", " brown ", - "fox", " jumps ", "over", - " the ", "lazy", " dog"}; - uint32_t crc = wvcrc32Init(); - for (const auto& part : parts) { - crc = ComputeCrc32Cont(part, crc); - } - EXPECT_EQ(0xBA62119Eu, crc); -} - -TEST(OEMCryptoWvCrc32Test, Keybox) { - // clang-format off - const uint8_t kKeyboxData[128] = { - // deviceID = WidevineCRCTestKeyBox - 0x57, 0x69, 0x64, 0x65, 0x76, 0x69, 0x6e, 0x65, - 0x43, 0x52, 0x43, 0x54, 0x65, 0x73, 0x74, 0x4b, - 0x65, 0x79, 0x62, 0x6f, 0x78, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - // key = random - 0x8a, 0x7c, 0xda, 0x3e, 0x09, 0xd9, 0x8e, 0xd5, - 0x47, 0x47, 0x00, 0x84, 0x5a, 0x1f, 0x52, 0xd4, - // data = random - 0x98, 0xa5, 0x00, 0x19, 0x8b, 0xfe, 0x54, 0xfd, - 0xca, 0x4d, 0x26, 0xa3, 0xfa, 0xaa, 0x3b, 0x6c, - 0x35, 0xfe, 0x03, 0x7c, 0xbf, 0x35, 0xba, 0xce, - 0x31, 0xb5, 0x1e, 0x3c, 0x49, 0xd6, 0x3f, 0x9c, - 0x3a, 0xde, 0x9b, 0x58, 0xcc, 0x54, 0x8d, 0xc0, - 0x4b, 0x04, 0xcc, 0xee, 0xae, 0x4d, 0x9f, 0x90, - 0xd3, 0xf3, 0xfe, 0x23, 0x26, 0x13, 0x56, 0x80, - 0xe4, 0x3b, 0x79, 0x22, 0x69, 0x5d, 0xd6, 0xb7, - 0xa0, 0x0e, 0x7e, 0x07, 0xcd, 0x1a, 0x15, 0xca, - // magic - 'k', 'b', 'o', 'x', - // crc - 0x09, 0x7b, 0x7e, 0xcc - }; - // clang-format on - const uint32_t crc_computed = wvcrc32n(kKeyboxData, 124); - uint32_t crc_current; - memcpy(&crc_current, &kKeyboxData[124], 4); - EXPECT_EQ(crc_computed, crc_current); -} -} // namespace util -} // namespace wvoec diff --git a/oemcrypto/util/test/wvcrc32_unittest.cpp b/oemcrypto/util/test/wvcrc32_unittest.cpp new file mode 100644 index 0000000..9eccd47 --- /dev/null +++ b/oemcrypto/util/test/wvcrc32_unittest.cpp @@ -0,0 +1,248 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine +// License Agreement. +// +#include "wvcrc32.h" + +#include +#include + +#include +#include + +#include + +namespace wvoec { +namespace util { +namespace test { +namespace { +uint32_t ComputeCrc32(const std::string& s) { + return wvcrc32(reinterpret_cast(s.data()), s.size()); +} + +uint32_t ComputeCrc32Cont(const std::string& s, uint32_t prev_crc) { + return wvcrc32Cont(reinterpret_cast(s.data()), s.size(), + prev_crc); +} + +// Sample keybox data. +// clang-format off +constexpr uint8_t kKeyboxData[128] = { + // deviceID = WidevineCRCTestKeyBox + 0x57, 0x69, 0x64, 0x65, 0x76, 0x69, 0x6e, 0x65, + 0x43, 0x52, 0x43, 0x54, 0x65, 0x73, 0x74, 0x4b, + 0x65, 0x79, 0x62, 0x6f, 0x78, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // key = random + 0x8a, 0x7c, 0xda, 0x3e, 0x09, 0xd9, 0x8e, 0xd5, + 0x47, 0x47, 0x00, 0x84, 0x5a, 0x1f, 0x52, 0xd4, + // data = random + 0x98, 0xa5, 0x00, 0x19, 0x8b, 0xfe, 0x54, 0xfd, + 0xca, 0x4d, 0x26, 0xa3, 0xfa, 0xaa, 0x3b, 0x6c, + 0x35, 0xfe, 0x03, 0x7c, 0xbf, 0x35, 0xba, 0xce, + 0x31, 0xb5, 0x1e, 0x3c, 0x49, 0xd6, 0x3f, 0x9c, + 0x3a, 0xde, 0x9b, 0x58, 0xcc, 0x54, 0x8d, 0xc0, + 0x4b, 0x04, 0xcc, 0xee, 0xae, 0x4d, 0x9f, 0x90, + 0xd3, 0xf3, 0xfe, 0x23, 0x26, 0x13, 0x56, 0x80, + 0xe4, 0x3b, 0x79, 0x22, 0x69, 0x5d, 0xd6, 0xb7, + 0xa0, 0x0e, 0x7e, 0x07, 0xcd, 0x1a, 0x15, 0xca, + // magic + 'k', 'b', 'o', 'x', + // crc + 0x09, 0x7b, 0x7e, 0xcc +}; +// clang-format on +} // namespace + +// ==== Original API ==== + +// Tests the original API for Widevine CRC-32 calculations. +TEST(WvCrc32Test, Init) { + // Initial value. + constexpr uint32_t kExpected = 0xffffffffu; + EXPECT_EQ(wvcrc32Init(), kExpected); +} + +TEST(WvCrc32Test, BasicTest) { + constexpr uint32_t kExpected1 = 0xF88AC628u; + EXPECT_EQ(ComputeCrc32("abcdefg"), kExpected1); + + constexpr uint32_t kExpected2 = 0xDF520F72u; + EXPECT_EQ(ComputeCrc32("Widevine"), kExpected2); + + constexpr uint32_t kExpected3 = 0x0376E6E7u; + EXPECT_EQ(ComputeCrc32("123456789"), kExpected3); + + constexpr uint32_t kExpected4 = 0xBA62119Eu; + EXPECT_EQ(ComputeCrc32("The quick brown fox jumps over the lazy dog"), + kExpected4); +} + +TEST(WvCrc32Test, NetworkByteOrder) { + const std::string kDataStr = "Hello, World!"; + const std::vector kDataVec(kDataStr.begin(), kDataStr.end()); + // Compute all at once, in network-byte-order (NBO). + const uint8_t kExpectedNboBytes[4] = {0x19, 0x27, 0x01, 0x20}; + uint32_t expected_nbo = 0; + memcpy(&expected_nbo, kExpectedNboBytes, 4); + EXPECT_EQ(wvcrc32n(kDataVec.data(), kDataVec.size()), expected_nbo); +} + +TEST(WvCrc32Test, StreamTest) { + // Compute in chunks. + const std::vector parts = {"The ", "quick", " brown ", + "fox", " jumps ", "over", + " the ", "lazy", " dog"}; + uint32_t crc = wvcrc32Init(); + for (const auto& part : parts) { + crc = ComputeCrc32Cont(part, crc); + } + EXPECT_EQ(0xBA62119Eu, crc); +} + +TEST(WvCrc32Test, Keybox) { + const uint32_t crc_computed = wvcrc32n(kKeyboxData, 124); + uint32_t crc_current; + memcpy(&crc_current, &kKeyboxData[124], 4); + EXPECT_EQ(crc_computed, crc_current); +} + +// == Context Based API ==== + +TEST(WvCrc32CtxTest, InitialCtx) { + const Crc32Ctx ctx; + constexpr uint32_t kExpected = 0xffffffffu; + EXPECT_EQ(ctx.accumulator(), kExpected); + EXPECT_EQ(ctx.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, ComputeSingleByte) { + Crc32Ctx ctx; + ctx.Update(0x42); + constexpr uint32_t kExpected = 0x730cf4ad; + EXPECT_EQ(ctx.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, ComputeByArray) { + Crc32Ctx ctx; + constexpr uint8_t kData[4] = {1, 3, 3, 7}; + ctx.Update(kData, sizeof(kData)); + constexpr uint32_t kExpected = 0x75acbd93; + EXPECT_EQ(ctx.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, ComputeByVector) { + Crc32Ctx ctx; + const std::vector kData = {0xde, 0xad, 0xbe, 0xef}; + ctx.Update(kData); + constexpr uint32_t kExpected = 0x81da1a18; + EXPECT_EQ(ctx.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, ComputeByString) { + Crc32Ctx ctx; + const std::string kData = "Hello, World!"; + ctx.Update(kData); + constexpr uint32_t kExpected = 0x19270120; + EXPECT_EQ(ctx.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, ComputeInChunks) { + Crc32Ctx ctx; + ctx.Update("Hello, ").Update("World").Update('!'); + constexpr uint32_t kExpected = 0x19270120; + EXPECT_EQ(ctx.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, ResumeConstructor) { + Crc32Ctx ctx1; + ctx1.Update("Hello, "); + + Crc32Ctx ctx2 = Crc32Ctx(ctx1.Finalize()); + ctx2.Update("World!"); + + constexpr uint32_t kExpected = 0x19270120; + EXPECT_EQ(ctx2.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, ResumeMethod) { + Crc32Ctx ctx1; + ctx1.Update("Hello, "); + + Crc32Ctx ctx2; + ctx2.Resume(ctx1.Finalize()).Update("World!"); + + constexpr uint32_t kExpected = 0x19270120; + EXPECT_EQ(ctx2.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, Reset) { + Crc32Ctx ctx; + ctx.Update("Foo bar fiz bang").Reset().Update("Hello, World!"); + + constexpr uint32_t kExpected = 0x19270120; + EXPECT_EQ(ctx.Finalize(), kExpected); +} + +TEST(WvCrc32CtxTest, FinalizeNbo) { + Crc32Ctx ctx; + ctx.Update("Hello, World!"); + + const uint8_t kExpectedNboBytes[4] = {0x19, 0x27, 0x01, 0x20}; + uint32_t expected_nbo = 0; + memcpy(&expected_nbo, kExpectedNboBytes, 4); + EXPECT_EQ(ctx.FinalizeNbo(), expected_nbo); +} + +TEST(WvCrc32CtxTest, FinalizeNboBuffer) { + Crc32Ctx ctx; + ctx.Update("Hello, World!"); + + const std::vector kExpected = {0x19, 0x27, 0x01, 0x20}; + std::vector result(4, 0x00); + ASSERT_TRUE(ctx.FinalizeNbo(result.data(), result.size())); + EXPECT_EQ(result, kExpected); +} + +TEST(WvCrc32CtxTest, BadParameters) { + Crc32Ctx ctx; + // Output buffer is null. + EXPECT_FALSE(ctx.FinalizeNbo(nullptr, 4)); + + // Output buffer is too small. + uint8_t buffer[4]; + EXPECT_FALSE(ctx.FinalizeNbo(buffer, 3)); + EXPECT_FALSE(ctx.FinalizeNbo(buffer, 0)); +} + +TEST(WvCrc32CtxTest, Keybox) { + Crc32Ctx ctx; + + constexpr size_t kDeviceIdOffset = 0; + constexpr size_t kDeviceIdLength = 32; + ctx.Update(&kKeyboxData[kDeviceIdOffset], kDeviceIdLength); + + constexpr size_t kKeyOffset = kDeviceIdOffset + kDeviceIdLength; + constexpr size_t kKeyLength = 16; + ctx.Update(&kKeyboxData[kKeyOffset], kKeyLength); + + constexpr size_t kDataOffset = kKeyOffset + kKeyLength; + constexpr size_t kDataLength = 72; + ctx.Update(&kKeyboxData[kDataOffset], kDataLength); + + constexpr size_t kMagicOffset = kDataOffset + kDataLength; + constexpr size_t kMagicLength = 4; + ctx.Update(&kKeyboxData[kMagicOffset], kMagicLength); + + // Extract CRC from keybox. + constexpr size_t kCrcOffset = kMagicOffset + kMagicLength; + uint32_t crc_keybox; + memcpy(&crc_keybox, &kKeyboxData[kCrcOffset], 4); + + // Compute CRC, and compare to keybox value. + const uint32_t crc_computed = ctx.FinalizeNbo(); + EXPECT_EQ(crc_computed, crc_keybox); +} +} // namespace test +} // namespace util +} // namespace wvoec diff --git a/oemcrypto/util/wvcrc32.gyp b/oemcrypto/util/wvcrc32.gyp new file mode 100644 index 0000000..1dcd1f5 --- /dev/null +++ b/oemcrypto/util/wvcrc32.gyp @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +# source code may only be used and distributed under the Widevine +# License Agreement. +{ + 'variables': { + 'oemcrypto_dir%': '..', + }, + 'targets': [ + { + 'target_name': 'libwvcrc32', + 'type': 'static_library', + 'standalone_static_library': 1, + 'hard_dependency': 1, + 'include_dirs': ['<(oemcrypto_dir)/util/include'], + 'direct_dependent_settings': { + 'include_dirs': ['<(oemcrypto_dir)/util/include'], + }, + 'sources': ['<(oemcrypto_dir)/util/src/wvcrc32.cpp'], + }, + ], +} diff --git a/oemcrypto/util/wvcrc32_unittests.gypi b/oemcrypto/util/wvcrc32_unittests.gypi new file mode 100644 index 0000000..26ac31f --- /dev/null +++ b/oemcrypto/util/wvcrc32_unittests.gypi @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +# source code may only be used and distributed under the Widevine +# License Agreement. +{ + 'include_dirs': ['<(oemcrypto_dir)/util/include'], + 'direct_dependent_settings': { + 'include_dirs': ['<(oemcrypto_dir)/util/include',], + }, + 'sources': [ + '<(oemcrypto_dir)/util/test/wvcrc32_unittest.cpp', + ], + 'dependencies': [ + '<(oemcrypto_dir)/util/wvcrc32.gyp:libwvcrc32', + ], +} diff --git a/util/include/buffer_reader.h b/util/include/buffer_reader.h new file mode 100644 index 0000000..58ffc4a --- /dev/null +++ b/util/include/buffer_reader.h @@ -0,0 +1,78 @@ +// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVCDM_UTIL_BUFFER_READER_H_ +#define WVCDM_UTIL_BUFFER_READER_H_ + +#include + +#include +#include + +#include "wv_class_utils.h" + +namespace wvutil { +// Annotate a function indicating the caller must examine the return value. +// Use like: +// int foo() WARN_UNUSED_RESULT; +// To explicitly ignore a result, see |ignore_result()| in . +#if defined(COMPILER_GCC) +# define WARN_UNUSED_RESULT __attribute__((warn_unused_result)) +#else +# define WARN_UNUSED_RESULT +#endif + +class BufferReader { + public: + BufferReader() = delete; + WVCDM_DISALLOW_COPY_AND_MOVE(BufferReader); + + BufferReader(const uint8_t* buf, size_t size) + : buf_(buf), size_(buf != nullptr ? size : 0) {} + + // == Accessors == + + constexpr const uint8_t* data() const { return buf_; } + constexpr size_t size() const { return size_; } + constexpr size_t pos() const { return pos_; } + + constexpr size_t BytesRemaining() const { + return size_ >= pos_ ? size_ - pos_ : 0; + } + + constexpr bool HasBytes(size_t count) const { + return count <= BytesRemaining(); + } + constexpr bool IsEof() const { return pos_ >= size_; } + + // Read a value from the stream, performing endian correction, + // and advance the stream pointer. + bool Read1(uint8_t* v) WARN_UNUSED_RESULT; + bool Read2(uint16_t* v) WARN_UNUSED_RESULT; + bool Read2s(int16_t* v) WARN_UNUSED_RESULT; + bool Read4(uint32_t* v) WARN_UNUSED_RESULT; + bool Read4s(int32_t* v) WARN_UNUSED_RESULT; + bool Read8(uint64_t* v) WARN_UNUSED_RESULT; + bool Read8s(int64_t* v) WARN_UNUSED_RESULT; + + bool ReadString(std::string* str, size_t count) WARN_UNUSED_RESULT; + bool ReadVec(std::vector* t, size_t count) WARN_UNUSED_RESULT; + + // These variants read a 4-byte integer of the corresponding signedness and + // store it in the 8-byte return type. + bool Read4Into8(uint64_t* v) WARN_UNUSED_RESULT; + bool Read4sInto8s(int64_t* v) WARN_UNUSED_RESULT; + + // Advance the stream by this many bytes. + bool SkipBytes(size_t count) WARN_UNUSED_RESULT; + + private: + const uint8_t* buf_ = nullptr; + size_t size_ = 0; + size_t pos_ = 0; + + template + bool Read(T* t) WARN_UNUSED_RESULT; +}; // class BufferReader +} // namespace wvutil +#endif // WVCDM_UTIL_BUFFER_READER_H_ diff --git a/util/include/clock.h b/util/include/clock.h index 98e7ab6..1946cdf 100644 --- a/util/include/clock.h +++ b/util/include/clock.h @@ -9,18 +9,19 @@ #include -namespace wvutil { +#include "wv_timestamp.h" +namespace wvutil { // Provides time related information. The implementation is platform dependent. class Clock { public: - Clock() {} - virtual ~Clock() {} + Clock() = default; + virtual ~Clock() = default; // Provides the number of seconds since an epoch - 01/01/1970 00:00 UTC virtual int64_t GetCurrentTime(); + + virtual Timestamp GetCurrentTimestamp(); }; - } // namespace wvutil - #endif // WVCDM_UTIL_CLOCK_H_ diff --git a/util/include/hls_attribute_list.h b/util/include/hls_attribute_list.h new file mode 100644 index 0000000..ee76097 --- /dev/null +++ b/util/include/hls_attribute_list.h @@ -0,0 +1,211 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVCDM_UTIL_HLS_ATTRIBUTE_LIST_H_ +#define WVCDM_UTIL_HLS_ATTRIBUTE_LIST_H_ + +#include + +#include +#include +#include +#include +#include + +#include "wv_class_utils.h" + +namespace wvutil { +// An HLS Attribute List is a loosely defined HLS tag value +// type representing a dictionary of attribute name-value pairs. +// No attribute name may appear twice in a valid HLS Attribute +// List. +// +// When serialized, an HLS attribute list appears as a comma +// separated list of = pairs, without any line breaks, +// and whitespace only within quoted string attribute values. +// +// The HLS specification defines them as context sensitive-types, +// as the format of certain value types are ambiguous with other +// value types (ex. something that looks like a hex sequence could +// actually be an enum string). The exact value types of attributes +// depends on the HLS tag. +// +// This implementation is intended to be context free when parsing. +// Internally, value will be assigned a type based on their +// appearance to the most restrictive type; however, accessing as a +// particular type will be allowed so long as they match the format +// that type. The only exception to this is quoted strings, which +// are fully unambiguous with all other types. +// +// The standard HLS attribute list allows for UTF-8 encoded unicode +// characters; however, for Widevine's use case, we only allow +// basic ASCII. +// +// This class is based on RFC 8216 section 4.2. +class HlsAttributeList { + public: + enum ValueType { + kUnsetType = 0, + // HLS integers are a sequence of 1 to 20 base10 digits + // values, which must fit into an unsigned 64-bit integer. + // Note: + // All integers can appear as an enum string. + // All integers can appear as floats; however, precision + // restrictions may limit ability to parse accurately. + kIntegerType, + // HLS hex sequence are a sequence of 1 or more upper case + // hexadecimal digits, with the prefix "0x" or "0X". + // Note: + // All hex sequences can appear as an enum string. + // Certain hex sequences can appear as a resolution. + kHexSequenceType, + // HLS float and sign float are a sequence of base10 digits + // with a possible leading negative sign ('-') and at most one + // decimal point ('.'). + // Note: + // All floats can appear as an enum string. + // Certain floats can appear are an integer. + kFloatType, + // For simplicity, signed and unsigned floats are internally + // treated as the same. + kSignedFloatType = kFloatType, // Aliasing + // HLS quoted strings are a sequence of printable characters + // (except a double quote), or space characters, contained within + // a pair of double quotes. + // Note: Quote string are unambiguous with all other types. + kQuotedStringType, + // HLS enum strings are a sequence of any printable characters + // (except quotes or commas) and do not contain whitespace. + // Note: + // Certain enum strings can appear as integers, floats, hex + // sequences or resolutions. + kEnumStringType, + // HLS resolutions are a pair of base10 integers (similar to + // integer types) separated by an 'x' character. + // Note: + // All resolutions can appear as enum strings + // Certain resolutions can appear as hex sequences. + kResolutionType, + }; + static const char* ValueTypeToString(ValueType type); + + // Checks if the provided |name| is valid HLS attribute name. + static bool IsValidName(const std::string& name); + + // Checks if the provided |value| is a valid HLS enumerated + // string. + static bool IsValidEnumStringValue(const std::string& value); + // Checks if the provided |value| is an allowed content + // of a quote string (i.e., |value| is the portion that is + // to be contained within the double quotes). + static bool IsValidQuotedStringValue(const std::string& value); + + // Validator for an HLS unsigned integer representation. + // Checks that the value representation |value_rep| conforms + // to HLS requirements for a serialized integer. + static bool IsValidIntegerRep(const std::string& value_rep); + // Parses the provided |integer_rep| as an HLS integer, assign + // the parsed integer to |value|. + static bool ParseInteger(const std::string& integer_rep, uint64_t* value); + + HlsAttributeList() = default; + WVCDM_DEFAULT_COPY_AND_MOVE(HlsAttributeList); + + // == Basic Accessors == + bool IsEmpty() const { return members_.empty(); } + size_t Count() const { return members_.size(); } + void Clear() { members_.clear(); } + + // == Value Getters == + + // Returns a list of attribute names. + std::vector GetNames() const; + // Returns true if the provided attribute |name| is contained + // within list. + bool Contains(const std::string& name) const; + + // Checks if the provided attribute |name| could be the + // specified |type|. + // Certain value types are ambiguous; so they may be marked + // internally as one type, but are still valid forms of a + // different type. + bool IsType(const std::string& name, ValueType type) const; + + // Gets the attribute value for the provided attribute |name|, + // assigning the deserialized value to the output parameter(s). + // + // If a type is internally stored as a different type, but the format + // matches the requested type, it will be allowed. + // + // Returns true if the attribute value was successfully obtained; + // false otherwise (attribute does not exist, or the value could + // not be parsed as the specified type). + bool GetEnumString(const std::string& name, std::string* value) const; + bool GetQuotedString(const std::string& name, std::string* value) const; + bool GetHexSequence(const std::string& name, std::string* value) const; + bool GetHexSequence(const std::string& name, + std::vector* value) const; + bool GetInteger(const std::string& name, uint64_t* value) const; + bool GetFloat(const std::string& name, double* value) const; + bool GetResolution(const std::string& name, uint64_t* width, + uint64_t* height) const; + + // == Value Setters == + + // The value setters attempt to serializes the provided value + // into an HLS attribute value format of the type indicated + // by the method name. These methods will overwrite any existing + // attribute of the same name. + // Setting will be rejected if |name| is not a valid HLS attribute + // name, or if the provided |value| is valid (only applicable to + // certain types). + bool SetEnumString(const std::string& name, const std::string& value); + bool SetQuotedString(const std::string& name, const std::string& value); + // Note: This implementation will use lower "0x" prefix for + // hex sequences. + bool SetHexSequence(const std::string& name, const std::string& value); + bool SetHexSequence(const std::string& name, + const std::vector& value); + bool SetInteger(const std::string& name, uint64_t value); + bool SetFloat(const std::string& name, double value); + bool SetResolution(const std::string& name, uint64_t width, uint64_t height); + + // Removes the specified attribute. Returns true if the attribute + // was removed; false otherwise. + bool Remove(const std::string& name); + + // == Parsing / Serialization == + + // Attempts to parse the provided HLS Attribute List. + // Always clears the existing contents, even if the attribute + // list cannot be parsed. + // + // Internally, values are not parsed, but checked to see which + // type they resemble. This does not restrict accessing them + // as different types. + // + // Returns true if successfully parsed; false otherwise. + bool Parse(const std::string& hls_attr_list_rep); + + // Serializes the contents of the HLS attribute list into a + // valid HLS attribute list form. + std::string Serialize() const; + bool SerializeToStream(std::ostream* out) const; + + // Internally, values are stored in their serialized form + // in case their value is ambiguous between two or more types. + // The assigned ValueType is intended to be a hint for improve + // access speeds. + using ValueInfo = + std::pair; + + private: + // Internal utility to obtain the value info. + // Silently returns null if the name does not exist. + const ValueInfo* GetInfo(const std::string& name) const; + + // Values in |members_| will always be validated before assigned. + std::map members_; +}; // class HlsAttributeList +} // namespace wvutil +#endif // WVCDM_UTIL_HLS_ATTRIBUTE_LIST_H_ diff --git a/util/include/rw_lock.h b/util/include/rw_lock.h deleted file mode 100644 index 40d7902..0000000 --- a/util/include/rw_lock.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2019 Google LLC. All Rights Reserved. This file and proprietary -// source code may only be used and distributed under the Widevine License -// Agreement. - -#ifndef WVCDM_UTIL_RW_LOCK_H_ -#define WVCDM_UTIL_RW_LOCK_H_ - -#include - -#include -#include - -#include "disallow_copy_and_assign.h" -#include "util_common.h" - -namespace wvutil { - -// A simple reader-writer mutex implementation that mimics the one from C++17 -class shared_mutex { - public: - shared_mutex() : reader_count_(0), has_writer_(false) {} - ~shared_mutex(); - - // These methods take the mutex as a reader. They do not fulfill the - // SharedMutex requirement from the C++14 STL, but they fulfill enough of it - // to be used with |shared_lock| below. - void lock_shared(); - void unlock_shared(); - - // These methods take the mutex as a writer. They fulfill the Mutex - // requirement from the C++11 STL so that this mutex can be used with - // |std::unique_lock|. - void lock() { lock_implementation(false); } - bool try_lock() { return lock_implementation(true); } - void unlock(); - - private: - bool lock_implementation(bool abort_if_unavailable); - - uint32_t reader_count_; - bool has_writer_; - - std::mutex mutex_; - std::condition_variable condition_variable_; - - CORE_DISALLOW_COPY_AND_ASSIGN(shared_mutex); -}; - -// A simple reader lock implementation that mimics the one from C++14 -template -class shared_lock { - public: - explicit shared_lock(Mutex& lock) : lock_(&lock) { lock_->lock_shared(); } - explicit shared_lock(Mutex* lock) : lock_(lock) { lock_->lock_shared(); } - ~shared_lock() { lock_->unlock_shared(); } - - private: - Mutex* lock_; - - CORE_DISALLOW_COPY_AND_ASSIGN(shared_lock); -}; - -} // namespace wvutil - -#endif // WVCDM_UTIL_RW_LOCK_H_ diff --git a/util/include/string_utils.h b/util/include/string_utils.h new file mode 100644 index 0000000..32b7ff2 --- /dev/null +++ b/util/include/string_utils.h @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVCDM_UTIL_STRING_UTILS_H_ +#define WVCDM_UTIL_STRING_UTILS_H_ + +#include +#include + +// Small set of simple string utilities. +// +// It is inspired by Python strings and Abseil. The CDM tends to not +// be too fancy with its use of templated containers; none of the +// Abseil container magic is not provided, nor are the fancy templated +// configurations. +namespace wvutil { + +// Splits a string into several substrings based on the provided +// delimiter. +// +// Special cases: +// - An empty delimiter will split the string into each character. +std::vector StringSplit(const std::string& s, char delim); +std::vector StringSplit(const std::string& s, + const std::string& delim); +std::vector StringSplit(const std::string& s, const char* delim); + +// Joins a list of strings into a single string, using the specified +// "glue" character(s) between tokens. Note: |glue| can be empty. +std::string StringJoin(const std::vector& tokens, char glue); +std::string StringJoin(const std::vector& tokens, + const std::string& glue = ""); +std::string StringJoin(const std::vector& tokens, + const char* glue); + +// Counts the number of instances of |needle| sequences found in the +// |haystack|. +// +// Special cases: +// - An empty needle will return the length of |haystack| plus 1, +// including for an empty string. +// Note: This is a convention used by many string utility +// libraries. +size_t StringCount(const std::string& haystack, char needle); +size_t StringCount(const std::string& haystack, const std::string& needle); +size_t StringCount(const std::string& haystack, const char* needle); + +// Checks if any instances of |needle| sequences found in the |haystack|. +// +// Special cases: +// - An empty |needle| is always present, even if |haystack| is empty. +// Note: This is a convention used by many string utility +// libraries. +bool StringContains(const std::string& haystack, char needle); +bool StringContains(const std::string& haystack, const std::string& needle); +bool StringContains(const std::string& haystack, const char* needle); + +// Checks if the |needle| sequences found at the beginning of |haystack|. +// +// Special cases: +// - An empty |needle| is always present, even if |haystack| is empty. +// Note: This is a convention used by many string utility +// libraries. +bool StringStartsWith(const std::string& haystack, char needle); +bool StringStartsWith(const std::string& haystack, const std::string& needle); +bool StringStartsWith(const std::string& haystack, const char* needle); + +// Checks if the |needle| sequences found at the end of |haystack|. +// +// Special cases: +// - An empty |needle| is always present, even if |haystack| is empty. +// Note: This is a convention used by many string utility +// libraries. +bool StringEndsWith(const std::string& haystack, char needle); +bool StringEndsWith(const std::string& haystack, const std::string& needle); +bool StringEndsWith(const std::string& haystack, const char* needle); + +// Removes any leading or trailing white space from the provided string. +std::string StringTrim(const std::string& s); + +// Checks if the vector of strings contains any instance of the +// specified |needle|. +// +// Note: Unlike the other utilities, an empty |needle| is treated +// as a value. +bool StringVecContains(const std::vector& haystack, + const std::string& needle); +} // namespace wvutil +#endif // WVCDM_UTIL_STRING_UTILS_H_ diff --git a/util/include/wv_date_time.h b/util/include/wv_date_time.h new file mode 100644 index 0000000..0cef398 --- /dev/null +++ b/util/include/wv_date_time.h @@ -0,0 +1,223 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVCDM_UTIL_DATE_TIME_H_ +#define WVCDM_UTIL_DATE_TIME_H_ + +#include + +#include +#include +#include +#include + +#include "wv_class_utils.h" +#include "wv_duration.h" +#include "wv_timestamp.h" + +namespace wvutil { +// The DateTime class represents a time point measured in UTC +// Gregorian calendar dates and 24-hour timekeeping system. +// +// Internally, the time is measured in Unix Time (milliseconds +// since January 1st, 1970 (epoch)). +// +// For the CDM, we are not concerned about time points before +// January 1st, 1970; this class does not support negative epoch +// seconds. In addition, zero is treated as a special value +// which indicates an uninitialized time point. +// +// Min DateTime: 1970-01-01 00:00:00.001 (epoch ms = 1) +// Max DateTime: 9999-12-31 23:59:59.999 (epoch ms = 253402300799999) +// +// Converting from Unix Time to datetime components is a potentially +// expensive operation for their amount of utility; to reduce this +// overhead date components are calculated once per change, and stored +// individually. +// +// For testing, use the DateTime::Builder class for create DateTime +// instances from their individual components. +// +// The DateTime class provides a PrintTo()/ToString() method which +// returns an initialized DateTime in ISO 8601 format without timezone +// indicator and with milliseconds only being printed if non-zero. +// Use DateTime::Formatter for specific formatting needs. +class DateTime { + public: + constexpr DateTime() = default; // Uninitialized DateTime. + WVCDM_CONSTEXPR_DEFAULT_COPY_AND_MOVE(DateTime); + + private: + // Note: This private section is needed here. + // clang++ with -Wundefined-inline will complain about certain + // inline members not being declared before use in other inline + // methods. + template + constexpr ToDuration EpochFloor() const { + return std::chrono::floor(epoch_milliseconds()); + } + + public: + static DateTime FromTimestamp(const Timestamp& timestamp); + + // Create a DateTime instance from Unix Time in epoch seconds. + // Optionally allowed to include |millisecond| (0 to 999). + // If |epoch_seconds| or |millisecond| are invalid values, then an + // uninitlaized DateTime instance is returned. + static DateTime FromUnixSeconds(uint64_t epoch_seconds, + uint32_t milliseconds = 0) { + return FromTimestamp( + Timestamp::FromUnixSeconds(epoch_seconds, milliseconds)); + } + static DateTime FromUnixSeconds(const Seconds& epoch_seconds, + uint32_t milliseconds = 0) { + return FromTimestamp( + Timestamp::FromUnixSeconds(epoch_seconds, milliseconds)); + } + + // Create a DateTime instance from Unix Time in epoch milliseconds. + // If |epoch_milliseconds| is an invalid value (zero), then an + // uninitlaized DateTime instance is returned. + static DateTime FromUnixMilliseconds(uint64_t epoch_milliseconds) { + return FromTimestamp(Timestamp::FromUnixMilliseconds(epoch_milliseconds)); + } + static DateTime FromUnixMilliseconds(const Milliseconds& epoch_milliseconds) { + return FromTimestamp(Timestamp::FromUnixMilliseconds(epoch_milliseconds)); + } + + // Obtain the minimum and maximum representable DateTime. + static DateTime Min() { return FromTimestamp(Timestamp::Min()); } + static DateTime Max() { return FromTimestamp(Timestamp::Max()); } + + constexpr const Timestamp& timestamp() const { return timestamp_; } + + // == Epoch Accessors == + constexpr Seconds epoch_seconds() const { return timestamp_.epoch_seconds(); } + constexpr Milliseconds epoch_milliseconds() const { + return timestamp_.epoch_milliseconds(); + } + + // == Component Accessors == + // Year on Gregorian calendar (1970 - 9999; 0 if unset) + constexpr uint32_t year() const { return year_; } + // Month of the year (January = 1, December = 12; 0 if unset) + constexpr uint32_t month() const { return month_; } + // Day of the month (1 - 28/29/30/31; 0 if unset) + constexpr uint32_t day() const { return day_; } + // Day of the year (1 - 365/366; 0 if unset) + constexpr uint32_t day_of_year() const { return day_of_year_; } + // Day of the week (1 (Sunday) - 7 (Saturday); 0 if unset) + constexpr uint32_t day_of_week() const { return day_of_week_; } + + // Time component accessors. + // Hour of day (0 - 23; 0 if unset) + constexpr uint32_t hour() const { + if (!IsSet()) return 0; + return static_cast( + (EpochFloor() - EpochFloor()).count()); + } + // Minute of hour (0 - 59; 0 if unset) + constexpr uint32_t minute() const { + if (!IsSet()) return 0; + return static_cast( + (EpochFloor() - EpochFloor()).count()); + } + // Second of minute (0 - 59; 0 if unset) + constexpr uint32_t second() const { + if (!IsSet()) return 0; + return static_cast( + (EpochFloor() - EpochFloor()).count()); + } + // Millisecond of second (0 - 999; 0 if unset) + constexpr uint32_t millisecond() const { return timestamp_.milliseconds(); } + + // Checks if the DateTime instance is set. + constexpr bool IsSet() const { return timestamp_.IsSet(); } + constexpr explicit operator bool() const { return IsSet(); } + + constexpr void Clear() { + timestamp_.Clear(); + ClearComponents(); + } + + // == Comparison Operators == + constexpr bool IsEqualTo(const DateTime& other) const { + return timestamp_.IsEqualTo(other.timestamp_); + } + constexpr int64_t CompareTo(const DateTime& other) const { + return timestamp_.CompareTo(other.timestamp_); + } + WVCDM_DEFINE_CONSTEXPR_EQ_AND_CMP_OPERATORS(DateTime); + + constexpr bool IsEqualTo(const Timestamp& other) const { + return timestamp_.IsEqualTo(other); + } + constexpr int64_t CompareTo(const Timestamp& other) const { + return timestamp_.CompareTo(other); + } + WVCDM_DEFINE_CONSTEXPR_EQ_AND_CMP_OPERATORS(Timestamp); + + // == Duration Operators == + // Duration-based addition/subtraction operations. + // Returns a new DateTime instance with time point adjusted by + // the provided Duration. + // If current DateTime instance is unset, or if result value + // is outside the range of valid DateTime values, then an unset + // DateTime is returned. + DateTime operator+(const Duration& duration) const; + DateTime operator-(const Duration& duration) const; + + // DateTime difference. + // Returns the duration between the two provided DateTime + // instances. + // If either DateTime instance is unset, then the resulting + // Duration is zero. + Duration operator-(const DateTime& other) const; + + // Duration-based increment/decrement operators. + // Increment or decrement the DateTime by the provided Duration + // amount. + // If current DateTime instance is unset, then no action will + // occur. If result value is outside the range of valid DateTime + // values, then DateTime will be unset. + DateTime& operator+=(const Duration& duration); + DateTime& operator-=(const Duration& duration); + + // For initialized DateTime instances, returns the datetime + // in ISO 8601 format. Milliseconds are only printed + // if non-zero. + // An uninitialized DateTime will return an empty string. + // Ex: + // Without milli: 2024-01-19T14:49:13Z + // With milli: 2024-01-19T14:49:13.507Z + // Use DateTime::Formatter for more formatting options. + bool PrintTo(std::ostream* out) const; + std::string ToString() const; + + private: + // Special constructor. + explicit DateTime(const Timestamp& timestamp); + + // Attempts to calculate subcomponents. + // Failure to calculate subcomponents will clear |timestamp_|. + void UpdateTimestamp(const Timestamp& new_timestamp); + + constexpr void ClearComponents() { + year_ = month_ = day_ = day_of_year_ = day_of_week_ = 0; + } + + // Source of truth for DateTime class. + Timestamp timestamp_; + + // Components (derived from |timestamp_|). + uint16_t year_ = 0; + uint8_t month_ = 0; + uint8_t day_ = 0; + uint16_t day_of_year_ = 0; + uint8_t day_of_week_ = 0; +}; // class DateTime + +// == GTest Printer == +void PrintTo(const DateTime& date_time, std::ostream* out); +} // namespace wvutil +#endif // WVCDM_UTIL_DATE_TIME_H_ diff --git a/util/include/wv_duration.h b/util/include/wv_duration.h new file mode 100644 index 0000000..c5140c9 --- /dev/null +++ b/util/include/wv_duration.h @@ -0,0 +1,286 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVCDM_UTIL_WV_DURATION_H_ +#define WVCDM_UTIL_WV_DURATION_H_ + +#include + +#include +#include +#include +#include + +#include "wv_class_utils.h" + +namespace wvutil { +// Wrappers around various std::chono::duration units. +using Nanoseconds = std::chrono::nanoseconds; +using Microseconds = std::chrono::microseconds; +using Milliseconds = std::chrono::milliseconds; +using Seconds = std::chrono::seconds; +using Minutes = std::chrono::minutes; +using Hours = std::chrono::hours; +#if __cplusplus >= 202002L // C++20 +using Days = std::chrono::days; +#else +// Days is not declared in C++17, the standard C++ library +// uses the following implementation for C++20. +using Days = std::chrono::duration >; +#endif + +// A high-level duration struct for storing a duration and +// easily accessing the components. +// A duration measures a concrete amount of time between two +// different points in time. +// +// Precision of Duration is in milliseconds. +class Duration final { + public: + constexpr Duration() = default; + // Initialize duration from some kind of chrono duration type that + // can easily convert to milliseconds. + // Compiler will not allow irreconcilable duration types to be used + // without caller explicitly chrono::duration_cast. + template + constexpr Duration(const std::chrono::duration& total_duration) + : total_milliseconds_(total_duration) {} + WVCDM_CONSTEXPR_DEFAULT_COPY_AND_MOVE(Duration); + + private: + // Note: This private section is needed here. + // clang++ with -Wundefined-inline will complain about certain + // inline members not being declared before use in other inline + // methods. + + // Helper function for performing truncate (round towards zero) + // operation on chrono::duration types. + // The standard C++ chrono library does not provide such a function. + template + static constexpr ToDuration DurationTruncate( + const std::chrono::duration& duration) { + return duration >= std::chrono::duration::zero() + ? std::chrono::floor(duration) + : std::chrono::ceil(duration); + } + + // Helper function for performing truncate on Duration class's + // |total_milliseconds_|, returning it in a chrono::duration type + // which is specified by template parameter |ToDuration|. + template + constexpr Duration GetTruncateInternal() const { + return Duration(DurationTruncate(total_milliseconds_)); + } + + public: + // == Special Initializers == + + static constexpr Duration Zero() { return Duration(); } + static constexpr Duration FromMilliseconds(int64_t milliseconds) { + return Duration(Milliseconds(milliseconds)); + } + static constexpr Duration FromSeconds(int64_t seconds) { + return Duration(Seconds(seconds)); + } + + constexpr bool IsZero() const { + return total_milliseconds_ == Milliseconds::zero(); + } + constexpr bool IsNegative() const { + return total_milliseconds_ < Milliseconds::zero(); + } + constexpr bool IsPositive() const { + return total_milliseconds_ > Milliseconds::zero(); + } + constexpr bool IsNonNegative() const { + return total_milliseconds_ >= Milliseconds::zero(); + } + + // Basic accessor and conversions (rounds down). + constexpr Milliseconds total_milliseconds() const { + return total_milliseconds_; + } + constexpr Seconds total_seconds() const { + return DurationTruncate(total_milliseconds_); + } + constexpr Minutes total_minutes() const { + return DurationTruncate(total_milliseconds_); + } + constexpr Hours total_hours() const { + return DurationTruncate(total_milliseconds_); + } + constexpr Days total_days() const { + return DurationTruncate(total_milliseconds_); + } + + // Access duration absolute components bounded within a + // particular unit. + // Milliseconds modulo milliseconds per-second (0-999) + constexpr uint32_t milliseconds() const { + if (IsNegative()) return GetAbsolute().milliseconds(); + return static_cast( + (total_milliseconds() - total_seconds()).count()); + } + // Seconds modulo seconds per-minute (0-59) + constexpr uint32_t seconds() const { + if (IsNegative()) return GetAbsolute().seconds(); + return static_cast((total_seconds() - total_minutes()).count()); + } + // Minutes modulo minutes per-hour (0-59). + constexpr uint32_t minutes() const { + if (IsNegative()) return GetAbsolute().minutes(); + return static_cast((total_minutes() - total_hours()).count()); + } + // Hours modulo hours per-day (0-23) + constexpr uint32_t hours() const { + if (IsNegative()) return GetAbsolute().hours(); + return static_cast((total_hours() - total_days()).count()); + } + // Days (total). + constexpr uint32_t days() const { + if (IsNegative()) return GetAbsolute().days(); + return static_cast(total_days().count()); + } + + // Obtain the absolute value. + constexpr Duration GetAbsolute() const { + return Duration(std::chrono::abs(total_milliseconds_)); + } + + // Obtain the truncated (math term for rounding towards zero) + // value by units. + constexpr Duration GetTruncateBySeconds() const { + return GetTruncateInternal(); + } + constexpr Duration GetTruncateByMinutes() const { + return GetTruncateInternal(); + } + constexpr Duration GetTruncateByHours() const { + return GetTruncateInternal(); + } + constexpr Duration GetTruncateByDays() const { + return GetTruncateInternal(); + } + + // Comparison operators (between another Duration class) + constexpr bool IsEqualTo(const Duration& other) const { + return total_milliseconds_ == other.total_milliseconds_; + } + constexpr int64_t CompareTo(const Duration& other) const { + return static_cast(total_milliseconds_.count() - + other.total_milliseconds_.count()); + } + WVCDM_DEFINE_CONSTEXPR_EQ_AND_CMP_OPERATORS(Duration); + + // Comparison operators (between other duration types) + template + constexpr bool operator==( + const std::chrono::duration& other) const { + return total_milliseconds_ == other; + } + template + constexpr bool operator!=( + const std::chrono::duration& other) const { + return total_milliseconds_ != other; + } + template + constexpr bool operator<=( + const std::chrono::duration& other) const { + return total_milliseconds_ <= other; + } + template + constexpr bool operator<( + const std::chrono::duration& other) const { + return total_milliseconds_ < other; + } + template + constexpr bool operator>=( + const std::chrono::duration& other) const { + return total_milliseconds_ >= other; + } + template + constexpr bool operator>( + const std::chrono::duration& other) const { + return total_milliseconds_ > other; + } + + // Unary plus/minus operators + constexpr Duration operator+() const { return Duration(total_milliseconds_); } + constexpr Duration operator-() const { + return Duration(-total_milliseconds_); + } + + // Add/subtract operators (between another Duration class) + constexpr Duration operator+(const Duration& other) const { + return Duration(total_milliseconds_ + other.total_milliseconds_); + } + constexpr Duration operator-(const Duration& other) const { + return Duration(total_milliseconds_ - other.total_milliseconds_); + } + + // Add/subtract operators (between other duration types) + // Note 1: This is intentionally implemented as a member function to + // force the caller the put the Duration class first in the + // expression. + // Note 2: These operators are not defined for duration types that + // smaller than chrono::milliseconds. This could cause + // unexpected behavior as operations would be truncated. + // Ex: (Duration(5s) + 500us + 500us) = Duration(5s) + // Compiler will not allow such operations. + template + constexpr Duration operator+( + const std::chrono::duration& other) const { + return Duration(total_milliseconds_ + other); + } + template + constexpr Duration operator-( + const std::chrono::duration& other) const { + return Duration(total_milliseconds_ - other); + } + + // Increment/decrement operators (between another Duration class) + constexpr Duration& operator+=(const Duration& other) { + total_milliseconds_ += other.total_milliseconds_; + return *this; + } + constexpr Duration& operator-=(const Duration& other) { + total_milliseconds_ -= other.total_milliseconds_; + return *this; + } + + // Increment/decrement operators (between another duration types) + // Note: These operators are not defined for duration types that + // smaller than chrono::milliseconds (see '+' operator + // comment for details). + template + constexpr Duration& operator+=( + const std::chrono::duration& other) { + total_milliseconds_ += other; + return *this; + } + template + constexpr Duration& operator-=( + const std::chrono::duration& other) { + total_milliseconds_ -= other; + return *this; + } + + // To string operator. + // Converts the Duration to its string representation. + // Examples: + // Duration::FromMilliseconds(4000123).ToString() + // ==> "1h6m40s123ms" + // Duration::FromSeconds(4000).ToString() + // ==> "1h6m40s" + bool PrintTo(std::ostream* out) const; + std::string ToString() const; + + private: + // Internally, the Duration class stores durations in milliseconds. + Milliseconds total_milliseconds_ = Milliseconds::zero(); +}; // class Duration + +// == GTest Printer == +void PrintTo(const Duration& duration, std::ostream* out); +} // namespace wvutil +#endif // WVCDM_UTIL_WV_DURATION_H_ diff --git a/util/include/wv_timestamp.h b/util/include/wv_timestamp.h new file mode 100644 index 0000000..c8ee8f3 --- /dev/null +++ b/util/include/wv_timestamp.h @@ -0,0 +1,170 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVCDM_UTIL_WV_TIMESTAMP_H_ +#define WVCDM_UTIL_WV_TIMESTAMP_H_ + +#include + +#include + +#include "wv_class_utils.h" +#include "wv_duration.h" + +namespace wvutil { +// The Timestamp class is a light-weight representation of a +// time point. +// +// Internally, the time is measured in Unix Time (milliseconds +// since January 1st, 1970 UTC (epoch)). +// +// For this library, we are not concerned about time points before +// January 1st, 1970; this class does not support negative epoch +// seconds. In addition, zero is treated as a special value +// which indicates an uninitialized time point. +// +// Min Timestamp: 1970-01-01 00:00:00.001 (epoch ms = 1) +// Max Timestamp: 9999-12-31 23:59:59.999 (epoch ms = 253402300799999) +class Timestamp { + public: + constexpr Timestamp() = default; // Defaults to "unset". + WVCDM_CONSTEXPR_DEFAULT_COPY_AND_MOVE(Timestamp); + + // Create a Timestamp instance from Unix Time in epoch seconds. + // Optionally allowed to include |millisecond| (0 to 999). + // If |epoch_seconds| or |millisecond| are invalid values, then an + // uninitlaized Timestamp instance is returned. + static constexpr Timestamp FromUnixSeconds(uint64_t epoch_seconds, + uint32_t milliseconds = 0) { + if (milliseconds > 999u) return Timestamp(); + return Timestamp(Seconds(epoch_seconds) + Milliseconds(milliseconds)); + } + static constexpr Timestamp FromUnixSeconds(const Seconds& epoch_seconds, + uint32_t milliseconds = 0) { + if (milliseconds > 999u) return Timestamp(); + return Timestamp(epoch_seconds + Milliseconds(milliseconds)); + } + + // Create a Timestamp instance from Unix Time in epoch milliseconds. + // If |epoch_milliseconds| is an invalid value (zero), then an + // uninitlaized Timestamp instance is returned. + static constexpr Timestamp FromUnixMilliseconds(uint64_t epoch_milliseconds) { + return Timestamp(Milliseconds(epoch_milliseconds)); + } + static constexpr Timestamp FromUnixMilliseconds( + const Milliseconds& epoch_milliseconds) { + return Timestamp(epoch_milliseconds); + } + + // Obtain the minimum and maximum representable Timestamp. + static constexpr Timestamp Min() { return Timestamp(kMinMsDuration); } + static constexpr Timestamp Max() { return Timestamp(kMaxMsDuration); } + + // == General Accessors == + + // Get the epoch seconds duration (rounding down milliseconds). + constexpr Seconds epoch_seconds() const { + return std::chrono::floor(epoch_milliseconds_); + } + // Get the milliseconds fraction of the epoch time (0-999). + constexpr uint32_t milliseconds() const { + return static_cast( + (epoch_milliseconds_ - epoch_seconds()).count()); + } + + constexpr Milliseconds epoch_milliseconds() const { + return epoch_milliseconds_; + } + + // As noted, the timestamp is considered "unset" if zero. + constexpr bool IsSet() const { + return epoch_milliseconds_ > Milliseconds::zero(); + } + explicit constexpr operator bool() const { return IsSet(); } + + constexpr void Clear() { epoch_milliseconds_ = kUnsetMsDuration; } + + // == Comparisons == + + constexpr bool IsEqualTo(const Timestamp& other) const { + return epoch_milliseconds_ == other.epoch_milliseconds_; + } + constexpr int64_t CompareTo(const Timestamp& other) const { + return static_cast(epoch_milliseconds_.count() - + other.epoch_milliseconds_.count()); + } + WVCDM_DEFINE_CONSTEXPR_EQ_AND_CMP_OPERATORS(Timestamp); + + // == Arithmetic Operations == + + // Duration-based addition/subtraction operations. + // Returns a new Timestamp instance with time point adjusted by + // the provided Duration. + // If current Timestamp instance is unset, or if result value + // is outside the range of valid Timestamp values, then an unset + // Timestamp is returned. + constexpr Timestamp operator+(const Duration& duration) const { + if (!IsSet()) return Timestamp(); + return Timestamp(epoch_milliseconds_ + duration.total_milliseconds()); + } + constexpr Timestamp operator-(const Duration& duration) const { + if (!IsSet()) return Timestamp(); + return Timestamp(epoch_milliseconds_ - duration.total_milliseconds()); + } + + // Duration-based increment/decrement operators. + // Increment or decrement the Timestamp by the provided Duration + // amount. + // If current Timestamp instance is unset, then no action will + // occur. If result value is outside the range of valid Timestamp + // values, then Timestamp will be unset. + constexpr Timestamp& operator+=(const Duration& duration) { + if (!IsSet()) return *this; + epoch_milliseconds_ = + Normalize(epoch_milliseconds_ + duration.total_milliseconds()); + return *this; + } + constexpr Timestamp& operator-=(const Duration& duration) { + if (!IsSet()) return *this; + epoch_milliseconds_ = + Normalize(epoch_milliseconds_ - duration.total_milliseconds()); + return *this; + } + + // Timestamp difference. + // Returns the duration between the two provided Timestamp + // instances. + // If either Timestamp instance is unset, then the resulting + // Duration is zero. + constexpr Duration operator-(const Timestamp& other) const { + if (!IsSet() || !other.IsSet()) return Duration(); + return Duration(epoch_milliseconds_ - other.epoch_milliseconds_); + } + + private: + static constexpr const uint64_t kMinMs = 1; + static constexpr const uint64_t kMaxMs = 253402300799999; + static constexpr const Milliseconds kMinMsDuration = Milliseconds(kMinMs); + static constexpr const Milliseconds kMaxMsDuration = Milliseconds(kMaxMs); + static constexpr const Milliseconds kUnsetMsDuration = Milliseconds::zero(); + + static constexpr bool IsInRange(const Milliseconds& epoch_milliseconds) { + return epoch_milliseconds >= kMinMsDuration && + epoch_milliseconds <= kMaxMsDuration; + } + + // Ensures the provided |epoch_milliseconds| is in range of the + // Timestamp class; otherwise returns an unset duration value. + static constexpr Milliseconds Normalize( + const Milliseconds& epoch_milliseconds) { + return IsInRange(epoch_milliseconds) ? epoch_milliseconds + : kUnsetMsDuration; + } + + constexpr Timestamp(const Milliseconds& epoch_milliseconds) + : epoch_milliseconds_(Normalize(epoch_milliseconds)) {} + + Milliseconds epoch_milliseconds_ = kUnsetMsDuration; +}; // class Timestamp +} // namespace wvutil +#endif // WVCDM_UTIL_WV_TIMESTAMP_H_ diff --git a/util/src/buffer_reader.cpp b/util/src/buffer_reader.cpp new file mode 100644 index 0000000..2b8612e --- /dev/null +++ b/util/src/buffer_reader.cpp @@ -0,0 +1,133 @@ +// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "buffer_reader.h" + +#include + +#include +#include + +#include "log.h" +#include "platform.h" + +namespace wvutil { +bool BufferReader::Read1(uint8_t* v) { + if (v == nullptr) { + LOGE("Parse failure: Null output parameter when expecting non-null"); + return false; + } + + if (!HasBytes(1)) { + LOGV("Parse failure: No bytes available"); + return false; + } + + *v = buf_[pos_++]; + return true; +} + +// Internal implementation of multi-byte reads +template +bool BufferReader::Read(T* v) { + static constexpr size_t kTypeSize = sizeof(T); + if (v == nullptr) { + LOGE("Parse failure: Null output parameter when expecting non-null (%s)", + __PRETTY_FUNCTION__); + return false; + } + + if (!HasBytes(kTypeSize)) { + LOGV("Parse failure: Not enough bytes (%zu)", kTypeSize); + return false; + } + + T tmp = 0; + for (size_t i = 0; i < kTypeSize; i++) { + tmp <<= 8; + // This works for signed values. + tmp += buf_[pos_ + i]; + } + pos_ += kTypeSize; + *v = tmp; + return true; +} + +bool BufferReader::Read2(uint16_t* v) { return Read(v); } +bool BufferReader::Read2s(int16_t* v) { return Read(v); } +bool BufferReader::Read4(uint32_t* v) { return Read(v); } +bool BufferReader::Read4s(int32_t* v) { return Read(v); } +bool BufferReader::Read8(uint64_t* v) { return Read(v); } +bool BufferReader::Read8s(int64_t* v) { return Read(v); } + +bool BufferReader::ReadString(std::string* str, size_t count) { + if (str == nullptr) { + LOGE("Parse failure: Null output parameter when expecting non-null"); + return false; + } + + if (!HasBytes(count)) { + LOGV("Parse failure: Not enough bytes (%zu)", count); + return false; + } + + str->assign(buf_ + pos_, buf_ + pos_ + count); + pos_ += count; + return true; +} + +bool BufferReader::ReadVec(std::vector* vec, size_t count) { + if (vec == nullptr) { + LOGE("Parse failure: Null output parameter when expecting non-null"); + return false; + } + + if (!HasBytes(count)) { + LOGV("Parse failure: Not enough bytes (%zu)", count); + return false; + } + + vec->assign(buf_ + pos_, buf_ + pos_ + count); + pos_ += count; + return true; +} + +bool BufferReader::SkipBytes(size_t count) { + if (!HasBytes(count)) { + LOGV("Parse failure: Not enough bytes (%zu)", count); + return false; + } + + pos_ += count; + return true; +} + +bool BufferReader::Read4Into8(uint64_t* v) { + if (v == nullptr) { + LOGE("Parse failure: Null output parameter when expecting non-null"); + return false; + } + + uint32_t tmp = 0; + if (!Read4(&tmp)) { + return false; + } + *v = tmp; + return true; +} + +bool BufferReader::Read4sInto8s(int64_t* v) { + if (v == nullptr) { + LOGE("Parse failure: Null output parameter when expecting non-null"); + return false; + } + + // Beware of the need for sign extension. + int32_t tmp = 0; + if (!Read4s(&tmp)) { + return false; + } + *v = tmp; + return true; +} +} // namespace wvutil diff --git a/util/src/hls_attribute_list.cpp b/util/src/hls_attribute_list.cpp new file mode 100644 index 0000000..21bbc58 --- /dev/null +++ b/util/src/hls_attribute_list.cpp @@ -0,0 +1,981 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "hls_attribute_list.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "log.h" +#include "string_conversions.h" +#include "wv_class_utils.h" + +namespace wvutil { +using HlsValueType = HlsAttributeList::ValueType; +using HlsValueInfo = HlsAttributeList::ValueInfo; + +namespace { +// Grammatically significant HLS characters. +constexpr char kDashChar = '-'; +constexpr char kDecimalPointChar = '.'; +constexpr char kSpaceChar = ' '; +constexpr char kLineFeedChar = '\n'; +constexpr char kCarriageReturnChar = '\r'; +constexpr char kNameValueSeparatorChar = '='; +constexpr char kCommaChar = ','; +constexpr char kQuoteChar = '"'; +constexpr char kMemberSeparatorChar = kCommaChar; +constexpr char kResolutionSplitChar = 'x'; + +// Note: Hex sequences can start with "0x" or "0X". This +// implementation uses lower "0x" when serializing. +constexpr char kHexSequencePrefix[] = "0x"; + +// ASCII non-control character points. +constexpr char kNonAsciiControlFirstChar = ' '; // 0x20 +constexpr char kNonAsciiControlLastChar = '~'; // 0x7e +constexpr bool IsNonAsciiControl(char ch) { + return ch >= kNonAsciiControlFirstChar && ch <= kNonAsciiControlLastChar; +} + +// Character points allowed in HLS Playlist files. +// 1. UTF-8 encoded unicode +// * WV CDM does not support unicode or even UTF-8, we restrict +// it to basic ASCII (0x00 to 0x7f) +// 2. Does not allow for control sequences +// * Defined for values 0x00 to 0x1f and 0x7f to 0x9f, except for +// Line Feed (0x0a) and Carriage Return (0x0d). +// * Our restriction to basic ASCII limits this to 0x00 to 0x1f and 0x7f, +// except for LF and CR. +// 3. Requires Unicode Normalized Form type C (NFC); but not applicable +// to our case. +// +// Ref: RFC 8216, section 4.1. +constexpr bool IsAllowedHlsChar(char ch) { + return IsNonAsciiControl(ch) || ch == kLineFeedChar || + ch == kCarriageReturnChar; +} + +constexpr bool IsUpperAlpha(char ch) { return ch >= 'A' && ch <= 'Z'; } + +constexpr bool IsLowerAlpha(char ch) { return ch >= 'a' && ch <= 'z'; } + +constexpr bool IsDecimalChar(char ch) { return ch >= '0' && ch <= '9'; } + +constexpr bool IsUpperHexadecimalChar(char ch) { + return IsDecimalChar(ch) || (ch >= 'A' && ch <= 'Z'); +} + +constexpr bool IsWhitespaceChar(char ch) { + return ch == kSpaceChar || ch == kLineFeedChar || ch == kCarriageReturnChar; +} + +constexpr bool IsValidNameChar(char ch) { + return IsUpperAlpha(ch) || IsDecimalChar(ch) || ch == kDashChar; +} + +constexpr bool IsValidEnumStringChar(char ch) { + return IsAllowedHlsChar(ch) && !IsWhitespaceChar(ch) && ch != kQuoteChar && + ch != kCommaChar; +} + +constexpr bool IsValidQuotedStringValueChar(char ch) { + return IsAllowedHlsChar(ch) && ch != kQuoteChar && ch != kLineFeedChar && + ch != kCarriageReturnChar; +} + +constexpr bool IsValidSignedFloatChar(char ch) { + return ch == kDashChar || ch == kDecimalPointChar || IsDecimalChar(ch); +} + +constexpr bool IsValidUnquotedValueChar(char ch) { + return IsValidSignedFloatChar(ch) || IsDecimalChar(ch) || + IsValidEnumStringChar(ch); +} + +constexpr char ToUpper(char ch) { + return IsLowerAlpha(ch) ? (ch + ('A' - 'a')) : ch; +} + +// == Type Pattern Matching == + +constexpr size_t kMinIntegerLength = 1; +constexpr size_t kMaxIntegerLength = 20; +// HLS integers must fit within an unsigned 64-bit value. +constexpr char kMaxIntegerRep[] = "18446744073709551615"; + +// Checks if the serialized attribute value appears like an integer. +bool IsIntegerLike(const std::string& value_rep) { + if (value_rep.size() < kMinIntegerLength || + value_rep.size() > kMaxIntegerLength) + return false; + if (!std::all_of(value_rep.begin(), value_rep.end(), IsDecimalChar)) + return false; + if (value_rep.size() < kMaxIntegerLength) return true; + // Must be able to fit into a uint64_t. + return value_rep <= kMaxIntegerRep; +} + +// Checks if the serialized attribute value appears like a hex +// sequence. +bool IsHexSequenceLike(const std::string& value_rep) { + // Must have hex prefix, and at least 1 hex character. + if (value_rep.size() < 3) return false; + // Check prefix ("0x" or "0X"). + if (value_rep.front() != '0' || (value_rep[1] != 'x' && value_rep[1] != 'X')) + return false; + return std::all_of(value_rep.begin() + 2, value_rep.end(), + IsUpperHexadecimalChar); +} + +// Checks if the serialized attribute value appears like a signed +// float (not this can be used for unsigned floats as well). +bool IsSignedFloatLike(const std::string& value_rep) { + if (value_rep.empty()) return false; + // Initial check that all are valid signed float characters. + if (!std::all_of(value_rep.begin(), value_rep.end(), IsValidSignedFloatChar)) + return false; + // Skip initial dash. + const size_t int_start_pos = (value_rep.front() == kDashChar) ? 1 : 0; + if (int_start_pos == value_rep.size()) return false; + // Ensure there are no more dashes. + if (value_rep.find(kDashChar, int_start_pos) != std::string::npos) + return false; + + const size_t decimal_pos = value_rep.find(kDecimalPointChar); + // No decimal point implies all other character are decimal digits, + // and is a valid floating point value. + if (decimal_pos == std::string::npos) return true; + // Ensure the decimal is not the first item. + if (decimal_pos == int_start_pos) return false; + // Ensure all values after the first decimal point are + // digits. + if ((decimal_pos + 1) == value_rep.size()) return false; + return std::all_of(value_rep.begin() + decimal_pos + 1, value_rep.end(), + IsDecimalChar); +} + +bool IsQuotedStringLike(const std::string& value_rep) { + if (value_rep.size() < 2) return false; + if (value_rep.front() != kQuoteChar || value_rep.back() != kQuoteChar) + return false; + return std::all_of(value_rep.begin() + 1, value_rep.end() - 1, + IsValidQuotedStringValueChar); +} + +bool IsEnumStringLike(const std::string& value_rep) { + return HlsAttributeList::IsValidEnumStringValue(value_rep); +} + +bool IsResolutionLike(const std::string& value_rep) { + if (value_rep.size() < 3) return false; + // Find the resolution split. + const size_t split_pos = value_rep.find(kResolutionSplitChar); + if (split_pos == std::string::npos) return false; + // Ensure it is not at the beginning or end of the value. + if (split_pos == 0 || (split_pos + 1) == value_rep.size()) return false; + return IsIntegerLike(value_rep.substr(0, split_pos)) && + IsIntegerLike(value_rep.substr(split_pos + 1)); +} + +std::string MakeInteger(uint64_t value) { return std::to_string(value); } + +std::string MakeSignedFloat(double value) { + // Avoid potential issues with negative zero. + if (value == 0.0 || value == -0.0) return "0"; + std::ostringstream out; + out << std::fixed << value; + std::string result = out.str(); + // Remove trailing zeros. + while (!result.empty() && result.back() == '0') result.pop_back(); + // Check if there were only zeros after the decimal. + if (!result.empty() && result.back() == '.') result.pop_back(); + return result; +} + +std::string MakeQuoted(const std::string& value) { + std::string result; + result.reserve(value.size() + 2); + result.push_back(kQuoteChar); + result.append(value); + result.push_back(kQuoteChar); + return result; +} + +std::string MakeHex(const std::string& data) { + std::string result; + result.reserve(data.size() * 2 + sizeof(kHexSequencePrefix) - 1); + result.append(kHexSequencePrefix); + result.append(b2a_hex(data)); + // HLS hex must be upper case. + for (size_t i = 2; i < result.size(); i++) { + result[i] = ToUpper(result[i]); + } + return result; +} + +std::string MakeHex(const std::vector& data) { + std::string result; + result.reserve(data.size() * 2 + sizeof(kHexSequencePrefix) - 1); + result.append(kHexSequencePrefix); + result.append(b2a_hex(data)); + return result; +} + +std::string MakeResolution(uint64_t width, uint64_t height) { + std::ostringstream out; + out << width << kResolutionSplitChar << height; + return out.str(); +} + +// Assumes |integer_rep| is a valid HLS integer representation. +bool ParseIntegerInternal(const std::string& integer_rep, uint64_t* value) { + std::istringstream in(integer_rep); + in >> (*value); + return !in.fail(); +} + +// Assumes |integer_rep| is a valid HLS float representation. +bool ParseFloat(const std::string& float_rep, double* value) { + std::istringstream in(float_rep); + in >> *value; + return !in.fail(); +} + +// == Tokenizer == + +// The HLS Member Token represents the name-value pair. +// The |value_rep| is the exact representation from the +// HLS-serialized value (quoted values include the quotes). +struct HlsMemberToken { + std::string name; + std::string value_rep; + void Clear() { + name.clear(); + value_rep.clear(); + } +}; + +// The HLS Attribute List Reader will reader a character sequence +// representing a serialized HLS Attribute List, and convert them +// into a sequence of name-value tokens. +// +// This performs more like a lexeme tokenizer, rather than a lexical +// tokenizers (without producing structural tokens). +// +// The |name| will be a well-formed name; but values are simply check +// if they match a quoted or unquoted values. The contents of the +// values might not match lexical rules of an HLS value types. +// +// Rough grammar this reader enforces: +// list := BEGIN member (',' member)* END +// member := name '=' value +// name := name-char+ +// value := quoted-value | unquoted-value +// quoted-value := '"' quoted-char* '"' +// unquoted-value := unquoted-char+ +class HlsAttributeListReader { + public: + // Filler value. + static constexpr char kEocChar = 0; + + HlsAttributeListReader() = delete; + WVCDM_DISALLOW_COPY_AND_MOVE(HlsAttributeListReader); + HlsAttributeListReader(const std::string& contents) : contents_(contents) {} + + size_t index() const { return index_; } + size_t remaining() const { return contents_.size() - index_; } + bool IsBegin() const { return index_ == 0; } + bool IsEnd() const { return index_ >= contents_.size(); } + void Reset() { index_ = 0; } + + // Reads the entire HLS list string, enforcing the following + // grammar: + // list := BEGIN member (',' member)* END + bool ReadMembers(std::vector* members); + + private: + char Peak() const { return IsEnd() ? kEocChar : contents_[index_]; } + char Pop() { return IsEnd() ? kEocChar : contents_[index_++]; } + + bool IsNext(char ch) const { return IsEnd() ? false : Peak() == ch; } + bool IsNextQuote() const { return IsNext(kQuoteChar); } + bool IsNextNameValueSeparator() const { + return IsNext(kNameValueSeparatorChar); + } + bool IsNextMemberSeparator() const { return IsNext(kMemberSeparatorChar); } + bool IsNextNameChar() const { + return IsEnd() ? false : IsValidNameChar(Peak()); + } + bool IsNextUnquotedValueChar() const { + return IsEnd() ? false : IsValidUnquotedValueChar(Peak()); + } + bool IsNextQuotedValueChar() const { + return IsEnd() ? false : IsValidQuotedStringValueChar(Peak()); + } + + // Reads the next sequence of tokens representing a member. + // On a successful read, the reader index started at the first + // character of an attribute name, and will end with the index at + // the first character after the attribute value (or end). + // + // Enforces the name-value grammar: + // member := name '=' value + // + // Returns true if the member was read correctly, false otherwise. + bool ReadMember(HlsMemberToken* member); + + // Reads the next token representing an attribute name. + // On a successful read, the reader index started at the first + // character of the attribute name, and will end with the index at + // the first character after the attribute name. + // + // Returns true if the name was read correctly, false otherwise. + bool ReadName(std::string* name); + // Reads the next token representing a quoted attribute value. + // On a successful read, the reader index started at the first + // character of the attribute value (must be a quote), and will end + // with the index at the first character after the closing quote. + // + // Enforces the quoted-value grammar: + // quoted-value := '"' quoted-char* '"' + // + // The produce value rep will include the beginning and end quote. + // Returns true if the quote was read correctly, false otherwise. + bool ReadQuotedValue(std::string* value_rep); + // Reads the next token representing an unquoted attribute value. + // On a successful read, the reader index started at the first + // character of the attribute value, and will end with the index at + // the first character after the last possible attribute value. + // + // Enforces the unquoted-value grammar: + // unquoted-value := unquoted-char+ + // + // Returns true if the quote was read correctly, false otherwise. + bool ReadUnquotedValue(std::string* value_rep); + // Reads the next token representing a value (quoted or unquoted). + // Simply checks the first attribute value character if a quote + // or not, then uses either ReadQuotedValue() or ReadUnquotedValue() + // + // Enforces the value grammar: + // value := quoted-value | unquoted-value + // + // Returns true if the value was read correctly, false otherwise. + bool ReadValue(std::string* value_rep); + + const std::string& contents_; + size_t index_ = 0; +}; // class HlsAttributeListReader + +bool HlsAttributeListReader::ReadMembers(std::vector* members) { + members->clear(); + if (!IsBegin()) { + LOGV("Not at the beginning of the contents: index = %zu", index()); + return false; + } + while (!IsEnd()) { + if (!members->empty()) { + // Check for member separator. + if (!IsNextMemberSeparator()) { + LOGV( + "Expected member separator: char = %c, index = %zu, " + "member_index = %zu", + Peak(), index(), members->size()); + members->clear(); + return false; + } + Pop(); + } + HlsMemberToken member; + if (!ReadMember(&member)) { + LOGV("Failed to read member: member_index = %zu", members->size()); + members->clear(); + return false; + } + members->push_back(std::move(member)); + } + return true; +} + +bool HlsAttributeListReader::ReadMember(HlsMemberToken* member) { + member->Clear(); + if (!ReadName(&member->name)) { + LOGV("Failed to read name"); + member->Clear(); + return false; + } + if (!IsNextNameValueSeparator()) { + LOGV("Expected name-value separator: char = %c, index = %zu", Peak(), + index()); + member->Clear(); + return false; + } + Pop(); + if (!ReadValue(&member->value_rep)) { + LOGV("Failed to read value"); + member->Clear(); + return false; + } + return true; +} + +bool HlsAttributeListReader::ReadName(std::string* name) { + name->clear(); + while (!IsEnd()) { + if (!IsNextNameChar()) break; + name->push_back(Pop()); + } + return !name->empty(); +} + +bool HlsAttributeListReader::ReadQuotedValue(std::string* value_rep) { + value_rep->clear(); + if (!IsNextQuote()) return false; + value_rep->push_back(Pop()); + while (!IsEnd()) { + if (IsNextQuote()) { + value_rep->push_back(Pop()); + break; + } + if (!IsNextQuotedValueChar()) break; + value_rep->push_back(Pop()); + } + // Verify start and end quotes are in the output value. + return value_rep->size() >= 2 && value_rep->front() == kQuoteChar && + value_rep->back() == kQuoteChar; +} + +bool HlsAttributeListReader::ReadUnquotedValue(std::string* value_rep) { + value_rep->clear(); + while (!IsEnd()) { + if (!IsNextUnquotedValueChar()) break; + value_rep->push_back(Pop()); + } + return !value_rep->empty(); +} + +bool HlsAttributeListReader::ReadValue(std::string* value_rep) { + if (IsEnd()) { + LOGV("Expected value char, reached end"); + return false; + } + if (IsNextQuote()) return ReadQuotedValue(value_rep); + if (IsNextUnquotedValueChar()) return ReadUnquotedValue(value_rep); + LOGV("Expected value start char: char = %c, index = %zu", Peak(), index()); + return false; +} +} // namespace + +// static +const char* HlsAttributeList::ValueTypeToString(ValueType type) { + switch (type) { + case kUnsetType: + return "Unset"; + case kIntegerType: + return "Decimal-Integer"; + case kHexSequenceType: + return "Hexadecimal-Sequence"; + case kFloatType: + return "Decimal-Floating-Point"; + case kQuotedStringType: + return "Quoted-String"; + case kEnumStringType: + return "Enumerated-String"; + case kResolutionType: + return "Decimal-Resolution"; + } + return ""; +} + +// static +bool HlsAttributeList::IsValidName(const std::string& name) { + if (name.empty()) return false; + return std::all_of(name.begin(), name.end(), IsValidNameChar); +} + +// static +bool HlsAttributeList::IsValidEnumStringValue(const std::string& value) { + if (value.empty()) return false; + return std::all_of(value.begin(), value.end(), IsValidEnumStringChar); +} + +// static +bool HlsAttributeList::IsValidQuotedStringValue(const std::string& value) { + if (value.empty()) return true; + return std::all_of(value.begin(), value.end(), IsValidQuotedStringValueChar); +} + +// static +bool HlsAttributeList::IsValidIntegerRep(const std::string& value) { + return IsIntegerLike(value); +} + +// static +bool HlsAttributeList::ParseInteger(const std::string& integer_rep, + uint64_t* value) { + if (value == nullptr) { + LOGE("Output |value| is null"); + return false; + } + if (!IsIntegerLike(integer_rep)) { + LOGV("Not a valid HLS integer rep: %s", + SafeByteIdToString(integer_rep).c_str()); + return false; + } + return ParseIntegerInternal(integer_rep, value); +} + +std::vector HlsAttributeList::GetNames() const { + if (IsEmpty()) return std::vector(); + std::vector names; + names.reserve(Count()); + for (const auto& pair : members_) { + names.push_back(pair.first); + } + std::sort(names.begin(), names.end()); + return names; +} + +bool HlsAttributeList::Contains(const std::string& name) const { + return members_.find(name) != members_.end(); +} + +const HlsValueInfo* HlsAttributeList::GetInfo(const std::string& name) const { + const auto it = members_.find(name); + if (it == members_.end()) return nullptr; + return &it->second; +} + +bool HlsAttributeList::IsType(const std::string& name, ValueType type) const { + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return kUnsetType; + } + if (info->first == type) return true; + switch (type) { + case kUnsetType: + return false; // Bad caller value. + case kIntegerType: + return IsIntegerLike(info->second); + case kHexSequenceType: + return IsHexSequenceLike(info->second); + case kFloatType: + return IsSignedFloatLike(info->second); + case kQuotedStringType: + return false; // Quoted strings are unambiguous. + case kEnumStringType: + return IsEnumStringLike(info->second); + case kResolutionType: + return IsResolutionLike(info->second); + } + LOGE("Unexpected type: type = %d", static_cast(type)); + return false; +} + +bool HlsAttributeList::GetEnumString(const std::string& name, + std::string* value) const { + if (value == nullptr) { + LOGE("Output |value| is null"); + return false; + } + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return false; + } + if (info->first == kEnumStringType) { + value->assign(info->second); + return true; + } + // Certain values could appear as something else, but are still + // valid enum strings. + if (!IsEnumStringLike(info->second)) { + LOGV("Attribute is not enum string: name = %s, type = %s, rep = %s", + name.c_str(), ValueTypeToString(info->first), info->second.c_str()); + return false; + } + value->assign(info->second); + return true; +} + +bool HlsAttributeList::GetQuotedString(const std::string& name, + std::string* value) const { + if (value == nullptr) { + LOGE("Output |value| is null"); + return false; + } + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return false; + } + // Quote strings are unambiguous to all other value types. + if (info->first != kQuotedStringType) { + LOGV("Attribute is not quoted string: name = %s, type = %s", name.c_str(), + ValueTypeToString(info->first)); + return false; + } + const std::string& rep = info->second; + if (rep.size() < 2) { + // Should not have allowed this to happen when + // assigning the value. + LOGE("Internal error: quote string too small: name = %s, rep = %s", + name.c_str(), SafeByteIdToString(rep).c_str()); + return false; + } + // Beginning and end quote. + value->assign(rep.begin() + 1, rep.end() - 1); + return true; +} + +bool HlsAttributeList::GetHexSequence(const std::string& name, + std::string* value) const { + if (value == nullptr) { + LOGE("Output |value| is null"); + return false; + } + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return false; + } + if (info->first != kHexSequenceType && !IsHexSequenceLike(info->second)) { + LOGV("Attribute is not hex sequence: name = %s, type = %s", name.c_str(), + ValueTypeToString(info->first)); + return false; + } + const std::string& rep = info->second; + // Leading "0x" or "0X"; + if (rep.size() < 2) { + LOGE("Internal error: hex sequence is too small: name = %s, rep = %s", + name.c_str(), SafeByteIdToString(rep).c_str()); + return false; + } + std::string hex_only = rep.substr(2); // Strip prefix. + if ((hex_only.size() % 2) != 0) { + // HLS attributes are not required to have hex sequence of even + // length. Prepending an '0' for Widevine's hex to bytes converter. + hex_only.insert(0, 1, '0'); + } + *value = a2bs_hex(hex_only); + if (value->empty()) { + LOGE("Internal error: failed to decode hex sequence: name = %s, rep = %s", + name.c_str(), SafeByteIdToString(rep).c_str()); + return false; + } + return true; +} + +bool HlsAttributeList::GetHexSequence(const std::string& name, + std::vector* value) const { + if (value == nullptr) { + LOGE("Output |value| is null"); + return false; + } + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return false; + } + if (info->first != kHexSequenceType && !IsHexSequenceLike(info->second)) { + LOGV("Attribute is not hex sequence: name = %s, type = %s", name.c_str(), + ValueTypeToString(info->first)); + return false; + } + const std::string& rep = info->second; + // Leading "0x" or "0X"; + if (rep.size() < 2) { + LOGE("Internal error: hex sequence is too small: name = %s, rep = %s", + name.c_str(), SafeByteIdToString(rep).c_str()); + return false; + } + std::string hex_only = rep.substr(2); // Strip prefix. + if ((hex_only.size() % 2) != 0) { + // HLS attributes are not required to have hex sequence of even + // length. Prepending an '0' for Widevine's hex to bytes converter. + hex_only.insert(0, 1, '0'); + } + *value = a2b_hex(hex_only); + if (value->empty()) { + LOGE("Internal error: failed to decode hex sequence: name = %s, rep = %s", + name.c_str(), SafeByteIdToString(rep).c_str()); + return false; + } + return true; +} + +bool HlsAttributeList::GetInteger(const std::string& name, + uint64_t* value) const { + if (value == nullptr) { + LOGE("Output |value| is null"); + return false; + } + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return false; + } + if (info->first != kIntegerType && !IsIntegerLike(info->second)) { + LOGV("Attribute is not integer: name = %s, type = %s, rep = %s", + name.c_str(), ValueTypeToString(info->first), info->second.c_str()); + return false; + } + if (!ParseIntegerInternal(info->second, value)) { + LOGV("Failed to parse value as integer: name = %s, rep = %s", name.c_str(), + info->second.c_str()); + *value = 0; + return false; + } + return true; +} + +bool HlsAttributeList::GetFloat(const std::string& name, double* value) const { + if (value == nullptr) { + LOGE("Output |value| is null"); + return false; + } + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return false; + } + if (info->first != kFloatType && !IsSignedFloatLike(info->second)) { + LOGV("Attribute is not float: name = %s, type = %s, rep = %s", name.c_str(), + ValueTypeToString(info->first), info->second.c_str()); + return false; + } + if (!ParseFloat(info->second, value)) { + LOGV("Failed to parse value as float: name = %s, rep = %s", name.c_str(), + info->second.c_str()); + *value = 0; + return false; + } + return true; +} + +bool HlsAttributeList::GetResolution(const std::string& name, uint64_t* width, + uint64_t* height) const { + if (width == nullptr) { + LOGE("Output |width| is null"); + return false; + } + if (height == nullptr) { + LOGE("Output |height| is null"); + return false; + } + const auto* info = GetInfo(name); + if (info == nullptr) { + LOGV("Attribute does not exist: %s", name.c_str()); + return false; + } + if (info->first != kResolutionType && !IsResolutionLike(info->second)) { + LOGV("Attribute is not resolution: name = %s, type = %s, rep = %s", + name.c_str(), ValueTypeToString(info->first), info->second.c_str()); + return false; + } + const size_t res_split_pos = info->second.find(kResolutionSplitChar); + if (res_split_pos == std::string::npos) { + LOGE( + "Internal error: resolution is missing resolution split: " + "name = %s, rep = %s", + name.c_str(), info->second.c_str()); + return false; + } + if (!ParseIntegerInternal(info->second.substr(0, res_split_pos), width) || + !ParseIntegerInternal(info->second.substr(res_split_pos + 1), height)) { + LOGV("Failed to parse resolution: name = %s, rep = %s", name.c_str(), + info->second.c_str()); + *width = *height = 0; + return false; + } + return true; +} + +bool HlsAttributeList::SetEnumString(const std::string& name, + const std::string& value) { + if (!IsValidName(name)) { + LOGV("Invalid HLS attribute name: %s", name.c_str()); + return false; + } + if (!IsValidEnumStringValue(value)) { + LOGV("Invalid HLS enum string value: %s", value.c_str()); + return false; + } + members_[name] = HlsValueInfo(kEnumStringType, value); + return true; +} + +bool HlsAttributeList::SetQuotedString(const std::string& name, + const std::string& value) { + if (!IsValidName(name)) { + LOGV("Invalid HLS attribute name: %s", name.c_str()); + return false; + } + if (!IsValidQuotedStringValue(value)) { + LOGV("Invalid HLS quoted string value: %s", value.c_str()); + return false; + } + members_[name] = HlsValueInfo(kQuotedStringType, MakeQuoted(value)); + return true; +} + +bool HlsAttributeList::SetHexSequence(const std::string& name, + const std::string& value) { + if (!IsValidName(name)) { + LOGV("Invalid HLS attribute name: %s", name.c_str()); + return false; + } + if (value.empty()) { + LOGV("Hex sequence data cannot be empty: %s", name.c_str()); + return false; + } + members_[name] = HlsValueInfo(kHexSequenceType, MakeHex(value)); + return true; +} + +bool HlsAttributeList::SetHexSequence(const std::string& name, + const std::vector& value) { + if (!IsValidName(name)) { + LOGV("Invalid HLS attribute name: %s", name.c_str()); + return false; + } + if (value.empty()) { + LOGV("Hex sequence data cannot be empty: %s", name.c_str()); + return false; + } + members_[name] = HlsValueInfo(kHexSequenceType, MakeHex(value)); + return true; +} + +bool HlsAttributeList::SetInteger(const std::string& name, uint64_t value) { + if (!IsValidName(name)) { + LOGV("Invalid HLS attribute name: %s", name.c_str()); + return false; + } + members_[name] = HlsValueInfo(kIntegerType, MakeInteger(value)); + return true; +} + +bool HlsAttributeList::SetFloat(const std::string& name, double value) { + if (!IsValidName(name)) { + LOGV("Invalid HLS attribute name: %s", name.c_str()); + return false; + } + members_[name] = HlsValueInfo(kSignedFloatType, MakeSignedFloat(value)); + return true; +} + +bool HlsAttributeList::SetResolution(const std::string& name, uint64_t width, + uint64_t height) { + if (!IsValidName(name)) { + LOGV("Invalid HLS attribute name: %s", name.c_str()); + return false; + } + members_[name] = HlsValueInfo(kResolutionType, MakeResolution(width, height)); + return true; +} + +bool HlsAttributeList::Remove(const std::string& name) { + return members_.erase(name) != 0; +} + +bool HlsAttributeList::Parse(const std::string& hls_attr_list_rep) { + Clear(); + if (hls_attr_list_rep.empty()) { + // Technically not an error. + LOGV("Empty HLS attribute list"); + return true; + } + std::vector tokens; + if (!HlsAttributeListReader(hls_attr_list_rep).ReadMembers(&tokens)) { + LOGV("Failed to tokenize"); + return false; + } + if (tokens.empty()) return true; + + for (auto& token : tokens) { + if (!IsValidName(token.name)) { + // Internal error as the tokenizer should have + // caught this. + LOGE("Invalid name: name = %s, value_rep = %s", + SafeByteIdToString(token.name).c_str(), token.value_rep.c_str()); + Clear(); + return false; + } + if (Contains(token.name)) { + // HLS specification recommends clients to refuse to parse + // lists with repeated attribute names. + LOGV("HLS list contains repeated name: name = %s, list_rep = %s", + token.name.c_str(), hls_attr_list_rep.c_str()); + Clear(); + return false; + } + + // The "type" is used as a hint, but does not restrict + // how the value may be used. + if (IsQuotedStringLike(token.value_rep)) { + members_.emplace(token.name, HlsValueInfo{kQuotedStringType, + std::move(token.value_rep)}); + } else if (IsIntegerLike(token.value_rep)) { + members_.emplace(token.name, + HlsValueInfo{kIntegerType, std::move(token.value_rep)}); + } else if (IsSignedFloatLike(token.value_rep)) { + members_.emplace(token.name, HlsValueInfo{kSignedFloatType, + std::move(token.value_rep)}); + } else if (IsHexSequenceLike(token.value_rep)) { + members_.emplace(token.name, HlsValueInfo{kHexSequenceType, + std::move(token.value_rep)}); + } else if (IsResolutionLike(token.value_rep)) { + members_.emplace(token.name, HlsValueInfo{kResolutionType, + std::move(token.value_rep)}); + } else if (IsEnumStringLike(token.value_rep)) { + members_.emplace(token.name, HlsValueInfo{kEnumStringType, + std::move(token.value_rep)}); + } else { + // Note: HlsAttributeListReader is a lexeme tokenizer not a + // lexical tokenizer; this is not an internal error. + LOGV("Unrecognized value: name = %s, rep = %s", token.name.c_str(), + SafeByteIdToString(token.value_rep).c_str()); + Clear(); + return false; + } + } + return true; +} + +std::string HlsAttributeList::Serialize() const { + std::ostringstream hls_stream; + if (!SerializeToStream(&hls_stream)) { + // Should not occur. + LOGE("Failed to serialize HLS attribute list"); + return std::string(); + } + return hls_stream.str(); +} + +bool HlsAttributeList::SerializeToStream(std::ostream* out) const { + if (out == nullptr) { + LOGE("Output stream is null"); + return false; + } + bool first_member = true; + for (const auto& member : members_) { + if (first_member) { + first_member = false; + } else { + *out << kMemberSeparatorChar; + } + *out << member.first << kNameValueSeparatorChar << member.second.second; + } + return true; +} +} // namespace wvutil diff --git a/util/src/rw_lock.cpp b/util/src/rw_lock.cpp deleted file mode 100644 index 96218a2..0000000 --- a/util/src/rw_lock.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2019 Google LLC. All Rights Reserved. This file and proprietary -// source code may only be used and distributed under the Widevine License -// Agreement. - -#include "rw_lock.h" - -#include "log.h" - -namespace wvutil { - -shared_mutex::~shared_mutex() { - if (reader_count_ > 0) { - LOGE("shared_mutex destroyed with active readers!"); - } - if (has_writer_) { - LOGE("shared_mutex destroyed with an active writer!"); - } -} - -void shared_mutex::lock_shared() { - std::unique_lock lock(mutex_); - - while (has_writer_) { - condition_variable_.wait(lock); - } - - ++reader_count_; -} - -void shared_mutex::unlock_shared() { - std::unique_lock lock(mutex_); - - --reader_count_; - - if (reader_count_ == 0) { - condition_variable_.notify_all(); - } -} - -bool shared_mutex::lock_implementation(bool abort_if_unavailable) { - std::unique_lock lock(mutex_); - - while (reader_count_ > 0 || has_writer_) { - if (abort_if_unavailable) return false; - condition_variable_.wait(lock); - } - - has_writer_ = true; - return true; -} - -void shared_mutex::unlock() { - std::unique_lock lock(mutex_); - - has_writer_ = false; - - condition_variable_.notify_all(); -} - -} // namespace wvutil diff --git a/util/src/string_utils.cpp b/util/src/string_utils.cpp new file mode 100644 index 0000000..7461f7f --- /dev/null +++ b/util/src/string_utils.cpp @@ -0,0 +1,244 @@ +// Copyright 2024 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "string_utils.h" + +#include +#include +#include +#include + +#include "log.h" + +namespace wvutil { +namespace { +// ASCII whitespace characters. +constexpr char kWhiteSpace[] = " \t\n\r"; + +constexpr size_t kNotFound = std::string::npos; + +// Splits input string |s| into a vector of strings containing +// a single character. +std::vector StringSplitAll(const std::string& s) { + std::vector tokens; + for (const char& c : s) { + tokens.emplace_back(1, c); + } + return tokens; +} + +std::vector StringSplitCommon(const std::string& s, + const std::string_view& delim) { + if (s.empty()) return {}; + if (delim.empty()) return StringSplitAll(s); + + std::vector tokens; + size_t start = 0; + size_t end = 0; + while ((end = s.find(delim, start)) != kNotFound) { + tokens.push_back(s.substr(start, end - start)); + start = end + delim.size(); + } + tokens.push_back(s.substr(start)); + return tokens; +} + +// Returns the total length of all the |tokens|. +size_t SumOfTokenLengths(const std::vector& tokens) { + size_t total_length = 0; + for (const auto& token : tokens) { + total_length += token.size(); + } + return total_length; +} + +// Special case of StringJoin where the glue is empty. +std::string StringJoinWithoutGlue(const std::vector& tokens) { + if (tokens.empty()) return std::string(); + const size_t expected_size = SumOfTokenLengths(tokens); + std::string result; + result.reserve(expected_size); + result = tokens.front(); + for (size_t i = 1; i < tokens.size(); i++) { + result.append(tokens[i]); + } + return result; +} + +std::string StringJoinCommon(const std::vector& tokens, + const std::string_view& glue) { + if (tokens.empty()) return std::string(); + if (tokens.size() == 1) return tokens.front(); + if (glue.empty()) return StringJoinWithoutGlue(tokens); + + // Total length of tokens + the glue length between each token. + const size_t expected_size = + SumOfTokenLengths(tokens) + ((tokens.size() - 1) * glue.size()); + std::string result; + result.reserve(expected_size); + result = tokens.front(); + for (size_t i = 1; i < tokens.size(); i++) { + result.append(glue); + result.append(tokens[i]); + } + return result; +} + +size_t StringCountCommon(const std::string& haystack, + const std::string_view& needle) { + // Special case. String libraries count the occurrence of an empty + // string as the length of the haystack + 1. This library does the + // same to remain predictable to those familiar with other string + // libraries. + if (needle.empty()) return haystack.size() + 1; + if (haystack.size() < needle.size()) return 0; + size_t count = 0; + size_t pos = 0; + while ((pos = haystack.find(needle, pos)) != kNotFound) { + count++; + pos += needle.size(); + } + return count; +} + +bool StringContainsCommon(const std::string& haystack, + const std::string_view& needle) { + if (needle.empty()) return true; + if (haystack.size() < needle.size()) return false; + return haystack.find(needle) != kNotFound; +} + +bool StringStartsWithCommon(const std::string& haystack, + const std::string_view& needle) { + if (needle.empty()) return true; + if (haystack.size() < needle.size()) return false; + return std::equal(haystack.begin(), haystack.begin() + needle.size(), + needle.begin(), needle.end()); +} + +bool StringEndsWithCommon(const std::string& haystack, + const std::string_view& needle) { + if (haystack.size() < needle.size()) return false; + return std::equal(haystack.rbegin(), haystack.rbegin() + needle.size(), + needle.rbegin(), needle.rend()); +} +} // namespace + +std::vector StringSplit(const std::string& s, char delim) { + return StringSplitCommon(s, std::string_view(&delim, 1)); +} + +std::vector StringSplit(const std::string& s, + const std::string& delim) { + return StringSplitCommon(s, std::string_view(delim)); +} + +std::vector StringSplit(const std::string& s, const char* delim) { + if (delim == nullptr) { + LOGE("Input |delim| is null"); + return {}; + } + return StringSplitCommon(s, std::string_view(delim)); +} + +std::string StringJoin(const std::vector& tokens, char glue) { + return StringJoinCommon(tokens, std::string_view(&glue, 1)); +} + +std::string StringJoin(const std::vector& tokens, + const std::string& glue) { + return StringJoinCommon(tokens, std::string_view(glue)); +} + +std::string StringJoin(const std::vector& tokens, + const char* glue) { + if (glue == nullptr) { + LOGE("Input |glue| is empty"); + return std::string(); + } + return StringJoinCommon(tokens, std::string_view(glue)); +} + +size_t StringCount(const std::string& haystack, char needle) { + if (haystack.empty()) return 0; + return std::count(haystack.begin(), haystack.end(), needle); +} + +size_t StringCount(const std::string& haystack, const std::string& needle) { + return StringCountCommon(haystack, std::string_view(needle)); +} + +size_t StringCount(const std::string& haystack, const char* needle) { + if (needle == nullptr) { + LOGE("Input |needle| is null"); + return 0; + } + return StringCountCommon(haystack, std::string_view(needle)); +} + +bool StringContains(const std::string& haystack, char needle) { + if (haystack.empty()) return false; + return haystack.find(needle) != kNotFound; +} + +bool StringContains(const std::string& haystack, const std::string& needle) { + return StringContainsCommon(haystack, std::string_view(needle)); +} + +bool StringContains(const std::string& haystack, const char* needle) { + if (needle == nullptr) { + LOGE("Input |needle| is null"); + return false; + } + return StringContainsCommon(haystack, std::string_view(needle)); +} + +bool StringStartsWith(const std::string& haystack, char needle) { + if (haystack.empty()) return false; + return haystack.front() == needle; +} + +bool StringStartsWith(const std::string& haystack, const std::string& needle) { + return StringStartsWithCommon(haystack, std::string_view(needle)); +} + +bool StringStartsWith(const std::string& haystack, const char* needle) { + if (needle == nullptr) { + LOGE("Input |needle| is null"); + return false; + } + return StringStartsWithCommon(haystack, std::string_view(needle)); +} + +bool StringEndsWith(const std::string& haystack, char needle) { + if (haystack.empty()) return false; + return haystack.back() == needle; +} + +bool StringEndsWith(const std::string& haystack, const std::string& needle) { + return StringEndsWithCommon(haystack, std::string_view(needle)); +} + +bool StringEndsWith(const std::string& haystack, const char* needle) { + if (needle == nullptr) { + LOGE("Input |needle| is null"); + return false; + } + return StringEndsWithCommon(haystack, std::string_view(needle)); +} + +std::string StringTrim(const std::string& s) { + if (s.empty()) return std::string(); + const size_t start = s.find_first_not_of(kWhiteSpace); + const size_t end = s.find_last_not_of(kWhiteSpace); + if (start == kNotFound || end == kNotFound) + return std::string(); // All white space. + return s.substr(start, end - start + 1); +} + +bool StringVecContains(const std::vector& haystack, + const std::string& needle) { + if (haystack.empty()) return false; + return std::find(haystack.begin(), haystack.end(), needle) != haystack.end(); +} +} // namespace wvutil diff --git a/util/src/time_struct.h b/util/src/time_struct.h new file mode 100644 index 0000000..b3927e4 --- /dev/null +++ b/util/src/time_struct.h @@ -0,0 +1,83 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#ifndef WVCDM_UTIL_INTERNAL_TIME_STRUCT_H_ +#define WVCDM_UTIL_INTERNAL_TIME_STRUCT_H_ + +// This is an internal header and should not be included outside +// of the CDM util library. +#ifndef WVCDM_ALLOW_TIME_STRUCT_INCLUDE +# error Internal header, do not include +#endif // WVCDM_ALLOW_TIME_STRUCT_INCLUDE + +#include + +#include "wv_class_utils.h" + +namespace wvutil { +namespace internal { +// Wrapper class around a struct tm from the header. +// Allows for easy conversion from how the C struct tm represents +// values to how the DateTime represents values. +class TimeStruct { + public: + // struct tm offsets. + static constexpr int kYearOffset = 1900; + static constexpr int kMonthOffset = 1; + static constexpr int kDayOfWeekOffset = 1; + static constexpr int kDayOfYearOffset = 1; + + constexpr TimeStruct() = default; + WVCDM_CONSTEXPR_DEFAULT_COPY_AND_MOVE(TimeStruct); + + constexpr const struct tm& time_parts() const { return time_parts_; } + constexpr void set_time_parts(const struct tm& tp) { time_parts_ = tp; } + + // Date fields. + constexpr int year() const { return time_parts_.tm_year + kYearOffset; } + constexpr void set_year(int year) { + time_parts_.tm_year = year - kYearOffset; + } + + constexpr int month() const { return time_parts_.tm_mon + kMonthOffset; } + constexpr void set_month(int month) { + time_parts_.tm_mon = month - kMonthOffset; + } + + constexpr int day_of_month() const { return time_parts_.tm_mday; } + constexpr void set_day_of_month(int day_of_month) { + time_parts_.tm_mday = day_of_month; + } + + // Sunday = 1, Saturday = 7 + constexpr int day_of_week() const { + return time_parts_.tm_wday + kDayOfWeekOffset; + } + constexpr void set_day_of_week(int day_of_week) { + time_parts_.tm_wday = day_of_week - kDayOfWeekOffset; + } + // January 1st = 1, December 31st = 365/366 + constexpr int day_of_year() const { + return time_parts_.tm_yday + kDayOfYearOffset; + } + constexpr void set_day_of_year(int day_of_year) { + time_parts_.tm_yday = day_of_year - kDayOfYearOffset; + } + + // Time fields. + constexpr int hour() const { return time_parts_.tm_hour; } + constexpr void set_hour(int hour) { time_parts_.tm_hour = hour; } + constexpr int minute() const { return time_parts_.tm_min; } + constexpr void set_minute(int minute) { time_parts_.tm_min = minute; } + constexpr int second() const { + // Handle case of leap second. + return (time_parts_.tm_sec >= 60) ? 59 : time_parts_.tm_sec; + } + constexpr void set_second(int second) { time_parts_.tm_sec = second; } + + private: + struct tm time_parts_ = {}; +}; // class TimeStruct +} // namespace internal +} // namespace wvutil +#endif // WVCDM_UTIL_INTERNAL_TIME_STRUCT_H_ diff --git a/util/src/wv_date_time.cpp b/util/src/wv_date_time.cpp new file mode 100644 index 0000000..98c7f82 --- /dev/null +++ b/util/src/wv_date_time.cpp @@ -0,0 +1,195 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "wv_date_time.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "log.h" +#include "wv_duration.h" +#include "wv_timestamp.h" + +#define WVCDM_ALLOW_TIME_STRUCT_INCLUDE +#include "time_struct.h" +#undef WVCDM_ALLOW_TIME_STRUCT_INCLUDE + +namespace wvutil { +namespace { +constexpr uint64_t kMaxCTime = + static_cast(std::numeric_limits::max()); +} // namespace + +// static +DateTime DateTime::FromTimestamp(const Timestamp& timestamp) { + if (!timestamp.IsSet()) { + LOGD("Timestamp is unset"); + return DateTime(); + } + return DateTime(timestamp); +} + +// private +DateTime::DateTime(const Timestamp& timestamp) { UpdateTimestamp(timestamp); } + +// private +void DateTime::UpdateTimestamp(const Timestamp& new_timestamp) { + if (!new_timestamp.IsSet()) { + Clear(); + return; + } + + const uint64_t epoch_seconds = + static_cast(new_timestamp.epoch_seconds().count()); + if (epoch_seconds > kMaxCTime) { + LOGE("Epoch seconds out of range of time_t: %zu", + static_cast(epoch_seconds)); + Clear(); + return; + } + + // Convert time point to time parts. + const time_t time_point = static_cast(epoch_seconds); + struct tm tp = {}; + // gmtime_r() is part POSIX.1-1996 and is safe to assume most + // standard C libraries will include it. + if (::gmtime_r(&time_point, &tp) == nullptr) { + const int saved_errno = errno; + if (saved_errno == EOVERFLOW) { + LOGE("Overflow when converting to time parts: epoch_seconds = %zu", + static_cast(epoch_seconds)); + } else { + LOGE( + "Failed to convert time point to time parts: " + "epoch_seconds = %zu, errno = %d", + static_cast(epoch_seconds), saved_errno); + } + Clear(); + return; + } + + internal::TimeStruct tm; + tm.set_time_parts(tp); + timestamp_ = new_timestamp; + year_ = tm.year(); + month_ = tm.month(); + day_ = tm.day_of_month(); + day_of_year_ = tm.day_of_year(); + day_of_week_ = tm.day_of_week(); +} + +DateTime DateTime::operator+(const Duration& duration) const { + if (!IsSet()) { + LOGD("DateTime unset: duration = %s", duration.ToString().c_str()); + return DateTime(); + } + if (duration.IsZero()) return *this; + const Timestamp new_timestamp = timestamp_ + duration; + if (!new_timestamp.IsSet()) { + LOGD("Addition overflow: duration = %s", duration.ToString().c_str()); + return DateTime(); + } + return DateTime(new_timestamp); +} + +DateTime DateTime::operator-(const Duration& duration) const { + if (!IsSet()) { + LOGD("DateTime unset: duration = %s", duration.ToString().c_str()); + return DateTime(); + } + if (duration.IsZero()) return *this; + const Timestamp new_timestamp = timestamp_ - duration; + if (!new_timestamp.IsSet()) { + LOGD("Subtraction overflow: duration = %s", duration.ToString().c_str()); + return DateTime(); + } + return DateTime(new_timestamp); +} + +Duration DateTime::operator-(const DateTime& other) const { + if (!IsSet() || !other.IsSet()) { + LOGD("Cannot get duration between unset DateTime"); + return Duration::Zero(); + } + return timestamp_ - other.timestamp_; +} + +DateTime& DateTime::operator+=(const Duration& duration) { + if (!IsSet()) { + LOGD("DateTime unset: duration = %s", duration.ToString().c_str()); + return *this; + } + if (duration.IsZero()) return *this; + const Timestamp new_timestamp = timestamp_ + duration; + if (!new_timestamp.IsSet()) { + LOGD("Addition overflow: duration = %s", duration.ToString().c_str()); + Clear(); + return *this; + } + UpdateTimestamp(new_timestamp); + return *this; +} + +DateTime& DateTime::operator-=(const Duration& duration) { + if (!IsSet()) { + LOGD("DateTime unset: duration = %s", duration.ToString().c_str()); + return *this; + } + if (duration.IsZero()) return *this; + const Timestamp new_timestamp = timestamp_ - duration; + if (!new_timestamp.IsSet()) { + LOGD("Subtraction overflow: duration = %s", duration.ToString().c_str()); + Clear(); + return *this; + } + UpdateTimestamp(new_timestamp); + return *this; +} + +bool DateTime::PrintTo(std::ostream* out) const { + if (out == nullptr) { + LOGE("Output stream is null"); + return false; + } + if (!IsSet()) { + *out << ""; + return true; + } + + const std::ios_base::fmtflags original_flags(out->flags()); + *out << std::dec << std::setfill('0'); + // Date + *out << std::setw(4) << year(); + *out << '-' << std::setw(2) << month(); + *out << '-' << std::setw(2) << day(); + // Time + *out << 'T' << std::setw(2) << hour(); + *out << ':' << std::setw(2) << minute(); + *out << ':' << std::setw(2) << second(); + const uint32_t ms = millisecond(); + if (ms != 0) { + // For default printer, only include MS if non-zero. + *out << '.' << std::setw(3) << ms; + } + *out << 'Z'; + out->flags(original_flags); + return true; +} + +std::string DateTime::ToString() const { + std::ostringstream out; + PrintTo(&out); + return out.str(); +} + +void PrintTo(const DateTime& date_time, std::ostream* out) { + if (out == nullptr) return; + date_time.PrintTo(out); +} +} // namespace wvutil diff --git a/util/src/wv_duration.cpp b/util/src/wv_duration.cpp new file mode 100644 index 0000000..ddfd32a --- /dev/null +++ b/util/src/wv_duration.cpp @@ -0,0 +1,71 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "wv_duration.h" + +#include + +#include +#include +#include +#include + +#include "log.h" + +namespace wvutil { +namespace { +// Time unit representations. +constexpr char kMillisUnit[] = "ms"; +constexpr char kSecondsUnit[] = "s"; +constexpr char kMinutesUnit[] = "m"; +constexpr char kHoursUnit[] = "h"; +constexpr char kDaysUnit[] = "d"; +// Common representation of a zero duration. +constexpr char kZeroRep[] = "0s"; +} // namespace + +bool Duration::PrintTo(std::ostream* out) const { + if (out == nullptr) { + LOGE("Output stream is null"); + return false; + } + if (IsZero()) { + *out << kZeroRep; + return true; + } + if (IsNegative()) { + *out << '-'; + return GetAbsolute().PrintTo(out); + } + const std::ios_base::fmtflags original_flags(out->flags()); + if (days() != 0) { + *out << days() << kDaysUnit; + } + if (hours() != 0) { + *out << hours() << kHoursUnit; + } + if (minutes() != 0) { + *out << minutes() << kMinutesUnit; + } + if (seconds() != 0) { + *out << seconds() << kSecondsUnit; + } + if (milliseconds() != 0) { + *out << milliseconds() << kMillisUnit; + } + out->flags(original_flags); + return true; +} + +std::string Duration::ToString() const { + std::ostringstream out; + PrintTo(&out); + return out.str(); +} + +// GTest printer. +void PrintTo(const Duration& duration, std::ostream* out) { + if (out == nullptr) return; + duration.PrintTo(out); +} +} // namespace wvutil diff --git a/util/test/buffer_reader_test.cpp b/util/test/buffer_reader_test.cpp new file mode 100644 index 0000000..d285222 --- /dev/null +++ b/util/test/buffer_reader_test.cpp @@ -0,0 +1,841 @@ +// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "buffer_reader.h" + +#include + +#include +#include + +#include + +namespace wvutil { + +class BufferReaderTest : public testing::Test { + public: + template + void WriteToBuffer(uint8_t* buffer, T v) { + for (size_t i = 0; i < sizeof(T); ++i) { + size_t insert_at = (sizeof(T) - i) - 1; // reverse the order of i + size_t shift_amount = 8 * i; + + buffer[insert_at] = static_cast((v >> shift_amount) & 0xff); + } + } + + // populate and validate data by cycling through the alphabet + // (lower case) so that it will work for strings and raw bytes + + void PopulateData(uint8_t* dest, size_t byte_count) { + for (size_t i = 0; i < byte_count; i++) { + dest[i] = static_cast(i % 26 + 'a'); + } + } + + bool ValidateData(const uint8_t* data, size_t byte_count) { + for (size_t i = 0; i < byte_count; i++) { + if (data[i] != static_cast(i % 26 + 'a')) { + return false; + } + } + + return true; + } + + bool ValidateReader(const BufferReader& reader, + const uint8_t* expected_address, size_t expected_size, + size_t expected_position) { + return reader.data() == expected_address && + reader.size() == expected_size && reader.pos() == expected_position; + } + + bool CheckRead1(uint8_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + uint8_t read; + + return reader.Read1(&read) && input == read && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead2(uint16_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + uint16_t read; + + return reader.Read2(&read) && input == read && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead2s(int16_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + int16_t read; + + return reader.Read2s(&read) && input == read && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead4(uint32_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + uint32_t read; + + return reader.Read4(&read) && input == read && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead4s(int32_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + int32_t read; + + return reader.Read4s(&read) && input == read && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead8(uint64_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + uint64_t read; + + return reader.Read8(&read) && input == read && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead8s(int64_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + int64_t read; + + return reader.Read8s(&read) && input == read && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead4Into8(uint32_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + uint64_t read; + return reader.Read4Into8(&read) && read == input && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } + + bool CheckRead4sInto8s(int32_t input) { + uint8_t raw_data[sizeof(input)]; + WriteToBuffer(raw_data, input); + + BufferReader reader(raw_data, sizeof(raw_data)); + + int64_t read; + return reader.Read4sInto8s(&read) && read == input && + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(input)); + } +}; + +TEST_F(BufferReaderTest, InitializeGoodDataAndGoodSize) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, InitializeGoodDataAndNoSize) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, InitializeNoDataNoSize) { + BufferReader reader(nullptr, 0); + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, InitializeNoDataBadSize) { + BufferReader reader(nullptr, 16); + + // Buffer reader should default to a size of 0 when given + // NULL data to ensure no reading of bad data + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, HasBytesWithBytes) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + // the reader should have enough bytes from 0 to the size of the buffer + for (size_t i = 0; i <= sizeof(raw_data); i++) { + ASSERT_TRUE(reader.HasBytes(i)); + } + + ASSERT_FALSE(reader.HasBytes(sizeof(raw_data) + 1)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, HasBytesWithEmptyBuffer) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + ASSERT_FALSE(reader.HasBytes(1)); + ASSERT_TRUE(reader.HasBytes(0)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, HasBytesWithNullBuffer) { + BufferReader reader(nullptr, 8); + + ASSERT_FALSE(reader.HasBytes(1)); + ASSERT_TRUE(reader.HasBytes(0)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, HasBytesAfterAllRead) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + for (size_t i = 0; i < sizeof(raw_data); i++) { + uint8_t read; + ASSERT_TRUE(reader.Read1(&read)); + } + + ASSERT_FALSE(reader.HasBytes(1)); + ASSERT_TRUE(reader.HasBytes(0)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE( + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(raw_data))); +} + +TEST_F(BufferReaderTest, Read1LargeNumber) { ASSERT_TRUE(CheckRead1(0xFF)); } + +TEST_F(BufferReaderTest, Read1SmallNumber) { ASSERT_TRUE(CheckRead1(0x0F)); } + +TEST_F(BufferReaderTest, Read1Zero) { ASSERT_TRUE(CheckRead1(0)); } + +TEST_F(BufferReaderTest, Read1WithNoData) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + uint8_t read; + ASSERT_FALSE(reader.Read1(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, Read1WithNullBuffer) { + BufferReader reader(nullptr, 16); + + uint8_t read; + ASSERT_FALSE(reader.Read1(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read1WithNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read1(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read2LargeNumber) { ASSERT_TRUE(CheckRead2(30000)); } + +TEST_F(BufferReaderTest, Read2SmallNumber) { ASSERT_TRUE(CheckRead2(10)); } + +TEST_F(BufferReaderTest, Read2Zero) { ASSERT_TRUE(CheckRead2(0)); } + +TEST_F(BufferReaderTest, Read2WithNoData) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + uint16_t read; + ASSERT_FALSE(reader.Read2(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, Read2WithNullBuffer) { + BufferReader reader(nullptr, 16); + + uint16_t read; + ASSERT_FALSE(reader.Read2(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read2WithNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read2(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read2sLargePositive) { + ASSERT_TRUE(CheckRead2s(30000)); +} + +TEST_F(BufferReaderTest, Read2sSmallPositive) { ASSERT_TRUE(CheckRead2s(10)); } + +TEST_F(BufferReaderTest, Read2sZero) { ASSERT_TRUE(CheckRead2s(0)); } + +TEST_F(BufferReaderTest, Read2sSmallNegative) { ASSERT_TRUE(CheckRead2s(-10)); } + +TEST_F(BufferReaderTest, Read2sLargeNegative) { + ASSERT_TRUE(CheckRead2s(-30000)); +} + +TEST_F(BufferReaderTest, Read2sWithNoData) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + int16_t read; + ASSERT_FALSE(reader.Read2s(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, Read2sWithNullBuffer) { + BufferReader reader(nullptr, 16); + + int16_t read; + ASSERT_FALSE(reader.Read2s(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read2sWithNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read2s(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read4LargeNumber) { + // a number near uint32's max value + ASSERT_TRUE(CheckRead4(2000000000)); +} + +TEST_F(BufferReaderTest, Read4SmallNumber) { ASSERT_TRUE(CheckRead4(10)); } + +TEST_F(BufferReaderTest, Read4Zero) { ASSERT_TRUE(CheckRead4(0)); } + +TEST_F(BufferReaderTest, Read4WithNoData) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + uint32_t read; + ASSERT_FALSE(reader.Read4(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, Read4WithNullBuffer) { + BufferReader reader(nullptr, 16); + + uint32_t read; + ASSERT_FALSE(reader.Read4(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read4WithNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read4(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read4sLargePositive) { + // a number near int32's max value + ASSERT_TRUE(CheckRead4s(2000000000)); +} + +TEST_F(BufferReaderTest, Read4sSmallPositive) { ASSERT_TRUE(CheckRead4s(10)); } + +TEST_F(BufferReaderTest, Read4sZero) { ASSERT_TRUE(CheckRead4s(0)); } + +TEST_F(BufferReaderTest, Read4sSmallNegative) { ASSERT_TRUE(CheckRead4s(-10)); } + +TEST_F(BufferReaderTest, Read4sLargeNegative) { + // a number near int32's max negative value + ASSERT_TRUE(CheckRead4s(-2000000000)); +} + +TEST_F(BufferReaderTest, Read4sWithNoData) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + int32_t read; + ASSERT_FALSE(reader.Read4s(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, Read4sWithNullBuffer) { + BufferReader reader(nullptr, 16); + + int32_t read; + ASSERT_FALSE(reader.Read4s(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read4sWithNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read4s(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read8LargeNumber) { + // a number near uint64's max value + ASSERT_TRUE(CheckRead8(9000000000000000000)); +} + +TEST_F(BufferReaderTest, Read8SmallNumber) { ASSERT_TRUE(CheckRead8(10)); } + +TEST_F(BufferReaderTest, Read8Zero) { ASSERT_TRUE(CheckRead8(0)); } + +TEST_F(BufferReaderTest, Read8WithNoData) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + uint64_t read; + ASSERT_FALSE(reader.Read8(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, Read8WithNullBuffer) { + BufferReader reader(nullptr, 16); + + uint64_t read; + ASSERT_FALSE(reader.Read8(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read8WithNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read8(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read8sLargePositive) { + // a number near int64's max value + ASSERT_TRUE(CheckRead8s(9000000000000000000)); +} + +TEST_F(BufferReaderTest, Read8sSmallPositive) { ASSERT_TRUE(CheckRead8s(10)); } + +TEST_F(BufferReaderTest, Read8sZero) { ASSERT_TRUE(CheckRead8s(0)); } + +TEST_F(BufferReaderTest, Read8sSmallNegative) { ASSERT_TRUE(CheckRead8s(-10)); } + +TEST_F(BufferReaderTest, Read8sLargeNegative) { + // a number near int64's max negative value + ASSERT_TRUE(CheckRead8s(-9000000000000000000)); +} + +TEST_F(BufferReaderTest, Read8sWithNoData) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, 0); + + int64_t read; + ASSERT_FALSE(reader.Read8s(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, 0, 0)); +} + +TEST_F(BufferReaderTest, Read8sWithNullBuffer) { + BufferReader reader(nullptr, 16); + + int64_t read; + ASSERT_FALSE(reader.Read8s(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read8sWithNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read8s(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, ReadString) { + uint8_t raw_data[5]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + std::string read; + ASSERT_TRUE(reader.ReadString(&read, sizeof(raw_data))); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(read.length() == sizeof(raw_data)); + ASSERT_TRUE(ValidateData((const uint8_t*)read.c_str(), read.length())); + + ASSERT_TRUE( + ValidateReader(reader, raw_data, sizeof(raw_data), sizeof(raw_data))); +} + +TEST_F(BufferReaderTest, ReadStringNullSource) { + BufferReader reader(nullptr, 5); + + std::string read; + ASSERT_FALSE(reader.ReadString(&read, 5)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, ReadStringNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.ReadString(nullptr, 5)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, ReadStringZeroCount) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + std::string read; + ASSERT_TRUE(reader.ReadString(&read, 0)); + + ASSERT_TRUE(0 == read.length()); + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, ReadStringTooLarge) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + std::string read; + ASSERT_FALSE(reader.ReadString(&read, sizeof(raw_data) * 2)); + + ASSERT_TRUE(0 == read.length()); + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, ReadVector) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + std::vector read; + + ASSERT_TRUE(reader.ReadVec(&read, 4)); + + ASSERT_TRUE(read.size() == 4); + + for (size_t i = 0; i < 4; i++) { + ASSERT_TRUE(raw_data[i] == read[i]); + } + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 4)); +} + +TEST_F(BufferReaderTest, ReadVectorTooLarge) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + std::vector read; + + ASSERT_FALSE(reader.ReadVec(&read, sizeof(raw_data) * 2)); + + ASSERT_TRUE(0 == read.size()); + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, ReadVectorNullSource) { + BufferReader reader(nullptr, 16); + + std::vector read; + ASSERT_FALSE(reader.ReadVec(&read, 4)); + + ASSERT_TRUE(0 == read.size()); + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, ReadVectorNullReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.ReadVec(nullptr, 4)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, ReadVectorNone) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + std::vector read; + ASSERT_TRUE(reader.ReadVec(&read, 0)); + + ASSERT_TRUE(0 == read.size()); + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read4Into84Bytes) { + ASSERT_TRUE(CheckRead4Into8(0xFFFFFF)); +} + +TEST_F(BufferReaderTest, Read4Into83Bytes) { + ASSERT_TRUE(CheckRead4Into8(0xFFFF)); +} + +TEST_F(BufferReaderTest, Read4Into82Bytes) { + ASSERT_TRUE(CheckRead4Into8(0xFF)); +} + +TEST_F(BufferReaderTest, Read4Into8Zero) { ASSERT_TRUE(CheckRead4Into8(0)); } + +TEST_F(BufferReaderTest, Read4Into8NullSource) { + BufferReader reader(nullptr, 4); + + uint64_t read; + ASSERT_FALSE(reader.Read4Into8(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read4Into8TooLittleData) { + uint8_t raw_data[2]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + uint64_t read; + ASSERT_FALSE(reader.Read4Into8(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read4Into8NoReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read4Into8(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read4sInto8s4Bytes) { + ASSERT_TRUE(CheckRead4sInto8s(0x0FFFFFFF)); +} + +TEST_F(BufferReaderTest, Read4sInto8s3Bytes) { + ASSERT_TRUE(CheckRead4sInto8s(0xFFFFFF)); +} + +TEST_F(BufferReaderTest, Read4sInto8s2Bytes) { + ASSERT_TRUE(CheckRead4sInto8s(0xFFFF)); +} + +TEST_F(BufferReaderTest, Read4sInto8s1Bytes) { + ASSERT_TRUE(CheckRead4sInto8s(0xFF)); +} + +TEST_F(BufferReaderTest, Read4sInto8sZero) { + ASSERT_TRUE(CheckRead4sInto8s(0)); +} + +TEST_F(BufferReaderTest, Read4sInto8sNegative) { + ASSERT_TRUE(CheckRead4sInto8s(-100)); +} + +TEST_F(BufferReaderTest, Read4sInto8sNullSource) { + BufferReader reader(nullptr, 4); + + int64_t read; + ASSERT_FALSE(reader.Read4sInto8s(&read)); + + ASSERT_TRUE(ValidateReader(reader, nullptr, 0, 0)); +} + +TEST_F(BufferReaderTest, Read4sInto8sTooLittleData) { + uint8_t raw_data[2]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + int64_t read; + ASSERT_FALSE(reader.Read4sInto8s(&read)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, Read4sInto8sNoReturn) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.Read4sInto8s(nullptr)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, SkipBytesNone) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_TRUE(reader.SkipBytes(0)); + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} + +TEST_F(BufferReaderTest, SkipBytes) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_TRUE(reader.SkipBytes(4)); + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 4)); +} + +TEST_F(BufferReaderTest, SkipBytesTooLarge) { + uint8_t raw_data[16]; + PopulateData(raw_data, sizeof(raw_data)); + + BufferReader reader(raw_data, sizeof(raw_data)); + + ASSERT_FALSE(reader.SkipBytes(sizeof(raw_data) * 2)); + + ASSERT_TRUE(ValidateData(raw_data, sizeof(raw_data))); + ASSERT_TRUE(ValidateReader(reader, raw_data, sizeof(raw_data), 0)); +} +} // namespace wvutil diff --git a/util/test/hls_attribute_list_unittest.cpp b/util/test/hls_attribute_list_unittest.cpp new file mode 100644 index 0000000..1d789c2 --- /dev/null +++ b/util/test/hls_attribute_list_unittest.cpp @@ -0,0 +1,1385 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "hls_attribute_list.h" + +#include + +#include +#include +#include +#include + +#include + +#include "string_conversions.h" +#include "string_utils.h" + +namespace wvutil { +namespace test { +namespace { +constexpr uint64_t kZeroInt = 0; + +constexpr uint64_t kMaxInt = std::numeric_limits::max(); +constexpr char kMaxIntString[] = "18446744073709551615"; +// FYI: 2^52 is the largest float value with whole integer +// precision. +constexpr double kMaxIntFloat = 9007199254740992.0; +constexpr uint64_t kMaxIntFloatInt = 9007199254740992; +constexpr char kMaxIntFloatString[] = "9007199254740992"; + +} // namespace + +TEST(HlsAttributeListTest, IsValidName) { + // Valid attribute names. + EXPECT_TRUE(HlsAttributeList::IsValidName("METHOD")); + EXPECT_TRUE(HlsAttributeList::IsValidName("START-DATE")); + EXPECT_TRUE(HlsAttributeList::IsValidName("SCTE35-CMD")); + EXPECT_TRUE(HlsAttributeList::IsValidName("SCTE35-IN")); + EXPECT_TRUE(HlsAttributeList::IsValidName("GROUP-ID")); + EXPECT_TRUE(HlsAttributeList::IsValidName("X-WIDEVINE-1337")); + + // Unlikely but still valid. + EXPECT_TRUE(HlsAttributeList::IsValidName("-")); + EXPECT_TRUE(HlsAttributeList::IsValidName("-----------")); + EXPECT_TRUE(HlsAttributeList::IsValidName("123456789")); + EXPECT_TRUE(HlsAttributeList::IsValidName("-0")); + + // Invalid attribute names. + EXPECT_FALSE(HlsAttributeList::IsValidName("")); + EXPECT_FALSE(HlsAttributeList::IsValidName("lower-case")); + EXPECT_FALSE(HlsAttributeList::IsValidName("sOME-LOWER-CASE")); + EXPECT_FALSE(HlsAttributeList::IsValidName("SOME-LOWER-CASe")); + EXPECT_FALSE(HlsAttributeList::IsValidName("sOME-LOwER-CASE")); + EXPECT_FALSE(HlsAttributeList::IsValidName("METHOD=NONE")); + EXPECT_FALSE(HlsAttributeList::IsValidName("WHITE SPACE")); + EXPECT_FALSE(HlsAttributeList::IsValidName(" WHITE-SPACE")); + EXPECT_FALSE(HlsAttributeList::IsValidName("WHITE-SPACE ")); + EXPECT_FALSE(HlsAttributeList::IsValidName("NO,COMMA")); + EXPECT_FALSE(HlsAttributeList::IsValidName("NO\"QUOTE")); + EXPECT_FALSE(HlsAttributeList::IsValidName("NO\nLF")); + EXPECT_FALSE(HlsAttributeList::IsValidName("NO\rCR")); +} + +TEST(HlsAttributeListTest, IsValidEnumStringValue) { + // Valid from HLS standard. + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("NONE")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("AES-128")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("SAMPLE-AES")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("CLOSED-CAPTIONS")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("YES")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("TYPE-0")); + // Other valid values. + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("lower-case")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("------")); + // All punctuation except commas and double quotes. + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("!@#$%^&*()")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("-=_+{}[]\\|")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue(";:'<>./?~`")); + // Values that look like other types. + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("1234")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("-12.3")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("0x1234")); + EXPECT_TRUE(HlsAttributeList::IsValidEnumStringValue("1920x1080")); + + // Invalid + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue("")); + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue("WHITE SPACE")); + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue(" WHITESPACE")); + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue("WHITESPACE ")); + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue("NO,COMMAS")); + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue("NO\"QUOTE")); + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue("NO\nLF")); + EXPECT_FALSE(HlsAttributeList::IsValidEnumStringValue("NO\rCR")); +} + +TEST(HlsAttributeListTest, IsValidQuotedStringValue) { + // Valid from standard. + EXPECT_TRUE(HlsAttributeList::IsValidQuotedStringValue( + "uri:text/plain;base64,SGVsbG8sIFdvcmxkIQ==")); + EXPECT_TRUE(HlsAttributeList::IsValidQuotedStringValue("com.widevine")); + EXPECT_TRUE(HlsAttributeList::IsValidQuotedStringValue("1/2/5")); + EXPECT_TRUE( + HlsAttributeList::IsValidQuotedStringValue("2024-07-31T11:43:55-12:00")); + + EXPECT_TRUE(HlsAttributeList::IsValidQuotedStringValue("")); + EXPECT_TRUE( + HlsAttributeList::IsValidQuotedStringValue("white space, and commas")); + + EXPECT_FALSE(HlsAttributeList::IsValidQuotedStringValue("NO\"QUOTE")); + EXPECT_FALSE(HlsAttributeList::IsValidQuotedStringValue("NO\nLF")); + EXPECT_FALSE(HlsAttributeList::IsValidQuotedStringValue("NO\rCR")); +} + +TEST(HlsAttributeListTest, EmptyList) { + const HlsAttributeList list; + constexpr size_t kZero = 0; + EXPECT_EQ(list.Count(), kZero); + EXPECT_TRUE(list.IsEmpty()); + const std::vector kEmptyVector; + EXPECT_EQ(list.GetNames(), kEmptyVector); + EXPECT_FALSE(list.Contains("METHOD")); +} + +TEST(HlsAttributeListTest, SetEnumString) { + HlsAttributeList list; + EXPECT_TRUE(list.SetEnumString("METHOD", "NONE")); + EXPECT_TRUE(list.SetEnumString("TYPE", "AUDIO")); + EXPECT_TRUE(list.SetEnumString("DEFAULT", "NO")); + // Overwrite. + EXPECT_TRUE(list.SetEnumString("TYPE", "VIDEO")); + + EXPECT_EQ(list.Count(), static_cast(3)); + EXPECT_TRUE(list.Contains("METHOD")); + EXPECT_FALSE(list.Contains("AUDIO")); + + const std::vector kExpectedNames = {"DEFAULT", "METHOD", "TYPE"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + std::string value; + EXPECT_TRUE(list.GetEnumString("METHOD", &value)); + EXPECT_EQ(value, "NONE"); + EXPECT_TRUE(list.GetEnumString("TYPE", &value)); + EXPECT_EQ(value, "VIDEO"); + EXPECT_TRUE(list.GetEnumString("DEFAULT", &value)); + EXPECT_EQ(value, "NO"); + + EXPECT_FALSE(list.GetEnumString("INSTREAM-ID", &value)); +} + +TEST(HlsAttributeListTest, SetEnumString_Invalid) { + HlsAttributeList list; + + EXPECT_FALSE(list.SetEnumString("", "NONE")); + EXPECT_FALSE(list.SetEnumString("BAD NAME", "NONE")); + EXPECT_FALSE(list.SetEnumString("bad-name", "NONE")); + + EXPECT_FALSE(list.SetEnumString("METHOD", "")); + EXPECT_FALSE(list.SetEnumString("METHOD", "BAD VALUE")); + EXPECT_FALSE(list.SetEnumString("METHOD", "BAD,VALUE")); + EXPECT_FALSE(list.SetEnumString("METHOD", "BAD\"VALUE")); + + EXPECT_TRUE(list.IsEmpty()); +} + +TEST(HlsAttributeListTest, SetQuotedString) { + HlsAttributeList list; + const std::string kUriValue = "uri:text/plain;base64,SGVsbG8sIFdvcmxkIQ=="; + EXPECT_TRUE(list.SetQuotedString("URI", kUriValue)); + EXPECT_TRUE(list.SetQuotedString("VERSION", "")); + EXPECT_TRUE(list.SetQuotedString("MESSAGE", "Hello, world!")); + + EXPECT_EQ(list.Count(), static_cast(3)); + EXPECT_TRUE(list.Contains("VERSION")); + EXPECT_FALSE(list.Contains("AUDIO")); + + const std::vector kExpectedNames = {"MESSAGE", "URI", "VERSION"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + std::string value; + EXPECT_TRUE(list.GetQuotedString("URI", &value)); + EXPECT_EQ(value, kUriValue); + EXPECT_TRUE(list.GetQuotedString("VERSION", &value)); + EXPECT_EQ(value, ""); + EXPECT_TRUE(list.GetQuotedString("MESSAGE", &value)); + EXPECT_EQ(value, "Hello, world!"); + + EXPECT_FALSE(list.GetQuotedString("INSTREAM-ID", &value)); +} + +TEST(HlsAttributeListTest, SetQuotedString_Invalid) { + HlsAttributeList list; + EXPECT_FALSE(list.SetQuotedString("", "NONE")); + EXPECT_FALSE(list.SetQuotedString("BAD NAME", "NONE")); + EXPECT_FALSE(list.SetQuotedString("bad-name", "NONE")); + + EXPECT_FALSE(list.SetQuotedString("METHOD", "BAD\rVALUE")); + EXPECT_FALSE(list.SetQuotedString("METHOD", "BAD\nVALUE")); + EXPECT_FALSE(list.SetQuotedString("METHOD", "BAD\"VALUE")); + + EXPECT_TRUE(list.IsEmpty()); +} + +TEST(HlsAttributeListTest, SetHexSequenceAsString) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetHexSequence("IV", "some IV value")); + EXPECT_TRUE(list.SetHexSequence("OTHERWISE-FORBIDDEN", "\r\n \"\b")); + EXPECT_TRUE(list.SetHexSequence("SMALL", "-")); + + EXPECT_EQ(list.Count(), static_cast(3)); + EXPECT_TRUE(list.Contains("IV")); + EXPECT_FALSE(list.Contains("AUDIO")); + + const std::vector kExpectedNames = {"IV", "OTHERWISE-FORBIDDEN", + "SMALL"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + std::string value; + EXPECT_TRUE(list.GetHexSequence("IV", &value)); + EXPECT_EQ(value, "some IV value"); + EXPECT_TRUE(list.GetHexSequence("OTHERWISE-FORBIDDEN", &value)); + EXPECT_EQ(value, "\r\n \"\b"); + EXPECT_TRUE(list.GetHexSequence("SMALL", &value)); + EXPECT_EQ(value, "-"); + + EXPECT_FALSE(list.GetHexSequence("INSTREAM-ID", &value)); +} + +TEST(HlsAttributeListTest, SetHexSequenceAsString_Invalid) { + HlsAttributeList list; + EXPECT_FALSE(list.SetHexSequence("", "NONE")); + EXPECT_FALSE(list.SetHexSequence("BAD NAME", "NONE")); + EXPECT_FALSE(list.SetHexSequence("bad-name", "NONE")); + + EXPECT_FALSE(list.SetHexSequence("METHOD", "")); + EXPECT_TRUE(list.IsEmpty()); +} + +TEST(HlsAttributeListTest, SetHexSequenceAsVector) { + HlsAttributeList list; + + const std::vector kIvValue(16, 0x33); + EXPECT_TRUE(list.SetHexSequence("IV", kIvValue)); + const std::vector kSmallValue(1, 0xff); + EXPECT_TRUE(list.SetHexSequence("SMALL", kSmallValue)); + + EXPECT_EQ(list.Count(), static_cast(2)); + EXPECT_TRUE(list.Contains("IV")); + EXPECT_FALSE(list.Contains("AUDIO")); + + const std::vector kExpectedNames = {"IV", "SMALL"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + std::vector value; + EXPECT_TRUE(list.GetHexSequence("IV", &value)); + EXPECT_EQ(value, kIvValue); + EXPECT_TRUE(list.GetHexSequence("SMALL", &value)); + EXPECT_EQ(value, kSmallValue); + + EXPECT_FALSE(list.GetHexSequence("INSTREAM-ID", &value)); +} + +TEST(HlsAttributeListTest, SetHexSequenceAsVector_Invalid) { + HlsAttributeList list; + const std::vector kValidValue(4, 0x20); + EXPECT_FALSE(list.SetHexSequence("", kValidValue)); + EXPECT_FALSE(list.SetHexSequence("BAD NAME", kValidValue)); + EXPECT_FALSE(list.SetHexSequence("bad-name", kValidValue)); + + const std::vector kEmptyVector; + EXPECT_FALSE(list.SetHexSequence("METHOD", kEmptyVector)); + EXPECT_TRUE(list.IsEmpty()); +} + +TEST(HlsAttributeListTest, SetInteger) { + HlsAttributeList list; + const uint64_t kInStreamId = 32; + EXPECT_TRUE(list.SetInteger("INSTREAM-ID", kInStreamId)); + EXPECT_TRUE(list.SetInteger("ZERO", kZeroInt)); + EXPECT_TRUE(list.SetInteger("MAX", kMaxInt)); + + EXPECT_EQ(list.Count(), static_cast(3)); + EXPECT_TRUE(list.Contains("ZERO")); + EXPECT_FALSE(list.Contains("AUDIO")); + + const std::vector kExpectedNames = {"INSTREAM-ID", "MAX", + "ZERO"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + uint64_t value = 0; + EXPECT_TRUE(list.GetInteger("INSTREAM-ID", &value)); + EXPECT_EQ(value, kInStreamId); + EXPECT_TRUE(list.GetInteger("ZERO", &value)); + EXPECT_EQ(value, kZeroInt); + EXPECT_TRUE(list.GetInteger("MAX", &value)); + EXPECT_EQ(value, kMaxInt); + + EXPECT_FALSE(list.GetInteger("VERSION", &value)); +} + +TEST(HlsAttributeListTest, SetInteger_Invalid) { + HlsAttributeList list; + const uint64_t kValidValue = 1337; + EXPECT_FALSE(list.SetInteger("", kValidValue)); + EXPECT_FALSE(list.SetInteger("BAD NAME", kValidValue)); + EXPECT_FALSE(list.SetInteger("bad-name", kValidValue)); +} + +TEST(HlsAttributeListTest, SetFloat) { + HlsAttributeList list; + const double kFrameRate = 29.97; + EXPECT_TRUE(list.SetFloat("FRAME-RATE", kFrameRate)); + const double kTimeOffset = -5.5; + EXPECT_TRUE(list.SetFloat("TIME-OFFSET", kTimeOffset)); + const double kZero = 0.0; + EXPECT_TRUE(list.SetFloat("ZERO", kZero)); + + EXPECT_EQ(list.Count(), static_cast(3)); + EXPECT_TRUE(list.Contains("TIME-OFFSET")); + EXPECT_FALSE(list.Contains("AUDIO")); + + const std::vector kExpectedNames = {"FRAME-RATE", "TIME-OFFSET", + "ZERO"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + double value = 0; + EXPECT_TRUE(list.GetFloat("FRAME-RATE", &value)); + EXPECT_EQ(value, kFrameRate); + EXPECT_TRUE(list.GetFloat("ZERO", &value)); + EXPECT_EQ(value, kZero); + EXPECT_TRUE(list.GetFloat("TIME-OFFSET", &value)); + EXPECT_EQ(value, kTimeOffset); + + EXPECT_FALSE(list.GetFloat("VERSION", &value)); +} + +TEST(HlsAttributeListTest, SetFloat_Invalid) { + HlsAttributeList list; + const double kValidValue = 1337.0; + EXPECT_FALSE(list.SetFloat("", kValidValue)); + EXPECT_FALSE(list.SetFloat("BAD NAME", kValidValue)); + EXPECT_FALSE(list.SetFloat("bad-name", kValidValue)); +} + +TEST(HlsAttributeListTest, SetResolution) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetResolution("RESOLUTION", 1920, 1080)); + EXPECT_TRUE(list.SetResolution("SVGA-RESOLUTION", 800, 600)); + EXPECT_TRUE(list.SetResolution("NULL-RESOLUTION", 0, 0)); + + const std::vector kExpectedNames = { + "NULL-RESOLUTION", "RESOLUTION", "SVGA-RESOLUTION"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + uint64_t width = 0; + uint64_t height = 0; + EXPECT_TRUE(list.GetResolution("RESOLUTION", &width, &height)); + EXPECT_EQ(width, static_cast(1920)); + EXPECT_EQ(height, static_cast(1080)); + EXPECT_TRUE(list.GetResolution("NULL-RESOLUTION", &width, &height)); + EXPECT_EQ(width, kZeroInt); + EXPECT_EQ(height, kZeroInt); + EXPECT_TRUE(list.GetResolution("SVGA-RESOLUTION", &width, &height)); + EXPECT_EQ(width, static_cast(800)); + EXPECT_EQ(height, static_cast(600)); + + EXPECT_FALSE(list.GetResolution("VERSION", &width, &height)); +} + +TEST(HlsAttributeListTest, SetResolution_Invalid) { + HlsAttributeList list; + EXPECT_FALSE(list.SetResolution("", 1920, 1080)); + EXPECT_FALSE(list.SetResolution("BAD NAME", 1920, 1080)); + EXPECT_FALSE(list.SetResolution("bad-name", 1920, 1080)); +} + +TEST(HlsAttributeListTest, Remove) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetEnumString("METHOD", "SAMPLE-AES")); + EXPECT_TRUE(list.SetQuotedString("VERSION", "1/2/5")); + const std::vector kIv(16, 0x42); + EXPECT_TRUE(list.SetHexSequence("IV", kIv)); + const uint64_t kInStreamId = 32; + EXPECT_TRUE(list.SetInteger("INSTREAM-ID", kInStreamId)); + const double kFrameRate = 29.97; + EXPECT_TRUE(list.SetFloat("FRAME-RATE", kFrameRate)); + const uint64_t kWidth = 1920; + const uint64_t kHeight = 1080; + EXPECT_TRUE(list.SetResolution("RESOLUTION", kWidth, kHeight)); + + const std::vector kExpectedInitialNames = { + "FRAME-RATE", "INSTREAM-ID", "IV", "METHOD", "RESOLUTION", "VERSION"}; + EXPECT_EQ(list.GetNames(), kExpectedInitialNames); + + EXPECT_TRUE(list.Remove("IV")); + EXPECT_TRUE(list.Remove("VERSION")); + EXPECT_FALSE(list.Remove("TIME-OFFSET")); + + EXPECT_EQ(list.Count(), static_cast(4)); + const std::vector kExpectedFinalNames = { + "FRAME-RATE", "INSTREAM-ID", "METHOD", "RESOLUTION"}; + EXPECT_EQ(list.GetNames(), kExpectedFinalNames); + + std::string string_value; + EXPECT_TRUE(list.GetEnumString("METHOD", &string_value)); + EXPECT_EQ(string_value, "SAMPLE-AES"); + EXPECT_FALSE(list.GetEnumString("VERSION", &string_value)); + std::vector vec_value; + EXPECT_FALSE(list.GetHexSequence("IV", &vec_value)); + uint64_t int_value = 0; + EXPECT_TRUE(list.GetInteger("INSTREAM-ID", &int_value)); + EXPECT_EQ(int_value, kInStreamId); +} + +TEST(HlsAttributeListTest, AmbiguousValues_EnumStringAsInteger) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetEnumString("POSITIVE-INT", "1234")); + // Max u64 2^64-1 + EXPECT_TRUE(list.SetEnumString("MAX-INT", kMaxIntString)); + EXPECT_TRUE(list.SetEnumString("BIG-ZERO-INT", "00000000000000000000")); + // Over max, 2^64 + EXPECT_TRUE(list.SetEnumString("OVER-MAX-INT", "18446744073709551616")); + // Integers cannot be negative + EXPECT_TRUE(list.SetEnumString("NEGATIVE-INT", "-1337")); + // Integers cannot use positive sign. + EXPECT_TRUE(list.SetEnumString("BAD-SIGN", "+42")); + EXPECT_TRUE(list.SetEnumString("NOT-INT", "OTHER")); + + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("MAX-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("MAX-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("BIG-ZERO-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("BIG-ZERO-INT", HlsAttributeList::kIntegerType)); + + EXPECT_TRUE(list.IsType("OVER-MAX-INT", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("OVER-MAX-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("BAD-SIGN", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("BAD-SIGN", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("NEGATIVE-INT", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NEGATIVE-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("NOT-INT", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NOT-INT", HlsAttributeList::kIntegerType)); + + uint64_t value = 0; + EXPECT_TRUE(list.GetInteger("POSITIVE-INT", &value)); + EXPECT_EQ(value, static_cast(1234)); + EXPECT_TRUE(list.GetInteger("MAX-INT", &value)); + EXPECT_EQ(value, kMaxInt); + EXPECT_TRUE(list.GetInteger("BIG-ZERO-INT", &value)); + EXPECT_EQ(value, kZeroInt); + + EXPECT_FALSE(list.GetInteger("OVER-MAX-INT", &value)); + EXPECT_FALSE(list.GetInteger("NEGATIVE-INT", &value)); + EXPECT_FALSE(list.GetInteger("BAD-SIGN", &value)); + EXPECT_FALSE(list.GetInteger("NOT-INT", &value)); +} + +TEST(HlsAttributeListTest, AmbiguousValues_EnumStringAsFloat) { + HlsAttributeList list; + EXPECT_TRUE(list.SetEnumString("POSITIVE-INT", "1234")); + EXPECT_TRUE(list.SetEnumString("NEGATIVE-INT", "-1337")); + EXPECT_TRUE(list.SetEnumString("POSITIVE-FLOAT", "29.97")); + EXPECT_TRUE(list.SetEnumString("NEGATIVE-FLOAT", "-180.0")); + EXPECT_TRUE(list.SetEnumString("MAX-INT-FLOAT", kMaxIntFloatString)); + EXPECT_TRUE(list.SetEnumString("NO-FRAC", "12.")); + EXPECT_TRUE(list.SetEnumString("NO-INT", ".345")); + // Floats cannot have positive signs. + EXPECT_TRUE(list.SetEnumString("BAD-SIGN", "+12.345")); + EXPECT_TRUE(list.SetEnumString("NOT-FLOAT", "OTHER")); + + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("NEGATIVE-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("NEGATIVE-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("POSITIVE-FLOAT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("POSITIVE-FLOAT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("NEGATIVE-FLOAT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("NEGATIVE-FLOAT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("MAX-INT-FLOAT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("MAX-INT-FLOAT", HlsAttributeList::kFloatType)); + + EXPECT_TRUE(list.IsType("NO-FRAC", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NO-FRAC", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("NO-INT", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NO-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("BAD-SIGN", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("BAD-SIGN", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("NOT-FLOAT", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NOT-FLOAT", HlsAttributeList::kFloatType)); + + double value = 0.0; + EXPECT_TRUE(list.GetFloat("POSITIVE-INT", &value)); + EXPECT_EQ(value, 1234.0); + EXPECT_TRUE(list.GetFloat("NEGATIVE-INT", &value)); + EXPECT_EQ(value, -1337.0); + EXPECT_TRUE(list.GetFloat("POSITIVE-FLOAT", &value)); + EXPECT_EQ(value, 29.97); + EXPECT_TRUE(list.GetFloat("NEGATIVE-FLOAT", &value)); + EXPECT_EQ(value, -180.0); + EXPECT_TRUE(list.GetFloat("MAX-INT-FLOAT", &value)); + EXPECT_EQ(value, kMaxIntFloat); + + EXPECT_FALSE(list.GetFloat("NO-FRAC", &value)); + EXPECT_FALSE(list.GetFloat("NO-INT", &value)); + EXPECT_FALSE(list.GetFloat("BAD-SIGN", &value)); + EXPECT_FALSE(list.GetFloat("NOT-FLOAT", &value)); +} + +TEST(HlsAttributeListTest, AmbiguousValues_EnumStringAsHexSequence) { + HlsAttributeList list; + // Hex of "Hello, World!" + EXPECT_TRUE(list.SetEnumString("VALID-HEX", "0x48656C6C6F2C20576F726C6421")); + // Hex does not need to be even length. + EXPECT_TRUE(list.SetEnumString("ODD-HEX", "0X1020304")); + // Lower case is not valid HLS hex. + EXPECT_TRUE( + list.SetEnumString("INVALID-HEX", "0x48656c6c6f2c20576f726c6421")); + EXPECT_TRUE(list.SetEnumString("NULL-HEX", "0X")); + EXPECT_TRUE(list.SetEnumString("NOT-HEX", "OTHER")); + + EXPECT_TRUE(list.IsType("VALID-HEX", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("VALID-HEX", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("ODD-HEX", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("ODD-HEX", HlsAttributeList::kHexSequenceType)); + + EXPECT_TRUE(list.IsType("INVALID-HEX", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("INVALID-HEX", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("NULL-HEX", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NULL-HEX", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("NOT-HEX", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NOT-HEX", HlsAttributeList::kHexSequenceType)); + + std::string string_value; + EXPECT_TRUE(list.GetHexSequence("VALID-HEX", &string_value)); + EXPECT_EQ(string_value, "Hello, World!"); + std::vector vec_value; + EXPECT_TRUE(list.GetHexSequence("ODD-HEX", &vec_value)); + const std::vector kOddHexValue = {1, 2, 3, 4}; + EXPECT_EQ(vec_value, kOddHexValue); + + EXPECT_FALSE(list.GetHexSequence("INVALID-HEX", &string_value)); + EXPECT_FALSE(list.GetHexSequence("NULL-HEX", &string_value)); + EXPECT_FALSE(list.GetHexSequence("NOT-HEX", &string_value)); + + EXPECT_FALSE(list.GetHexSequence("INVALID-HEX", &vec_value)); + EXPECT_FALSE(list.GetHexSequence("NULL-HEX", &vec_value)); + EXPECT_FALSE(list.GetHexSequence("NOT-HEX", &vec_value)); +} + +TEST(HlsAttributeListTest, AmbiguousValues_EnumStringAsResolution) { + HlsAttributeList list; + EXPECT_TRUE(list.SetEnumString("VALID-RES", "1920x1080")); + EXPECT_TRUE(list.SetEnumString("ZERO-RES", "0x0")); + const std::string kMaxIntResString = + std::to_string(kMaxInt) + "x" + std::to_string(kMaxInt); + EXPECT_TRUE(list.SetEnumString("MAX-RES", kMaxIntResString)); + // Resolution separator must be lower 'x'. + EXPECT_TRUE(list.SetEnumString("CAP-SEP-RES", "1920X1080")); + EXPECT_TRUE(list.SetEnumString("MISSING-WIDTH", "x1080")); + EXPECT_TRUE(list.SetEnumString("MISSING-HEIGHT", "1920x")); + EXPECT_TRUE(list.SetEnumString("BAD-WIDTH", "BADx1080")); + EXPECT_TRUE(list.SetEnumString("BAD-HEIGHT", "1920xBAD")); + EXPECT_TRUE(list.SetEnumString("NOT-RES", "OTHER")); + + EXPECT_TRUE(list.IsType("VALID-RES", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("VALID-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kResolutionType)); + + EXPECT_TRUE(list.IsType("CAP-SEP-RES", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("CAP-SEP-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("MISSING-WIDTH", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("MISSING-WIDTH", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("MISSING-HEIGHT", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE( + list.IsType("MISSING-HEIGHT", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("BAD-WIDTH", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("BAD-WIDTH", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("BAD-HEIGHT", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("BAD-HEIGHT", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("NOT-RES", HlsAttributeList::kEnumStringType)); + EXPECT_FALSE(list.IsType("NOT-RES", HlsAttributeList::kResolutionType)); + + uint64_t width = 0; + uint64_t height = 0; + EXPECT_TRUE(list.GetResolution("VALID-RES", &width, &height)); + EXPECT_EQ(width, static_cast(1920)); + EXPECT_EQ(height, static_cast(1080)); + EXPECT_TRUE(list.GetResolution("ZERO-RES", &width, &height)); + EXPECT_EQ(width, kZeroInt); + EXPECT_EQ(height, kZeroInt); + EXPECT_TRUE(list.GetResolution("MAX-RES", &width, &height)); + EXPECT_EQ(width, kMaxInt); + EXPECT_EQ(height, kMaxInt); + + EXPECT_FALSE(list.GetResolution("CAP-SEP-RES", &width, &height)); + EXPECT_FALSE(list.GetResolution("MISSING-WIDTH", &width, &height)); + EXPECT_FALSE(list.GetResolution("MISSING-HEIGHT", &width, &height)); + EXPECT_FALSE(list.GetResolution("BAD-WIDTH", &width, &height)); + EXPECT_FALSE(list.GetResolution("BAD-HEIGHT", &width, &height)); + EXPECT_FALSE(list.GetResolution("NOT-RES", &width, &height)); +} + +// Note: All valid hex sequences are also valid enum strings. +TEST(HlsAttributeListTest, AmbiguousValues_HexSequenceAsEnumString) { + HlsAttributeList list; + + const std::vector kValueData = a2b_hex("DEADBEAF"); + EXPECT_TRUE(list.SetHexSequence("VALUE", kValueData)); + + EXPECT_TRUE(list.IsType("VALUE", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("VALUE", HlsAttributeList::kEnumStringType)); + + std::string value; + EXPECT_TRUE(list.GetEnumString("VALUE", &value)); + // Our implementation will use "0x" prefix. + EXPECT_EQ(value, "0xDEADBEAF"); +} + +TEST(HlsAttributeListTest, AmbiguousValues_HexSequenceAsResolution) { + HlsAttributeList list; + + const std::vector kValidResData = a2b_hex("1920"); + EXPECT_TRUE(list.SetHexSequence("VALID-RES", kValidResData)); + const std::vector kZeroResData = {0}; + EXPECT_TRUE(list.SetHexSequence("ZERO-RES", kZeroResData)); + // Max integer length is 20 (10 bytes -> 20 hex digits) + const std::vector kBigZeroResData(10, 0); + EXPECT_TRUE(list.SetHexSequence("BIG-ZERO-RES", kBigZeroResData)); + const std::vector kTooBigZeroResData(11, 0); + EXPECT_TRUE(list.SetHexSequence("TOO-BIG-ZERO-RES", kTooBigZeroResData)); + // Max integer as hex. + const std::vector kMaxResData = a2b_hex(kMaxIntString); + EXPECT_TRUE(list.SetHexSequence("MAX-RES", kMaxResData)); + // Over max as hex. + const std::vector kOverMaxResData = a2b_hex("18446744073709551616"); + EXPECT_TRUE(list.SetHexSequence("OVER-MAX-RES", kOverMaxResData)); + // Resolution must be decimal, not hex. + const std::vector kNotDecimalResData = a2b_hex("B000"); + EXPECT_TRUE(list.SetHexSequence("NOT-DEC-RES", kNotDecimalResData)); + + EXPECT_TRUE(list.IsType("VALID-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("VALID-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("BIG-ZERO-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("BIG-ZERO-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kResolutionType)); + + EXPECT_TRUE( + list.IsType("TOO-BIG-ZERO-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_FALSE( + list.IsType("TOO-BIG-ZERO-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("OVER-MAX-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_FALSE(list.IsType("OVER-MAX-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("NOT-DEC-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_FALSE(list.IsType("NOT-DEC-RES", HlsAttributeList::kResolutionType)); + + uint64_t width = 0; + uint64_t height = 0; + EXPECT_TRUE(list.GetResolution("VALID-RES", &width, &height)); + EXPECT_EQ(width, kZeroInt); + EXPECT_EQ(height, static_cast(1920)); + EXPECT_TRUE(list.GetResolution("ZERO-RES", &width, &height)); + EXPECT_EQ(width, kZeroInt); + EXPECT_EQ(height, kZeroInt); + EXPECT_TRUE(list.GetResolution("BIG-ZERO-RES", &width, &height)); + EXPECT_EQ(width, kZeroInt); + EXPECT_EQ(height, kZeroInt); + EXPECT_TRUE(list.GetResolution("MAX-RES", &width, &height)); + EXPECT_EQ(width, kZeroInt); + EXPECT_EQ(height, kMaxInt); + + EXPECT_FALSE(list.GetResolution("TOO-BIG-ZERO-RES", &width, &height)); + EXPECT_FALSE(list.GetResolution("OVER-MAX-RES", &width, &height)); + EXPECT_FALSE(list.GetResolution("NOT-DEC-RES", &width, &height)); +} + +TEST(HlsAttributeListTest, AmbiguousValues_FloatAsInteger) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetFloat("ZERO-INT", 0.0)); + EXPECT_TRUE(list.SetFloat("POSITIVE-INT", 123.0)); + EXPECT_TRUE(list.SetFloat("NOT-WHOLE", 123.456)); + EXPECT_TRUE(list.SetFloat("FRAC-ONLY", 0.456)); + EXPECT_TRUE(list.SetFloat("NEGATIVE-INT", -10.0)); + EXPECT_TRUE(list.SetFloat("MAX-FLOAT-INT", kMaxIntFloat)); + + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("MAX-FLOAT-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("MAX-FLOAT-INT", HlsAttributeList::kIntegerType)); + + EXPECT_TRUE(list.IsType("NOT-WHOLE", HlsAttributeList::kFloatType)); + EXPECT_FALSE(list.IsType("NOT-WHOLE", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("FRAC-ONLY", HlsAttributeList::kFloatType)); + EXPECT_FALSE(list.IsType("FRAC-ONLY", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("NEGATIVE-INT", HlsAttributeList::kFloatType)); + EXPECT_FALSE(list.IsType("NEGATIVE-INT", HlsAttributeList::kIntegerType)); + + uint64_t value = 0; + EXPECT_TRUE(list.GetInteger("ZERO-INT", &value)); + EXPECT_EQ(value, kZeroInt); + EXPECT_TRUE(list.GetInteger("POSITIVE-INT", &value)); + EXPECT_EQ(value, static_cast(123)); + EXPECT_TRUE(list.GetInteger("MAX-FLOAT-INT", &value)); + EXPECT_EQ(value, kMaxIntFloatInt); + + EXPECT_FALSE(list.GetInteger("NOT-WHOLE", &value)); + EXPECT_FALSE(list.GetInteger("FRAC-ONLY", &value)); + EXPECT_FALSE(list.GetInteger("NEGATIVE-INT", &value)); +} + +// Note: All floating point values are valid enum strings. +TEST(HlsAttributeListTest, AmbiguousValues_FloatAsEnumString) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetFloat("ZERO-INT", 0.0)); + EXPECT_TRUE(list.SetFloat("POSITIVE-INT", 123.0)); + EXPECT_TRUE(list.SetFloat("NOT-WHOLE", 123.456)); + EXPECT_TRUE(list.SetFloat("FRAC-ONLY", 0.456)); + EXPECT_TRUE(list.SetFloat("NEGATIVE-INT", -10.0)); + EXPECT_TRUE(list.SetFloat("MAX-FLOAT-INT", kMaxIntFloat)); + + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("POSITIVE-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("MAX-FLOAT-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("MAX-FLOAT-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("NOT-WHOLE", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("NOT-WHOLE", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("FRAC-ONLY", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("FRAC-ONLY", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("NEGATIVE-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("NEGATIVE-INT", HlsAttributeList::kEnumStringType)); + + std::string value; + EXPECT_TRUE(list.GetEnumString("ZERO-INT", &value)); + EXPECT_EQ(value, "0"); + EXPECT_TRUE(list.GetEnumString("POSITIVE-INT", &value)); + EXPECT_EQ(value, "123"); + EXPECT_TRUE(list.GetEnumString("NOT-WHOLE", &value)); + EXPECT_EQ(value, "123.456"); + EXPECT_TRUE(list.GetEnumString("FRAC-ONLY", &value)); + EXPECT_EQ(value, "0.456"); + EXPECT_TRUE(list.GetEnumString("NEGATIVE-INT", &value)); + EXPECT_EQ(value, "-10"); + EXPECT_TRUE(list.GetEnumString("MAX-FLOAT-INT", &value)); + EXPECT_EQ(value, kMaxIntFloatString); +} + +// Note: All integers values are valid enum strings. +TEST(HlsAttributeListTest, AmbiguousValues_IntegerAsEnumString) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetInteger("ZERO-INT", kZeroInt)); + EXPECT_TRUE(list.SetInteger("SOME-INT", 123456)); + EXPECT_TRUE(list.SetInteger("MAX-INT", kMaxInt)); + + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("SOME-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("SOME-INT", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("MAX-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("MAX-INT", HlsAttributeList::kEnumStringType)); + + std::string value; + EXPECT_TRUE(list.GetEnumString("ZERO-INT", &value)); + EXPECT_EQ(value, "0"); + EXPECT_TRUE(list.GetEnumString("SOME-INT", &value)); + EXPECT_EQ(value, "123456"); + EXPECT_TRUE(list.GetEnumString("MAX-INT", &value)); + EXPECT_EQ(value, kMaxIntString); +} + +// Note: All integer values appear as valid float values; +// however, double-precision floats have limited range. +TEST(HlsAttributeListTest, AmbiguousValues_IntegerAsFloat) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetInteger("ZERO-INT", kZeroInt)); + EXPECT_TRUE(list.SetInteger("SOME-INT", 123456)); + EXPECT_TRUE(list.SetInteger("MAX-INT-FLOAT", kMaxIntFloat)); + + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("ZERO-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("SOME-INT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("SOME-INT", HlsAttributeList::kFloatType)); + EXPECT_TRUE(list.IsType("MAX-INT-FLOAT", HlsAttributeList::kIntegerType)); + EXPECT_TRUE(list.IsType("MAX-INT-FLOAT", HlsAttributeList::kFloatType)); + + double value; + EXPECT_TRUE(list.GetFloat("ZERO-INT", &value)); + EXPECT_EQ(value, 0.0); + EXPECT_TRUE(list.GetFloat("SOME-INT", &value)); + EXPECT_EQ(value, 123456); + EXPECT_TRUE(list.GetFloat("MAX-INT-FLOAT", &value)); + EXPECT_EQ(value, kMaxIntFloat); +} + +// Note: All resolutions are valid enum strings. +TEST(HlsAttributeListTest, AmbiguousValues_ResolutionAsEnumString) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetResolution("ZERO-RES", kZeroInt, kZeroInt)); + EXPECT_TRUE(list.SetResolution("RESOLUTION", 1920, 1080)); + EXPECT_TRUE(list.SetResolution("MAX-RES", kMaxInt, kMaxInt)); + + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("RESOLUTION", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("RESOLUTION", HlsAttributeList::kEnumStringType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kEnumStringType)); + + std::string value; + EXPECT_TRUE(list.GetEnumString("ZERO-RES", &value)); + EXPECT_EQ(value, "0x0"); + EXPECT_TRUE(list.GetEnumString("RESOLUTION", &value)); + EXPECT_EQ(value, "1920x1080"); + EXPECT_TRUE(list.GetEnumString("MAX-RES", &value)); + const std::string kMaxIntResString = + std::to_string(kMaxInt) + "x" + std::to_string(kMaxInt); + EXPECT_EQ(value, kMaxIntResString); +} + +TEST(HlsAttributeListTest, AmbiguousValues_ResolutionAsHexSequence) { + HlsAttributeList list; + + // Note: Only appears are resolution if width is zero. + EXPECT_TRUE(list.SetResolution("ZERO-RES", kZeroInt, kZeroInt)); + EXPECT_TRUE(list.SetResolution("RESOLUTION", kZeroInt, 1080)); + EXPECT_TRUE(list.SetResolution("MAX-RES", kZeroInt, kMaxInt)); + EXPECT_TRUE(list.SetResolution("NON-ZERO-WIDTH", 1920, 1080)); + + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("ZERO-RES", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("RESOLUTION", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("RESOLUTION", HlsAttributeList::kHexSequenceType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kResolutionType)); + EXPECT_TRUE(list.IsType("MAX-RES", HlsAttributeList::kHexSequenceType)); + + EXPECT_TRUE(list.IsType("NON-ZERO-WIDTH", HlsAttributeList::kResolutionType)); + EXPECT_FALSE( + list.IsType("NON-ZERO-WIDTH", HlsAttributeList::kHexSequenceType)); + + std::vector value; + EXPECT_TRUE(list.GetHexSequence("ZERO-RES", &value)); + const std::vector kZeroResData = {0}; + EXPECT_EQ(value, kZeroResData); + EXPECT_TRUE(list.GetHexSequence("RESOLUTION", &value)); + const std::vector kResData = {0x10, 0x80}; + EXPECT_EQ(value, kResData); + EXPECT_TRUE(list.GetHexSequence("MAX-RES", &value)); + const std::vector kMaxResData = a2b_hex(kMaxIntString); + EXPECT_EQ(value, kMaxResData); + + EXPECT_FALSE(list.GetHexSequence("NON-ZERO-WIDTH", &value)); +} + +// Quoted strings are always unambiguous. +TEST(HlsAttributeListTest, UnambiguousValues_QuotedStringAsAnything) { + HlsAttributeList list; + + // Settings to quoted values of otherwise valid types. + EXPECT_TRUE(list.SetQuotedString("INT-LIKE", "1234")); + EXPECT_TRUE(list.SetQuotedString("FLOAT-LIKE", "-12.34")); + EXPECT_TRUE(list.SetQuotedString("ENUM-LIKE", "VALUE")); + EXPECT_TRUE(list.SetQuotedString("RES-LIKE", "1920x1080")); + EXPECT_TRUE(list.SetQuotedString("HEX-LIKE", "0xDEADBEAF")); + + const std::vector kNonQuotedStringTypes = { + HlsAttributeList::kIntegerType, HlsAttributeList::kFloatType, + HlsAttributeList::kEnumStringType, HlsAttributeList::kResolutionType, + HlsAttributeList::kHexSequenceType}; + + const std::vector names = list.GetNames(); + const std::vector kExpectedNames = { + "ENUM-LIKE", "FLOAT-LIKE", "HEX-LIKE", "INT-LIKE", "RES-LIKE"}; + // Validity check, if this fails, something other that what + // this test is checking for has failed. + ASSERT_EQ(names, kExpectedNames); + + for (const auto& name : names) { + EXPECT_TRUE(list.IsType(name, HlsAttributeList::kQuotedStringType)) + << "name = " << name; + for (const auto& type : kNonQuotedStringTypes) { + EXPECT_FALSE(list.IsType(name, type)) + << "name = " << name + << ", type = " << HlsAttributeList::ValueTypeToString(type); + } + } + + // Accessor test for only similar looking types. + uint64_t int_value = 0; + EXPECT_FALSE(list.GetInteger("INT-LIKE", &int_value)); + double float_value = 0.0; + EXPECT_FALSE(list.GetFloat("FLOAT-LIKE", &float_value)); + std::string string_value; + EXPECT_FALSE(list.GetEnumString("ENUM-LIKE", &string_value)); + EXPECT_FALSE(list.GetHexSequence("HEX-LIKE", &string_value)); + std::vector vec_value; + EXPECT_FALSE(list.GetHexSequence("HEX-LIKE", &vec_value)); + uint64_t other_int_value = 0; + EXPECT_FALSE(list.GetResolution("RES-LIKE", &int_value, &other_int_value)); +} + +// Nothing can appear as a quoted string. +TEST(HlsAttributeListTest, UnambiguousValues_AnythingAsQuotedString) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetInteger("INT", 1234)); + EXPECT_TRUE(list.SetHexSequence("HEX", "data")); + EXPECT_TRUE(list.SetEnumString("ENUM", "VALUE")); + EXPECT_TRUE(list.SetFloat("FLOAT", 12.34)); + EXPECT_TRUE(list.SetResolution("RESOLUTION", 1920, 1080)); + + const std::vector kNonQuotedStringTypes = { + HlsAttributeList::kIntegerType, HlsAttributeList::kFloatType, + HlsAttributeList::kEnumStringType, HlsAttributeList::kResolutionType, + HlsAttributeList::kHexSequenceType}; + + const std::vector names = list.GetNames(); + const std::vector kExpectedNames = {"ENUM", "FLOAT", "HEX", + "INT", "RESOLUTION"}; + // Validity check, if this fails, something other that what + // this test is checking for has failed. + ASSERT_EQ(names, kExpectedNames); + + for (const auto& name : names) { + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kQuotedStringType)) + << "name = " << name; + std::string value; + EXPECT_FALSE(list.GetQuotedString(name, &value)); + } +} + +// Integers can never appear are hex sequences or resolutions. +// Note: Also quoted strings, tested above. +TEST(HlsAttributeListTest, UnambiguousValues_Integer) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetInteger("ZERO-INT", kZeroInt)); + EXPECT_TRUE(list.SetInteger("SOME-INT", 1234)); + EXPECT_TRUE(list.SetInteger("MAX-INT", kMaxInt)); + + const std::vector names = list.GetNames(); + ASSERT_EQ(names.size(), static_cast(3)); + + for (const auto& name : names) { + EXPECT_TRUE(list.IsType(name, HlsAttributeList::kIntegerType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kHexSequenceType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kResolutionType)) + << "name = " << name; + + std::string string_value; + EXPECT_FALSE(list.GetHexSequence(name, &string_value)) << "name = " << name; + std::vector vec_value; + EXPECT_FALSE(list.GetHexSequence(name, &vec_value)) << "name = " << name; + uint64_t width = 0; + uint64_t height = 0; + EXPECT_FALSE(list.GetResolution(name, &width, &height)) + << "name = " << name; + } +} + +// Hex sequences can never appear as integers or floats. +// Note: Also quoted strings, tested above. +TEST(HlsAttributeListTest, UnambiguousValues_HexSequence) { + HlsAttributeList list; + + const std::vector kSmallZeroData = {0}; + EXPECT_TRUE(list.SetHexSequence("SMALL-ZERO", kSmallZeroData)); + const std::vector kBigZeroData(10, 0); + EXPECT_TRUE(list.SetHexSequence("BIG-ZERO", kBigZeroData)); + const std::vector kDecimalData = a2b_hex("1234"); + EXPECT_TRUE(list.SetHexSequence("DEC", kDecimalData)); + const std::vector kHexData = a2b_hex("deadbeaf"); + EXPECT_TRUE(list.SetHexSequence("HEX", kHexData)); + + const std::vector names = list.GetNames(); + ASSERT_EQ(names.size(), static_cast(4)); + + for (const auto& name : names) { + EXPECT_TRUE(list.IsType(name, HlsAttributeList::kHexSequenceType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kIntegerType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kFloatType)) + << "name = " << name; + + uint64_t int_value = 0; + EXPECT_FALSE(list.GetInteger(name, &int_value)) << "name = " << name; + double float_value = 0.0; + EXPECT_FALSE(list.GetFloat(name, &float_value)) << "name = " << name; + } +} + +// Floats can never appear as hex sequences or resolutions. +// Note: Also quoted strings, tested above. +TEST(HlsAttributeListTest, UnambiguousValues_Float) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetFloat("ZERO", 0.0)); + EXPECT_TRUE(list.SetFloat("MAX-INT-FLOAT", kMaxIntFloat)); + EXPECT_TRUE(list.SetFloat("SOME-FLOAT", 12.34)); + + const std::vector names = list.GetNames(); + ASSERT_EQ(names.size(), static_cast(3)); + + for (const auto& name : names) { + EXPECT_TRUE(list.IsType(name, HlsAttributeList::kFloatType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kHexSequenceType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kResolutionType)) + << "name = " << name; + + std::string string_value; + EXPECT_FALSE(list.GetHexSequence(name, &string_value)) << "name = " << name; + std::vector vec_value; + EXPECT_FALSE(list.GetHexSequence(name, &vec_value)) << "name = " << name; + uint64_t width = 0; + uint64_t height = 0; + EXPECT_FALSE(list.GetResolution(name, &width, &height)) + << "name = " << name; + } +} + +// Resolutions can never appear as integers or floats. +// Note: Also quoted strings, tested above. +TEST(HlsAttributeListTest, UnambiguousValues_Resolution) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetResolution("ZERO-RES", kZeroInt, kZeroInt)); + EXPECT_TRUE(list.SetResolution("SOME-RES", 1920, 1080)); + EXPECT_TRUE(list.SetResolution("ZERO-WIDTH-RES", kZeroInt, 1080)); + EXPECT_TRUE(list.SetResolution("ZERO-HEIGHT-RES", 1920, kZeroInt)); + EXPECT_TRUE(list.SetResolution("MAX-RES", kMaxInt, kMaxInt)); + + const std::vector names = list.GetNames(); + ASSERT_EQ(names.size(), static_cast(5)); + + for (const auto& name : names) { + EXPECT_TRUE(list.IsType(name, HlsAttributeList::kResolutionType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kIntegerType)) + << "name = " << name; + EXPECT_FALSE(list.IsType(name, HlsAttributeList::kFloatType)) + << "name = " << name; + + uint64_t int_value = 0; + EXPECT_FALSE(list.GetInteger(name, &int_value)) << "name = " << name; + double float_value = 0.0; + EXPECT_FALSE(list.GetFloat(name, &float_value)) << "name = " << name; + } +} + +TEST(HlsAttributeListTest, Parse_SingleInteger) { + const std::string rep = "INSTREAM-ID=32"; + HlsAttributeList list; + + EXPECT_TRUE(list.Parse(rep)); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_EQ(list.Count(), static_cast(1)); + + const uint64_t kInStreamId = 32; + uint64_t value = 0; + EXPECT_TRUE(list.GetInteger("INSTREAM-ID", &value)); + EXPECT_EQ(value, kInStreamId); +} + +TEST(HlsAttributeListTest, Parse_SingleFloat) { + const std::string rep = "FRAME-RATE=29.97"; + HlsAttributeList list; + + EXPECT_TRUE(list.Parse(rep)); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_EQ(list.Count(), static_cast(1)); + + const double kFrameRate = 29.97; + double value = 0; + EXPECT_TRUE(list.GetFloat("FRAME-RATE", &value)); + EXPECT_EQ(value, kFrameRate); +} + +TEST(HlsAttributeListTest, Parse_SingleEnumString) { + const std::string rep = "METHOD=AES-128"; + HlsAttributeList list; + + EXPECT_TRUE(list.Parse(rep)); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_EQ(list.Count(), static_cast(1)); + + const std::string kMethod = "AES-128"; + std::string value; + EXPECT_TRUE(list.GetEnumString("METHOD", &value)); + EXPECT_EQ(value, kMethod); +} + +TEST(HlsAttributeListTest, Parse_SingleQuotedString) { + const std::string rep = "VERSION=\"1/2/5\""; + HlsAttributeList list; + + EXPECT_TRUE(list.Parse(rep)); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_EQ(list.Count(), static_cast(1)); + + const std::string kVersion = "1/2/5"; + std::string value; + EXPECT_TRUE(list.GetQuotedString("VERSION", &value)); + EXPECT_EQ(value, kVersion); +} + +TEST(HlsAttributeListTest, Parse_SingleEmptyQuotedString) { + const std::string rep = "KEYFORMAT=\"\""; + HlsAttributeList list; + + EXPECT_TRUE(list.Parse(rep)); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_EQ(list.Count(), static_cast(1)); + + std::string value = "NOT-EMPTY"; + EXPECT_TRUE(list.GetQuotedString("KEYFORMAT", &value)); + EXPECT_EQ(value, ""); +} + +TEST(HlsAttributeListTest, Parse_SingleHexSequence) { + const std::string rep = "IV=0xDEADBEAF"; + HlsAttributeList list; + + EXPECT_TRUE(list.Parse(rep)); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_EQ(list.Count(), static_cast(1)); + + const std::string kIvString = a2bs_hex("DEADBEAF"); + const std::vector kIvVector(kIvString.begin(), kIvString.end()); + std::string string_value; + EXPECT_TRUE(list.GetHexSequence("IV", &string_value)); + EXPECT_EQ(string_value, kIvString); + std::vector vec_value; + EXPECT_TRUE(list.GetHexSequence("IV", &vec_value)); + EXPECT_EQ(vec_value, kIvVector); +} + +TEST(HlsAttributeListTest, Parse_SingleResolution) { + const std::string rep = "RESOLUTION=1920x1080"; + HlsAttributeList list; + + EXPECT_TRUE(list.Parse(rep)); + EXPECT_FALSE(list.IsEmpty()); + EXPECT_EQ(list.Count(), static_cast(1)); + + const uint64_t kWidth = 1920; + const uint64_t kHeight = 1080; + uint64_t width = 0; + uint64_t height = 0; + EXPECT_TRUE(list.GetResolution("RESOLUTION", &width, &height)); + EXPECT_EQ(width, kWidth); + EXPECT_EQ(height, kHeight); +} + +TEST(HlsAttributeListTest, Parse_OneOfEach) { + const uint64_t kInStreamId = 32; + const double kFrameRate = 29.97; + const std::string kIvString = a2bs_hex("DEADBEAF"); + const uint64_t kWidth = 1920; + const uint64_t kHeight = 1080; + std::ostringstream hls_stream; + hls_stream << "INSTREAM-ID=32,"; + hls_stream << "FRAME-RATE=29.97,"; + hls_stream << "METHOD=AES-128,"; + hls_stream << "VERSION=\"1/2/5\","; + hls_stream << "IV=0xDEADBEAF,"; + hls_stream << "RESOLUTION=1920x1080"; + const std::string rep = hls_stream.str(); + + HlsAttributeList list; + ASSERT_TRUE(list.Parse(rep)); + + EXPECT_EQ(list.Count(), static_cast(6)); + const std::vector kExpectedNames = { + "FRAME-RATE", "INSTREAM-ID", "IV", "METHOD", "RESOLUTION", "VERSION", + }; + + uint64_t int_value = 0; + EXPECT_TRUE(list.GetInteger("INSTREAM-ID", &int_value)); + EXPECT_EQ(int_value, kInStreamId); + double float_value = 0.0; + EXPECT_TRUE(list.GetFloat("FRAME-RATE", &float_value)); + EXPECT_EQ(float_value, kFrameRate); + std::string string_value; + EXPECT_TRUE(list.GetEnumString("METHOD", &string_value)); + EXPECT_EQ(string_value, "AES-128"); + EXPECT_TRUE(list.GetQuotedString("VERSION", &string_value)); + EXPECT_EQ(string_value, "1/2/5"); + EXPECT_TRUE(list.GetHexSequence("IV", &string_value)); + EXPECT_EQ(string_value, kIvString); + uint64_t other_int_value = 0; + EXPECT_TRUE(list.GetResolution("RESOLUTION", &int_value, &other_int_value)); + EXPECT_EQ(int_value, kWidth); + EXPECT_EQ(other_int_value, kHeight); +} + +TEST(HlsAttributeListTest, Parse_KeyLike) { + // #EXT-X-KEY + const std::string uri = "data:text/plain;base64,SSdtIGEga2V5IQ=="; + std::ostringstream hls_stream; + hls_stream << "METHOD=AES-128,"; + hls_stream << "URI=\"" << uri << "\","; + hls_stream << "IV=0xDEADBEAFDEADBEAFDEADBEAFDEADBEAF,"; + hls_stream << "KEYFORMAT=\"com.widevine\","; + hls_stream << "KEYFORMATVERSIONS=\"1/2/5\""; + const std::string rep = hls_stream.str(); + + HlsAttributeList list; + ASSERT_TRUE(list.Parse(rep)); + + EXPECT_EQ(list.Count(), static_cast(5)); + const std::vector kExpectedNames = { + "IV", "KEYFORMAT", "KEYFORMATVERSIONS", "METHOD", "URI"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + std::string string_value; + EXPECT_TRUE(list.GetEnumString("METHOD", &string_value)); + EXPECT_EQ(string_value, "AES-128"); + EXPECT_TRUE(list.GetQuotedString("URI", &string_value)); + EXPECT_EQ(string_value, uri); + const std::string iv = a2bs_hex("DEADBEAFDEADBEAFDEADBEAFDEADBEAF"); + EXPECT_TRUE(list.GetHexSequence("IV", &string_value)); + EXPECT_EQ(string_value, iv); + EXPECT_TRUE(list.GetQuotedString("KEYFORMAT", &string_value)); + EXPECT_EQ(string_value, "com.widevine"); + EXPECT_TRUE(list.GetQuotedString("KEYFORMATVERSIONS", &string_value)); + EXPECT_EQ(string_value, "1/2/5"); +} + +TEST(HlsAttributeListTest, Parse_MediaLike) { + // #EXT-X-MEDIA + const std::string uri = "https://media.src/playlist"; + const std::string group_id = "The Best Group"; + const std::string language = "en"; + const std::string associated_language = "fr"; + const std::string name = "The Best Captioning"; + const std::string characteristics = "public.easy-to-read"; + std::ostringstream hls_stream; + hls_stream << "TYPE=CLOSED-CAPTIONS,"; + hls_stream << "URI=\"" << uri << "\","; + hls_stream << "GROUP-ID=\"" << group_id << "\","; + hls_stream << "LANGUAGE=\"" << language << "\","; + hls_stream << "ASSOC-LANGUAGE=\"" << associated_language << "\","; + hls_stream << "NAME=\"" << name << "\","; + hls_stream << "DEFAULT=YES,"; + hls_stream << "AUTOSELECT=NO,"; + hls_stream << "FORCED=NO,"; + hls_stream << "INSTREAM-ID=CC1,"; + hls_stream << "CHARACTERISTICS=\"" << characteristics << "\","; + hls_stream << "CHANNELS=\"1/2/4/7\""; + const std::string rep = hls_stream.str(); + + HlsAttributeList list; + ASSERT_TRUE(list.Parse(rep)); + + EXPECT_EQ(list.Count(), static_cast(12)); + const std::vector kExpectedNames = { + "ASSOC-LANGUAGE", "AUTOSELECT", "CHANNELS", "CHARACTERISTICS", + "DEFAULT", "FORCED", "GROUP-ID", "INSTREAM-ID", + "LANGUAGE", "NAME", "TYPE", "URI"}; + EXPECT_EQ(list.GetNames(), kExpectedNames); + + std::string string_value; + EXPECT_TRUE(list.GetEnumString("TYPE", &string_value)); + EXPECT_EQ(string_value, "CLOSED-CAPTIONS"); + EXPECT_TRUE(list.GetQuotedString("URI", &string_value)); + EXPECT_EQ(string_value, uri); + EXPECT_TRUE(list.GetQuotedString("GROUP-ID", &string_value)); + EXPECT_EQ(string_value, group_id); + EXPECT_TRUE(list.GetQuotedString("LANGUAGE", &string_value)); + EXPECT_EQ(string_value, language); + EXPECT_TRUE(list.GetQuotedString("ASSOC-LANGUAGE", &string_value)); + EXPECT_EQ(string_value, associated_language); + EXPECT_TRUE(list.GetQuotedString("NAME", &string_value)); + EXPECT_EQ(string_value, name); + EXPECT_TRUE(list.GetEnumString("DEFAULT", &string_value)); + EXPECT_EQ(string_value, "YES"); + EXPECT_TRUE(list.GetEnumString("AUTOSELECT", &string_value)); + EXPECT_EQ(string_value, "NO"); + string_value.clear(); + EXPECT_TRUE(list.GetEnumString("FORCED", &string_value)); + EXPECT_EQ(string_value, "NO"); + EXPECT_TRUE(list.GetEnumString("INSTREAM-ID", &string_value)); + EXPECT_EQ(string_value, "CC1"); + EXPECT_TRUE(list.GetQuotedString("CHARACTERISTICS", &string_value)); + EXPECT_EQ(string_value, characteristics); + EXPECT_TRUE(list.GetQuotedString("CHANNELS", &string_value)); + EXPECT_EQ(string_value, "1/2/4/7"); +} + +TEST(HlsAttributeListTest, BadParse_MissingValue) { + HlsAttributeList list; + EXPECT_FALSE(list.Parse("A=,B=2,C=3")); + EXPECT_FALSE(list.Parse("A=1,B=,C=3")); + EXPECT_FALSE(list.Parse("A=1,B=2,C=")); +} + +TEST(HlsAttributeListTest, BadParse_BadWhitespace) { + HlsAttributeList list; + EXPECT_FALSE(list.Parse("KEY=THIS VALUE")); + EXPECT_FALSE(list.Parse("THIS KEY=1")); + EXPECT_FALSE(list.Parse(" A=1,B=2")); + EXPECT_FALSE(list.Parse("A =1,B=2")); + EXPECT_FALSE(list.Parse("A= 1,B=2")); + EXPECT_FALSE(list.Parse("A=1 ,B=2")); + EXPECT_FALSE(list.Parse("A=1, B=2")); + EXPECT_FALSE(list.Parse("A=1,B =2")); + EXPECT_FALSE(list.Parse("A=1,B= 2")); + EXPECT_FALSE(list.Parse("A=1,B=2 ")); +} + +TEST(HlsAttributeListTest, BadParse_BadQuotes) { + HlsAttributeList list; + EXPECT_FALSE(list.Parse("\"KEY\"=VALUE")); + EXPECT_FALSE(list.Parse("KEY=\"VALUE")); + EXPECT_FALSE(list.Parse("KEY=VALUE\"")); + EXPECT_FALSE(list.Parse("KEY=\"VALUE\"AND")); +} + +TEST(HlsAttributeListTest, Serialize) { + HlsAttributeList list; + + EXPECT_TRUE(list.SetInteger("INSTREAM-ID", 32)); + EXPECT_TRUE(list.SetFloat("FRAME-RATE", 29.97)); + EXPECT_TRUE(list.SetEnumString("METHOD", "AES-128")); + EXPECT_TRUE(list.SetQuotedString("VERSION", "1/2/5")); + EXPECT_TRUE(list.SetHexSequence("IV", a2b_hex("DEADBEAF"))); + EXPECT_TRUE(list.SetResolution("RESOLUTION", 1920, 1080)); + + // Note: The HlsAttributeList class does not guarantee any particular + // order when serializing to a string. + const std::vector kExpectedParts = { + "FRAME-RATE=29.97", "INSTREAM-ID=32", + // Note: HlsAttributeList uses lower prefix "0x" for hex sequences. + "IV=0xDEADBEAF", "METHOD=AES-128", "RESOLUTION=1920x1080", + "VERSION=\"1/2/5\""}; + const std::string result = list.Serialize(); + ASSERT_FALSE(result.empty()); + + size_t expected_length = kExpectedParts.size() - 1; // Start with commas. + for (const auto& part : kExpectedParts) { + ASSERT_TRUE(StringContains(result, part)) << "result: " << result; + expected_length += part.size(); + } + EXPECT_EQ(expected_length, result.size()) << "result: " << result; +} + +TEST(HlsAttributeListTest, Serialize_Empty) { + HlsAttributeList list; + EXPECT_EQ(list.Serialize(), ""); +} + +TEST(HlsAttributeListTest, Serialize_One) { + HlsAttributeList list; + EXPECT_TRUE(list.SetInteger("INSTREAM-ID", 32)); + EXPECT_EQ(list.Serialize(), "INSTREAM-ID=32"); +} + +TEST(HlsAttributeListTest, Serialize_QuotedEmpty) { + HlsAttributeList list; + EXPECT_TRUE(list.SetQuotedString("VERSION", "")); + EXPECT_EQ(list.Serialize(), "VERSION=\"\""); +} +} // namespace test +} // namespace wvutil diff --git a/util/test/string_utils_unittest.cpp b/util/test/string_utils_unittest.cpp new file mode 100644 index 0000000..1663cf7 --- /dev/null +++ b/util/test/string_utils_unittest.cpp @@ -0,0 +1,1890 @@ +// Copyright 2024 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "string_utils.h" + +#include +#include + +#include + +namespace wvutil { + +// std::vector StringSplit(const std::string& s, char delim); + +TEST(StringUtilsTest, StringSplit_CharDelim_EmptyText) { + const std::vector kEmptyVec; + const std::string kEmptyString; + constexpr char kDelimiter = ','; + EXPECT_EQ(StringSplit(kEmptyString, kDelimiter), kEmptyVec); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_SingleCharacterText_NoMatch) { + const std::string kText = "a"; + constexpr char kDelimiter = ','; + const std::vector kExpected = {"a"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_SingleCharacterText_Match) { + const std::string kText = "a"; + constexpr char kDelimiter = 'a'; + const std::vector kExpected = {"", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_LongText_NoMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter = ','; + const std::vector kExpected = {"abcdefg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_LongText_SingleMidMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter = 'd'; + const std::vector kExpected = {"abc", "efg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_LongText_SingleFrontMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter = 'a'; + const std::vector kExpected = {"", "bcdefg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_LongText_SingleBackMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter = 'g'; + const std::vector kExpected = {"abcdef", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_LongText_MultiInnerMatch) { + const std::string kText = "123.456.789"; + constexpr char kDelimiter = '.'; + const std::vector kExpected = {"123", "456", "789"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_LongText_MultiEndsMatch) { + const std::string kText = ".123..456..789."; + constexpr char kDelimiter = '.'; + const std::vector kExpected = {"", "123", "", "456", + "", "789", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_LongText_AllMatch) { + const std::string kText = "aaaaa"; + constexpr char kDelimiter = 'a'; + const std::vector kExpected = {"", "", "", "", "", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CharDelim_NullByte) { + using namespace std::string_literals; + // Ensure that split can handle null bytes within C++ strings. + const std::string kText = "Hello\0World!"s; + constexpr char kDelimiter = '\0'; + const std::vector kExpected = {"Hello", "World!"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +// std::vector StringSplit(const std::string& s, +// const std::string& delim); + +TEST(StringUtilsTest, StringSplit_StringDelim_EmptyText_EmptyDelim) { + const std::string kEmptyString; + const std::vector kExpected = {}; + EXPECT_EQ(StringSplit(kEmptyString, kEmptyString), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_EmptyText_NoMatch) { + const std::string kText = ""; + const std::string kDelimiter = "a"; + const std::vector kExpected = {}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_SingleCharText_EmptyDelim) { + const std::string kText = "a"; + const std::string kDelimiter = ""; + const std::vector kExpected = {"a"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_SingleCharText_NoMatch) { + const std::string kText = "a"; + const std::string kDelimiter = ","; + const std::vector kExpected = {"a"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_SingleCharText_NoMatch2) { + const std::string kText = "a"; + const std::string kDelimiter = "bcd"; + const std::vector kExpected = {"a"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_SingleCharText_Match) { + const std::string kText = "a"; + const std::string kDelimiter = "a"; + const std::vector kExpected = {"", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_EmptyDelim) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = ""; + const std::vector kExpected = {"a", "b", "c", "d", + "e", "f", "g"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_NoMatch) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = "edc"; + const std::vector kExpected = {kText}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_SingleMidMatch) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = "d"; + const std::vector kExpected = {"abc", "efg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_SingleMidMatch2) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = "cde"; + const std::vector kExpected = {"ab", "fg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_SingleFrontMatch) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = "a"; + const std::vector kExpected = {"", "bcdefg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_SingleFrontMatch2) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = "abc"; + const std::vector kExpected = {"", "defg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_SingleBackMatch) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = "g"; + const std::vector kExpected = {"abcdef", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_SingleBackMatch2) { + const std::string kText = "abcdefg"; + const std::string kDelimiter = "efg"; + const std::vector kExpected = {"abcd", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_MultiInnerMatch) { + const std::string kText = "123.456.789"; + const std::string kDelimiter = "."; + const std::vector kExpected = {"123", "456", "789"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_MultiInnerMatch2) { + const std::string kText = "123, 456, 789"; + const std::string kDelimiter = ", "; + const std::vector kExpected = {"123", "456", "789"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_MultiEndsMatch) { + const std::string kText = ".123.456..789."; + const std::string kDelimiter = "."; + const std::vector kExpected = {"", "123", "456", "", "789", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_MultiEndsMatch2) { + const std::string kText = ", 123, 456, 789, "; + const std::string kDelimiter = ", "; + const std::vector kExpected = {"", "123", "456", "789", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_AllMatch) { + const std::string kText = "aaaaa"; + const std::string kDelimiter = "a"; + const std::vector kExpected = {"", "", "", "", "", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_StringDelim_LongText_AllMatch2) { + const std::string kText = "123123123"; + const std::string kDelimiter = "123"; + const std::vector kExpected = {"", "", "", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +// std::vector StringSplit(const std::string& s, +// const char* delim); + +TEST(StringUtilsTest, StringSplit_CStringDelim_NullDelim) { + const std::string kText = "a"; + const std::vector kExpected = {}; + EXPECT_EQ(StringSplit(kText, nullptr), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_EmptyText_EmptyDelim) { + const std::string kText; + constexpr char kDelimiter[] = ""; + const std::vector kExpected = {}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_EmptyText_NoMatch) { + const std::string kText = ""; + constexpr char kDelimiter[] = "a"; + const std::vector kExpected = {}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_SingleCharText_EmptyDelim) { + const std::string kText = "a"; + constexpr char kDelimiter[] = ""; + const std::vector kExpected = {"a"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_SingleCharText_NoMatch) { + const std::string kText = "a"; + constexpr char kDelimiter[] = ","; + const std::vector kExpected = {"a"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_SingleCharText_NoMatch2) { + const std::string kText = "a"; + constexpr char kDelimiter[] = "bcd"; + const std::vector kExpected = {"a"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_SingleCharText_Match) { + const std::string kText = "a"; + constexpr char kDelimiter[] = "a"; + const std::vector kExpected = {"", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_EmptyDelim) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = ""; + const std::vector kExpected = {"a", "b", "c", "d", + "e", "f", "g"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_NoMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = "edc"; + const std::vector kExpected = {kText}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_SingleMidMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = "d"; + const std::vector kExpected = {"abc", "efg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_SingleMidMatch2) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = "cde"; + const std::vector kExpected = {"ab", "fg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_SingleFrontMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = "a"; + const std::vector kExpected = {"", "bcdefg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_SingleFrontMatch2) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = "abc"; + const std::vector kExpected = {"", "defg"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_SingleBackMatch) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = "g"; + const std::vector kExpected = {"abcdef", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_SingleBackMatch2) { + const std::string kText = "abcdefg"; + constexpr char kDelimiter[] = "efg"; + const std::vector kExpected = {"abcd", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_MultiInnerMatch) { + const std::string kText = "123.456.789"; + constexpr char kDelimiter[] = "."; + const std::vector kExpected = {"123", "456", "789"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_MultiInnerMatch2) { + const std::string kText = "123, 456, 789"; + constexpr char kDelimiter[] = ", "; + const std::vector kExpected = {"123", "456", "789"}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_MultiEndsMatch) { + const std::string kText = ".123.456..789."; + constexpr char kDelimiter[] = "."; + const std::vector kExpected = {"", "123", "456", "", "789", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_MultiEndsMatch2) { + const std::string kText = ", 123, 456, 789, "; + constexpr char kDelimiter[] = ", "; + const std::vector kExpected = {"", "123", "456", "789", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_AllMatch) { + const std::string kText = "aaaaa"; + constexpr char kDelimiter[] = "a"; + const std::vector kExpected = {"", "", "", "", "", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +TEST(StringUtilsTest, StringSplit_CStringDelim_LongText_AllMatch2) { + const std::string kText = "123123123"; + constexpr char kDelimiter[] = "123"; + const std::vector kExpected = {"", "", "", ""}; + EXPECT_EQ(StringSplit(kText, kDelimiter), kExpected); +} + +// std::string StringJoin(const std::vector& tokens, char glue); + +TEST(StringUtilsTest, StringJoin_CharGlue_NoTokens) { + const std::vector kTokens = {}; + constexpr char kGlue = ','; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CharGlue_SingleEmptyToken) { + const std::vector kTokens = {""}; + constexpr char kGlue = ','; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CharGlue_SingleToken) { + const std::vector kTokens = {"a"}; + constexpr char kGlue = ','; + const std::string kExpected = "a"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CharGlue_SingleToken2) { + const std::vector kTokens = {"abcdefg"}; + constexpr char kGlue = ','; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CharGlue_ManyEmptyTokens) { + const std::vector kTokens = {"", "", ""}; + constexpr char kGlue = ','; + const std::string kExpected = ",,"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CharGlue_ManyTokens) { + const std::vector kTokens = {"a", "bcd", "", "efg"}; + constexpr char kGlue = ','; + const std::string kExpected = "a,bcd,,efg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +// std::string StringJoin(const std::vector& tokens, +// const std::string& glue = ""); + +TEST(StringUtilsTest, StringJoin_StringGlue_NoTokens_EmptyGlue) { + const std::vector kTokens = {}; + const std::string kGlue = ""; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleEmptyToken_EmptyGlue) { + const std::vector kTokens = {""}; + const std::string kGlue = ""; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleToken_EmptyGlue) { + const std::vector kTokens = {"a"}; + const std::string kGlue = ""; + const std::string kExpected = "a"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleToken2_EmptyGlue) { + const std::vector kTokens = {"abcdefg"}; + const std::string kGlue = ""; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_ManyEmptyTokens_EmptyGlue) { + const std::vector kTokens = {"", "", ""}; + const std::string kGlue = ""; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_ManyTokens_EmptyGlue) { + const std::vector kTokens = {"a", "bcd", "", "efg"}; + const std::string kGlue = ""; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_NoTokens_SingleCharGlue) { + const std::vector kTokens = {}; + const std::string kGlue = ","; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleEmptyToken_SingleCharGlue) { + const std::vector kTokens = {""}; + const std::string kGlue = ","; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleToken_SingleCharGlue) { + const std::vector kTokens = {"a"}; + const std::string kGlue = ","; + const std::string kExpected = "a"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleToken2_SingleCharGlue) { + const std::vector kTokens = {"abcdefg"}; + const std::string kGlue = ","; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_ManyEmptyTokens_SingleCharGlue) { + const std::vector kTokens = {"", "", ""}; + const std::string kGlue = ","; + const std::string kExpected = ",,"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_ManyTokens_SingleCharGlue) { + const std::vector kTokens = {"a", "bcd", "", "efg"}; + const std::string kGlue = ","; + const std::string kExpected = "a,bcd,,efg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_NoTokens_MultiCharGlue) { + const std::vector kTokens = {}; + const std::string kGlue = ", "; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleEmptyToken_MultiCharGlue) { + const std::vector kTokens = {""}; + const std::string kGlue = ", "; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleToken_MultiCharGlue) { + const std::vector kTokens = {"a"}; + const std::string kGlue = ", "; + const std::string kExpected = "a"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_SingleToken2_MultiCharGlue) { + const std::vector kTokens = {"abcdefg"}; + const std::string kGlue = ", "; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_ManyEmptyTokens_MultiCharGlue) { + const std::vector kTokens = {"", "", ""}; + const std::string kGlue = ", "; + const std::string kExpected = ", , "; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_StringGlue_ManyTokens_MultiCharGlue) { + const std::vector kTokens = {"a", "bcd", "", "efg"}; + const std::string kGlue = ", "; + const std::string kExpected = "a, bcd, , efg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +// std::string StringJoin(const std::vector& tokens, +// const char* glue); + +TEST(StringUtilsTest, StringJoin_CStringGlue_NullGlue) { + const std::vector kTokens = {"123", "456"}; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, nullptr), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_NoTokens_EmptyGlue) { + const std::vector kTokens = {}; + constexpr char kGlue[] = ""; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleEmptyToken_EmptyGlue) { + const std::vector kTokens = {""}; + constexpr char kGlue[] = ""; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleToken_EmptyGlue) { + const std::vector kTokens = {"a"}; + constexpr char kGlue[] = ""; + const std::string kExpected = "a"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleToken2_EmptyGlue) { + const std::vector kTokens = {"abcdefg"}; + constexpr char kGlue[] = ""; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_ManyEmptyTokens_EmptyGlue) { + const std::vector kTokens = {"", "", ""}; + constexpr char kGlue[] = ""; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_ManyTokens_EmptyGlue) { + const std::vector kTokens = {"a", "bcd", "", "efg"}; + constexpr char kGlue[] = ""; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_NoTokens_SingleCharGlue) { + const std::vector kTokens = {}; + constexpr char kGlue[] = ","; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleEmptyToken_SingleCharGlue) { + const std::vector kTokens = {""}; + constexpr char kGlue[] = ","; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleToken_SingleCharGlue) { + const std::vector kTokens = {"a"}; + constexpr char kGlue[] = ","; + const std::string kExpected = "a"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleToken2_SingleCharGlue) { + const std::vector kTokens = {"abcdefg"}; + constexpr char kGlue[] = ","; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_ManyEmptyTokens_SingleCharGlue) { + const std::vector kTokens = {"", "", ""}; + constexpr char kGlue[] = ","; + const std::string kExpected = ",,"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_ManyTokens_SingleCharGlue) { + const std::vector kTokens = {"a", "bcd", "", "efg"}; + constexpr char kGlue[] = ","; + const std::string kExpected = "a,bcd,,efg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_NoTokens_MultiCharGlue) { + const std::vector kTokens = {}; + constexpr char kGlue[] = ", "; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleEmptyToken_MultiCharGlue) { + const std::vector kTokens = {""}; + constexpr char kGlue[] = ", "; + const std::string kExpected = ""; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleToken_MultiCharGlue) { + const std::vector kTokens = {"a"}; + constexpr char kGlue[] = ", "; + const std::string kExpected = "a"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_SingleToken2_MultiCharGlue) { + const std::vector kTokens = {"abcdefg"}; + constexpr char kGlue[] = ", "; + const std::string kExpected = "abcdefg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_ManyEmptyTokens_MultiCharGlue) { + const std::vector kTokens = {"", "", ""}; + constexpr char kGlue[] = ", "; + const std::string kExpected = ", , "; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +TEST(StringUtilsTest, StringJoin_CStringGlue_ManyTokens_MultiCharGlue) { + const std::vector kTokens = {"a", "bcd", "", "efg"}; + constexpr char kGlue[] = ", "; + const std::string kExpected = "a, bcd, , efg"; + EXPECT_EQ(StringJoin(kTokens, kGlue), kExpected); +} + +// size_t StringCount(const std::string& haystack, char needle); + +TEST(StringUtilsTest, StringCount_CharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle = '.'; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CharNeedle_SingleCharHaystack_NoMatch) { + const std::string kHaystack = "a"; + constexpr char kNeedle = '.'; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CharNeedle_SingleCharHaystack_Match) { + const std::string kHaystack = "a"; + constexpr char kNeedle = 'a'; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CharNeedle_LongHaystack_NoMatch) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = '.'; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CharNeedle_LongHaystack_SingleMatch) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'd'; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CharNeedle_LongHaystack_ManyMatch) { + const std::string kHaystack = "123.456.789"; + constexpr char kNeedle = '.'; + constexpr size_t kExpected = 2; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CharNeedle_LongHaystack_AllMatch) { + const std::string kHaystack = "aaaaa"; + constexpr char kNeedle = 'a'; + constexpr size_t kExpected = 5; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +// size_t StringCount(const std::string& haystack, const std::string& needle); + +TEST(StringUtilsTest, StringCount_StringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = ""; + constexpr size_t kExpected = 1; // Special case + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_StringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + const std::string kNeedle = ""; + constexpr size_t kExpected = 2; // Special case + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_StringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "1234567"; + const std::string kNeedle = ""; + constexpr size_t kExpected = 8; // Special case + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_StringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "."; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_SingleCharNeedle_SingleCharHaystack_NoMatch) { + const std::string kHaystack = "a"; + const std::string kNeedle = "."; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_SingleCharNeedle_SingleCharHaystack_Match) { + const std::string kHaystack = "a"; + const std::string kNeedle = "a"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_SingleCharNeedle_LongHaystack_NoMatch) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "."; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_SingleCharNeedle_LongHaystack_SingleMatch) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "d"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_SingleCharNeedle_LongHaystack_ManyMatch) { + const std::string kHaystack = "123.456.789"; + const std::string kNeedle = "."; + constexpr size_t kExpected = 2; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_SingleCharNeedle_LongHaystack_AllMatch) { + const std::string kHaystack = "aaaaa"; + const std::string kNeedle = "a"; + constexpr size_t kExpected = 5; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_StringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "123"; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_MultiCharNeedle_SingleCharHaystack_NoMatch) { + const std::string kHaystack = "a"; + const std::string kNeedle = "abc"; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_MultiCharNeedle_LongHaystack_NoMatch) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "123"; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "abcdefg"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_MultiCharNeedle_LongHaystack_SingleMatch) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "cde"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_MultiCharNeedle_LongHaystack_ManyMatch) { + const std::string kHaystack = "123.123.123"; + const std::string kNeedle = "123"; + constexpr size_t kExpected = 3; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_StringNeedle_MultiCharNeedle_LongHaystack_AllMatch) { + const std::string kHaystack = "aaaaaa"; + const std::string kNeedle = "aa"; + constexpr size_t kExpected = 3; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +// size_t StringCount(const std::string& haystack, const char* needle); + +TEST(StringUtilsTest, StringCount_CStringNeedle_NullNeedle) { + const std::string kHaystack = ""; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, nullptr), kExpected); +} + +TEST(StringUtilsTest, StringCount_CStringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle[] = ""; + constexpr size_t kExpected = 1; // Special case + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = ""; + constexpr size_t kExpected = 2; // Special case + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CStringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "1234567"; + constexpr char kNeedle[] = ""; + constexpr size_t kExpected = 8; // Special case + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle[] = "."; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_SingleCharNeedle_SingleCharHaystack_NoMatch) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = "."; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_SingleCharNeedle_SingleCharHaystack_Match) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = "a"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_SingleCharNeedle_LongHaystack_NoMatch) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "."; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "abcdefg"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_SingleCharNeedle_LongHaystack_SingleMatch) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "d"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_SingleCharNeedle_LongHaystack_ManyMatch) { + const std::string kHaystack = "123.456.789"; + constexpr char kNeedle[] = "."; + constexpr size_t kExpected = 2; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_SingleCharNeedle_LongHaystack_AllMatch) { + const std::string kHaystack = "aaaaa"; + constexpr char kNeedle[] = "a"; + constexpr size_t kExpected = 5; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, StringCount_CStringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle[] = "123"; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_MultiCharNeedle_SingleCharHaystack_NoMatch) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = "abc"; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_MultiCharNeedle_LongHaystack_NoMatch) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "123"; + constexpr size_t kExpected = 0; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_MultiCharNeedle_LongHaystack_SingleMatch) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "cde"; + constexpr size_t kExpected = 1; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_MultiCharNeedle_LongHaystack_ManyMatch) { + const std::string kHaystack = "123.123.123"; + constexpr char kNeedle[] = "123"; + constexpr size_t kExpected = 3; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +TEST(StringUtilsTest, + StringCount_CStringNeedle_MultiCharNeedle_LongHaystack_AllMatch) { + const std::string kHaystack = "aaaaaa"; + constexpr char kNeedle[] = "aa"; + constexpr size_t kExpected = 3; + EXPECT_EQ(StringCount(kHaystack, kNeedle), kExpected); +} + +// bool StringContains(const std::string& haystack, char needle); + +TEST(StringUtilsTest, StringContains_CharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_CharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_CharNeedle_SingleCharHaystack_Found) { + const std::string kHaystack = "a"; + constexpr char kNeedle = 'a'; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_CharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_CharNeedle_LongHaystack_FrontFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'a'; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_CharNeedle_LongHaystack_BackFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'g'; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_CharNeedle_LongHaystack_MidFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'd'; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +// bool StringContains(const std::string& haystack, const std::string& needle); + +TEST(StringUtilsTest, StringContains_StringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_StringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "."; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_SingleCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const std::string kNeedle = "."; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_SingleCharNeedle_SingleCharHaystack_Equal) { + const std::string kHaystack = "a"; + const std::string kNeedle = "a"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_SingleCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "."; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_SingleCharNeedle_LongHaystack_FrontFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "a"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_SingleCharNeedle_LongHaystack_BackFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "g"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_SingleCharNeedle_LongHaystack_MidFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "d"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_MultiCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const std::string kNeedle = "abc"; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_MultiCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abc"; + const std::string kNeedle = "abc"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_MultiCharNeedle_LongHaystack_FrontFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "abc"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_MultiCharNeedle_LongHaystack_BackFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "efg"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_StringNeedle_MultiCharNeedle_LongHaystack_MidFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "cde"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +// bool StringContains(const std::string& haystack, const char* needle); + +TEST(StringUtilsTest, StringContains_CStringNeedle_NullNeedle) { + const std::string kHaystack = "abcdefg"; + EXPECT_FALSE(StringContains(kHaystack, nullptr)); +} + +TEST(StringUtilsTest, StringContains_CStringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringContains_CStringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle[] = "."; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringContains_CStringNeedle_SingleCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = "."; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_SingleCharNeedle_SingleCharHaystack_Equal) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = "a"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_SingleCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "."; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_SingleCharNeedle_LongHaystack_FrontFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "a"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_SingleCharNeedle_LongHaystack_BackFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "g"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_SingleCharNeedle_LongHaystack_MidFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "d"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle[] = "123"; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_MultiCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + constexpr char kNeedle[] = "abc"; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_MultiCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "123"; + EXPECT_FALSE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abc"; + constexpr char kNeedle[] = "abc"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_MultiCharNeedle_LongHaystack_FrontFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "abc"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_MultiCharNeedle_LongHaystack_BackFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "efg"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringContains_CStringNeedle_MultiCharNeedle_LongHaystack_MidFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle[] = "cde"; + EXPECT_TRUE(StringContains(kHaystack, kNeedle)); +} + +// bool StringStartsWith(const std::string& haystack, char needle); + +TEST(StringUtilsTest, StringStartsWith_CharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringStartsWith_CharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringStartsWith_CharNeedle_SingleCharHaystack_Found) { + const std::string kHaystack = "a"; + constexpr char kNeedle = 'a'; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringStartsWith_CharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CharNeedle_LongHaystack_NotFoundAtStart) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'b'; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringStartsWith_CharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'a'; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +// bool StringStartsWith(const std::string& haystack, +// const std::string& needle); + +TEST(StringUtilsTest, StringStartsWith_StringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringStartsWith_StringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "."; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_StringNeedle_SingleCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const std::string kNeedle = "."; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_SingleCharNeedle_SingleCharHaystack_Found) { + const std::string kHaystack = "a"; + const std::string kNeedle = "a"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_SingleCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "."; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_StringNeedle_SingleCharNeedle_LongHaystack_NotFoundAtStart) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "b"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_SingleCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "a"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_StringNeedle_MultiCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const std::string kNeedle = "abc"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_MultiCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_StringNeedle_MultiCharNeedle_LongHaystack_NotFoundAtStart) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "bcd"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abc"; + const std::string kNeedle = "abc"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_StringNeedle_MultiCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "abc"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +// bool StringStartsWith(const std::string& haystack, const char* needle); + +TEST(StringUtilsTest, StringStartsWith_CStringNeedle_NullNeedle) { + const std::string kHaystack = "abcdefg"; + EXPECT_FALSE(StringStartsWith(kHaystack, nullptr)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + const char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringStartsWith_CStringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const char kNeedle[] = "."; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_CStringNeedle_SingleCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const char kNeedle[] = "."; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_SingleCharNeedle_SingleCharHaystack_Found) { + const std::string kHaystack = "a"; + const char kNeedle[] = "a"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_SingleCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "."; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_CStringNeedle_SingleCharNeedle_LongHaystack_NotFoundAtStart) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "b"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_SingleCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "a"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const char kNeedle[] = "123"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_CStringNeedle_MultiCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const char kNeedle[] = "abc"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_MultiCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "123"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringStartsWith_CStringNeedle_MultiCharNeedle_LongHaystack_NotFoundAtStart) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "bcd"; + EXPECT_FALSE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abc"; + const char kNeedle[] = "abc"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringStartsWith_CStringNeedle_MultiCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "abc"; + EXPECT_TRUE(StringStartsWith(kHaystack, kNeedle)); +} + +// bool StringEndsWith(const std::string& haystack, char needle); + +TEST(StringUtilsTest, StringEndsWith_CharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringEndsWith_CharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringEndsWith_CharNeedle_SingleCharHaystack_Found) { + const std::string kHaystack = "a"; + constexpr char kNeedle = 'a'; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringEndsWith_CharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = '.'; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringEndsWith_CharNeedle_LongHaystack_NotFoundAtEnd) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'f'; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringEndsWith_CharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + constexpr char kNeedle = 'g'; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +// bool StringEndsWith(const std::string& haystack, +// const std::string& needle); + +TEST(StringUtilsTest, StringEndsWith_StringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringEndsWith_StringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = ""; // Special case + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "."; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_SingleCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const std::string kNeedle = "."; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_SingleCharNeedle_SingleCharHaystack_Found) { + const std::string kHaystack = "a"; + const std::string kNeedle = "a"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_SingleCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "."; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_SingleCharNeedle_LongHaystack_NotFoundAtEnd) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "f"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_SingleCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "g"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_MultiCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "c"; + const std::string kNeedle = "abc"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_MultiCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_MultiCharNeedle_LongHaystack_NotFoundAtEnd) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "def"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abc"; + const std::string kNeedle = "abc"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_StringNeedle_MultiCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const std::string kNeedle = "efg"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +// bool StringEndsWith(const std::string& haystack, const char* needle); + +TEST(StringUtilsTest, StringEndsWith_CStringNeedle_NullNeedle) { + const std::string kHaystack = "abcdefg"; + EXPECT_FALSE(StringEndsWith(kHaystack, nullptr)); +} + +TEST(StringUtilsTest, StringEndsWith_CStringNeedle_EmptyNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_EmptyNeedle_SingleCharHaystack) { + const std::string kHaystack = "a"; + const char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringEndsWith_CStringNeedle_EmptyNeedle_LongHaystack) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = ""; // Special case + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_SingleCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const char kNeedle[] = "."; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST( + StringUtilsTest, + StringEndsWith_CStringNeedle_SingleCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const char kNeedle[] = "."; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_SingleCharNeedle_SingleCharHaystack_Found) { + const std::string kHaystack = "a"; + const char kNeedle[] = "a"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_SingleCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "."; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_SingleCharNeedle_LongHaystack_NotFoundAtEnd) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "f"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_SingleCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "g"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_MultiCharNeedle_EmptyHaystack) { + const std::string kHaystack = ""; + const char kNeedle[] = "123"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_MultiCharNeedle_SingleCharHaystack_NotFound) { + const std::string kHaystack = "a"; + const char kNeedle[] = "abc"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_MultiCharNeedle_LongHaystack_NotFound) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "123"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_MultiCharNeedle_LongHaystack_NotFoundAtEnd) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "def"; + EXPECT_FALSE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_MultiCharNeedle_LongHaystack_Equal) { + const std::string kHaystack = "abc"; + const char kNeedle[] = "abc"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, + StringEndsWith_CStringNeedle_MultiCharNeedle_LongHaystack_Found) { + const std::string kHaystack = "abcdefg"; + const char kNeedle[] = "efg"; + EXPECT_TRUE(StringEndsWith(kHaystack, kNeedle)); +} + +// std::string StringTrim(const std::string& s); + +TEST(StringUtilsTest, StringTrim_Empty) { EXPECT_EQ(StringTrim(""), ""); } + +TEST(StringUtilsTest, StringTrim_NoWhitespace) { + EXPECT_EQ(StringTrim("abcdefg"), "abcdefg"); +} + +TEST(StringUtilsTest, StringTrim_LeadingWhitespace) { + EXPECT_EQ(StringTrim("\t\n \rabcdefg"), "abcdefg"); +} + +TEST(StringUtilsTest, StringTrim_TrailingWhitespace) { + EXPECT_EQ(StringTrim("abcdefg\t\n \r"), "abcdefg"); +} + +TEST(StringUtilsTest, StringTrim_LeadingAndTrailingWhitespace) { + EXPECT_EQ(StringTrim("\t\n \rabcdefg\t\n \r"), "abcdefg"); +} + +TEST(StringUtilsTest, StringTrime_AllWhitespace) { + EXPECT_EQ(StringTrim("\t\n \r"), ""); +} + +// bool StringVecContains(const std::vector& haystack, +// const std::string& needle); + +TEST(StringUtilsTest, StringVecContains_EmptyHaystack) { + const std::vector kHaystack = {}; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringVecContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringVecContains_NotFound) { + const std::vector kHaystack = {"456", "789"}; + const std::string kNeedle = "123"; + EXPECT_FALSE(StringVecContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringVecContains_FoundAtFront) { + const std::vector kHaystack = {"123", "456", "789"}; + const std::string kNeedle = "123"; + EXPECT_TRUE(StringVecContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringVecContains_FoundAtBack) { + const std::vector kHaystack = {"123", "456", "789"}; + const std::string kNeedle = "789"; + EXPECT_TRUE(StringVecContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringVecContains_FoundAtMid) { + const std::vector kHaystack = {"123", "456", "789"}; + const std::string kNeedle = "456"; + EXPECT_TRUE(StringVecContains(kHaystack, kNeedle)); +} + +TEST(StringUtilsTest, StringVecContains_FoundEmpty) { + const std::vector kHaystack = {"123", "", "789"}; + const std::string kNeedle = ""; + EXPECT_TRUE(StringVecContains(kHaystack, kNeedle)); +} +} // namespace wvutil diff --git a/util/test/test_clock.cpp b/util/test/test_clock.cpp index 61fec1f..d512731 100644 --- a/util/test/test_clock.cpp +++ b/util/test/test_clock.cpp @@ -5,11 +5,14 @@ // Clock - A fake clock just for running tests. This is used when running // OEMCrypto unit tests. It is not used when tests include the CE CDM source // code because that uses the clock in cdm/test_host.cpp instead. +#include #include #include "clock.h" #include "test_sleep.h" +#include "wv_duration.h" +#include "wv_timestamp.h" namespace wvutil { @@ -19,16 +22,19 @@ class FakeClock : public TestSleep::CallBack { public: FakeClock() { auto now = std::chrono::system_clock().now(); - now_ = now.time_since_epoch() / std::chrono::milliseconds(1); + now_ = Timestamp::FromUnixMilliseconds( + std::chrono::floor(now.time_since_epoch())); TestSleep::AddCallback(this); } ~FakeClock() { TestSleep::RemoveCallback(this); } - void ElapseTime(int64_t milliseconds) { now_ += milliseconds; } + void ElapseTime(int64_t milliseconds) override { + now_ += Duration::FromMilliseconds(milliseconds); + } - int64_t now() const { return now_; } + Timestamp now() const { return now_; } private: - int64_t now_; + Timestamp now_; }; FakeClock* g_fake_clock = nullptr; @@ -38,7 +44,12 @@ FakeClock* g_fake_clock = nullptr; int64_t Clock::GetCurrentTime() { TestSleep::SyncFakeClock(); if (g_fake_clock == nullptr) g_fake_clock = new FakeClock(); - return g_fake_clock->now() / 1000; + return static_cast(g_fake_clock->now().epoch_seconds().count()); } +Timestamp Clock::GetCurrentTimestamp() { + TestSleep::SyncFakeClock(); + if (g_fake_clock == nullptr) g_fake_clock = new FakeClock(); + return g_fake_clock->now(); +} } // namespace wvutil diff --git a/util/test/wv_date_time_unittest.cpp b/util/test/wv_date_time_unittest.cpp new file mode 100644 index 0000000..367210a --- /dev/null +++ b/util/test/wv_date_time_unittest.cpp @@ -0,0 +1,464 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "wv_date_time.h" + +#include + +// #include +#include + +#include "wv_duration.h" +#include "wv_timestamp.h" + +namespace wvutil { +namespace test { +namespace { +// This is a Unix time chosen for testing. +// Friday, April 19th, 2024 13:36:18.507 +// April 19th is the 110th day of the year when a leap year. +constexpr uint64_t kTestUnixTimeS = 1713533778; +constexpr uint64_t kTestUnixTimeMs = 1713533778507; + +constexpr Seconds kTestTimeS = Seconds(kTestUnixTimeS); +constexpr Milliseconds kTestTimeMs = Milliseconds(kTestUnixTimeMs); + +constexpr uint32_t kTestYear = 2024; +constexpr uint32_t kTestMonth = 4; // April +constexpr uint32_t kTestDay = 19; +constexpr uint32_t kTestHour = 13; +constexpr uint32_t kTestMinute = 36; +constexpr uint32_t kTestSecond = 18; +constexpr uint32_t kTestMs = 507; +constexpr uint32_t kTestDayOfWeek = 6; // Friday +constexpr uint32_t kTestDayOfYear = 110; + +constexpr int64_t kOneMinuteS = 60; +constexpr int64_t kOneHourS = kOneMinuteS * 60; +constexpr int64_t kOneDayS = kOneHourS * 24; +constexpr int64_t kOneCommonYearS = kOneDayS * 365; +constexpr int64_t kOneLeapYearS = kOneDayS * 366; + +constexpr int64_t kOneSecondMs = 1000; +constexpr int64_t kOneMinuteMs = kOneSecondMs * 60; +constexpr int64_t kOneHourMs = kOneMinuteMs * 60; +constexpr int64_t kOneDayMs = kOneHourMs * 24; + +uint32_t DayOfWeekAddition(uint32_t day_of_week, uint32_t inc) { + return ((day_of_week - 1 + inc) % 7) + 1; +} + +uint32_t DayOfWeekSubtraction(uint32_t day_of_week, uint32_t dec) { + const uint32_t reverse_day_of_week = (8 - day_of_week); + const uint32_t reverse_new_day_of_week = + DayOfWeekAddition(reverse_day_of_week, dec); + return (8 - reverse_new_day_of_week); +} +} // namespace + +TEST(WvDateTimeUtilTest, FromUnixSeconds) { + const DateTime datetime = DateTime::FromUnixSeconds(kTestUnixTimeS, kTestMs); + ASSERT_TRUE(datetime.IsSet()); + EXPECT_EQ(datetime.epoch_seconds(), kTestTimeS); + EXPECT_EQ(datetime.epoch_milliseconds(), kTestTimeMs); + + EXPECT_EQ(datetime.year(), kTestYear); + EXPECT_EQ(datetime.month(), kTestMonth); + EXPECT_EQ(datetime.day(), kTestDay); + EXPECT_EQ(datetime.day_of_year(), kTestDayOfYear); + EXPECT_EQ(datetime.day_of_week(), kTestDayOfWeek); + + EXPECT_EQ(datetime.hour(), kTestHour); + EXPECT_EQ(datetime.minute(), kTestMinute); + EXPECT_EQ(datetime.second(), kTestSecond); + EXPECT_EQ(datetime.millisecond(), kTestMs); +} + +TEST(WvDateTimeUtilTest, FromUnixMilliseconds) { + const DateTime datetime = DateTime::FromUnixMilliseconds(kTestUnixTimeMs); + ASSERT_TRUE(datetime.IsSet()); + EXPECT_EQ(datetime.epoch_seconds(), kTestTimeS); + EXPECT_EQ(datetime.epoch_milliseconds(), kTestTimeMs); + + EXPECT_EQ(datetime.year(), kTestYear); + EXPECT_EQ(datetime.month(), kTestMonth); + EXPECT_EQ(datetime.day(), kTestDay); + EXPECT_EQ(datetime.day_of_year(), kTestDayOfYear); + EXPECT_EQ(datetime.day_of_week(), kTestDayOfWeek); + + EXPECT_EQ(datetime.hour(), kTestHour); + EXPECT_EQ(datetime.minute(), kTestMinute); + EXPECT_EQ(datetime.second(), kTestSecond); + EXPECT_EQ(datetime.millisecond(), kTestMs); +} + +TEST(WvDateTimeUtilTest, Min) { + // Thursday, January 1st, 1970 00:00:00.001 + const DateTime datetime = DateTime::Min(); + ASSERT_TRUE(datetime.IsSet()); + EXPECT_EQ(datetime.epoch_seconds(), Seconds(0)); + EXPECT_EQ(datetime.epoch_milliseconds(), Milliseconds(1)); + + EXPECT_EQ(datetime.year(), 1970u); + EXPECT_EQ(datetime.month(), 1u); + EXPECT_EQ(datetime.day(), 1u); + EXPECT_EQ(datetime.day_of_year(), 1u); + EXPECT_EQ(datetime.day_of_week(), 5u); + + EXPECT_EQ(datetime.hour(), 0u); + EXPECT_EQ(datetime.minute(), 0u); + EXPECT_EQ(datetime.second(), 0u); + EXPECT_EQ(datetime.millisecond(), 1u); +} + +TEST(WvDateTimeUtilTest, Max) { + constexpr int64_t kMaxEpochTimeS = 253402300799; + constexpr int64_t kMaxEpochTimeMs = (kMaxEpochTimeS * 1000) + 999; + // Friday, December 31st, 9999 23:59:59.999 + const DateTime datetime = DateTime::Max(); + ASSERT_TRUE(datetime.IsSet()); + EXPECT_EQ(datetime.epoch_seconds(), Seconds(kMaxEpochTimeS)); + EXPECT_EQ(datetime.epoch_milliseconds(), Milliseconds(kMaxEpochTimeMs)); + + EXPECT_EQ(datetime.year(), 9999u); + EXPECT_EQ(datetime.month(), 12u); + EXPECT_EQ(datetime.day(), 31u); + EXPECT_EQ(datetime.day_of_year(), 365u); + EXPECT_EQ(datetime.day_of_week(), 6u); + + EXPECT_EQ(datetime.hour(), 23u); + EXPECT_EQ(datetime.minute(), 59u); + EXPECT_EQ(datetime.second(), 59u); + EXPECT_EQ(datetime.millisecond(), 999u); +} + +TEST(WvDateTimeUtilTest, Clear) { + DateTime datetime = DateTime::FromUnixMilliseconds(kTestUnixTimeMs); + ASSERT_TRUE(datetime.IsSet()); + datetime.Clear(); + + EXPECT_FALSE(datetime.IsSet()); + EXPECT_EQ(datetime.year(), 0u); + EXPECT_EQ(datetime.month(), 0u); + EXPECT_EQ(datetime.day(), 0u); + EXPECT_EQ(datetime.day_of_year(), 0u); + EXPECT_EQ(datetime.day_of_week(), 0u); + + EXPECT_EQ(datetime.hour(), 0u); + EXPECT_EQ(datetime.minute(), 0u); + EXPECT_EQ(datetime.second(), 0u); + EXPECT_EQ(datetime.millisecond(), 0u); +} + +TEST(WvDateTimeUtilTest, Comparison) { + const DateTime datetime_a = DateTime::Min(); + const DateTime datetime_b = + DateTime::FromUnixMilliseconds(kTestUnixTimeMs - kOneDayMs); + const DateTime datetime_c = DateTime::FromUnixMilliseconds(kTestUnixTimeMs); + const DateTime datetime_d = + DateTime::FromUnixMilliseconds(kTestUnixTimeMs + 1); + const DateTime datetime_e = DateTime::Max(); + + // Equality + EXPECT_EQ(datetime_a, datetime_a); + EXPECT_EQ(datetime_b, datetime_b); + EXPECT_EQ(datetime_c, datetime_c); + EXPECT_EQ(datetime_d, datetime_d); + EXPECT_EQ(datetime_e, datetime_e); + + // Inequality. + EXPECT_NE(datetime_a, datetime_b); + EXPECT_NE(datetime_a, datetime_c); + EXPECT_NE(datetime_a, datetime_d); + EXPECT_NE(datetime_a, datetime_e); + EXPECT_NE(datetime_b, datetime_c); + EXPECT_NE(datetime_b, datetime_d); + EXPECT_NE(datetime_b, datetime_e); + EXPECT_NE(datetime_c, datetime_d); + EXPECT_NE(datetime_c, datetime_e); + EXPECT_NE(datetime_d, datetime_e); + + // Less than + EXPECT_LT(datetime_a, datetime_b); + EXPECT_LT(datetime_a, datetime_c); + EXPECT_LT(datetime_a, datetime_d); + EXPECT_LT(datetime_a, datetime_e); + EXPECT_LT(datetime_b, datetime_c); + EXPECT_LT(datetime_b, datetime_d); + EXPECT_LT(datetime_b, datetime_e); + EXPECT_LT(datetime_c, datetime_d); + EXPECT_LT(datetime_c, datetime_e); + EXPECT_LT(datetime_d, datetime_e); + + // Less than or equal. + EXPECT_LE(datetime_a, datetime_a); + EXPECT_LE(datetime_a, datetime_b); + EXPECT_LE(datetime_a, datetime_c); + EXPECT_LE(datetime_a, datetime_d); + EXPECT_LE(datetime_a, datetime_e); + EXPECT_LE(datetime_b, datetime_b); + EXPECT_LE(datetime_b, datetime_c); + EXPECT_LE(datetime_b, datetime_d); + EXPECT_LE(datetime_b, datetime_e); + EXPECT_LE(datetime_c, datetime_c); + EXPECT_LE(datetime_c, datetime_d); + EXPECT_LE(datetime_c, datetime_e); + EXPECT_LE(datetime_d, datetime_d); + EXPECT_LE(datetime_d, datetime_e); + EXPECT_LE(datetime_e, datetime_e); + + // Greater than + EXPECT_GT(datetime_b, datetime_a); + EXPECT_GT(datetime_c, datetime_a); + EXPECT_GT(datetime_d, datetime_a); + EXPECT_GT(datetime_e, datetime_a); + EXPECT_GT(datetime_c, datetime_b); + EXPECT_GT(datetime_d, datetime_b); + EXPECT_GT(datetime_e, datetime_b); + EXPECT_GT(datetime_d, datetime_c); + EXPECT_GT(datetime_e, datetime_c); + EXPECT_GT(datetime_e, datetime_d); + + // Greater than or equal. + EXPECT_GE(datetime_a, datetime_a); + EXPECT_GE(datetime_b, datetime_a); + EXPECT_GE(datetime_c, datetime_a); + EXPECT_GE(datetime_d, datetime_a); + EXPECT_GE(datetime_e, datetime_a); + EXPECT_GE(datetime_b, datetime_b); + EXPECT_GE(datetime_c, datetime_b); + EXPECT_GE(datetime_d, datetime_b); + EXPECT_GE(datetime_e, datetime_b); + EXPECT_GE(datetime_c, datetime_c); + EXPECT_GE(datetime_d, datetime_c); + EXPECT_GE(datetime_e, datetime_c); + EXPECT_GE(datetime_d, datetime_d); + EXPECT_GE(datetime_e, datetime_d); + EXPECT_GE(datetime_e, datetime_e); +} + +TEST(WvDateTimeUtilTest, Addition_WithDuration) { + const DateTime start = DateTime::FromUnixMilliseconds(kTestUnixTimeMs); + ASSERT_TRUE(start.IsSet()); + + const Duration one_day = Duration::FromMilliseconds(kOneDayMs); + const DateTime tomorrow = start + one_day; + ASSERT_TRUE(tomorrow.IsSet()); + + EXPECT_EQ(tomorrow.epoch_seconds(), kTestTimeS + Seconds(kOneDayS)); + EXPECT_EQ(tomorrow.epoch_milliseconds(), + kTestTimeMs + Milliseconds(kOneDayMs)); + + EXPECT_EQ(tomorrow.year(), kTestYear); + EXPECT_EQ(tomorrow.month(), kTestMonth); + EXPECT_EQ(tomorrow.day(), kTestDay + 1); + EXPECT_EQ(tomorrow.day_of_year(), kTestDayOfYear + 1); + EXPECT_EQ(tomorrow.day_of_week(), DayOfWeekAddition(kTestDayOfWeek, 1)); + + EXPECT_EQ(tomorrow.hour(), kTestHour); + EXPECT_EQ(tomorrow.minute(), kTestMinute); + EXPECT_EQ(tomorrow.second(), kTestSecond); + EXPECT_EQ(tomorrow.millisecond(), kTestMs); + + // Note: This is 30 days for April to May. + const Duration one_month = Duration::FromSeconds(kOneDayS * 30); + const DateTime next_month = start + one_month; + EXPECT_EQ(next_month.year(), kTestYear); + EXPECT_EQ(next_month.month(), kTestMonth + 1); + EXPECT_EQ(next_month.day(), kTestDay); + EXPECT_EQ(next_month.day_of_year(), kTestDayOfYear + 30); + EXPECT_EQ(next_month.day_of_week(), DayOfWeekAddition(kTestDayOfWeek, 30)); + + EXPECT_EQ(next_month.hour(), kTestHour); + EXPECT_EQ(next_month.minute(), kTestMinute); + EXPECT_EQ(next_month.second(), kTestSecond); + EXPECT_EQ(next_month.millisecond(), kTestMs); + + // Note: This should roll over a day. + const uint32_t day_inc_11h = (kTestHour + 11) / 24; + const Duration eleven_hours = Duration::FromSeconds(kOneHourS * 11); + const DateTime hours_later = start + eleven_hours; + EXPECT_EQ(hours_later.year(), kTestYear); + EXPECT_EQ(hours_later.month(), kTestMonth); + EXPECT_EQ(hours_later.day(), kTestDay + day_inc_11h); + EXPECT_EQ(hours_later.day_of_year(), kTestDayOfYear + day_inc_11h); + EXPECT_EQ(hours_later.day_of_week(), + DayOfWeekAddition(kTestDayOfWeek, day_inc_11h)); + + const uint32_t new_hour_11h = (kTestHour + 11) % 24; + EXPECT_EQ(hours_later.hour(), new_hour_11h); + EXPECT_EQ(hours_later.minute(), kTestMinute); + EXPECT_EQ(hours_later.second(), kTestSecond); + EXPECT_EQ(hours_later.millisecond(), kTestMs); + + const Duration five_years = + Duration::FromSeconds(kOneCommonYearS * 4 + kOneLeapYearS); + const DateTime fewer_years_later = start + five_years; + EXPECT_EQ(fewer_years_later.year(), kTestYear + 5); + EXPECT_EQ(fewer_years_later.month(), kTestMonth); + EXPECT_EQ(fewer_years_later.day(), kTestDay); + + EXPECT_EQ(fewer_years_later.hour(), kTestHour); + EXPECT_EQ(fewer_years_later.minute(), kTestMinute); + EXPECT_EQ(fewer_years_later.second(), kTestSecond); + EXPECT_EQ(fewer_years_later.millisecond(), kTestMs); + + const DateTime same_time = start + Duration::Zero(); + EXPECT_EQ(start, same_time); +} + +TEST(WvDateTimeUtilTest, Addition_WithDuration_Invalid) { + // Addition with an unset date time is not allowed. + const DateTime unset; + ASSERT_FALSE(unset.IsSet()); + const Duration one_month = Duration::FromSeconds(kOneDayS * 30); + DateTime result = unset + one_month; + EXPECT_FALSE(result.IsSet()); + + // Addition that causes overflow is not allowed. + result = DateTime::Max() + Duration::FromMilliseconds(1); + EXPECT_FALSE(result.IsSet()); + + result = DateTime::Min() + Duration::FromSeconds(kOneCommonYearS * 10000); + EXPECT_FALSE(result.IsSet()); +} + +TEST(WvDateTimeUtilTest, Subtraction_WithDuration) { + const DateTime start = DateTime::FromUnixMilliseconds(kTestUnixTimeMs); + ASSERT_TRUE(start.IsSet()); + + const Duration one_day = Duration::FromMilliseconds(kOneDayMs); + const DateTime yesterday = start - one_day; + ASSERT_TRUE(yesterday.IsSet()); + + EXPECT_EQ(yesterday.epoch_seconds(), Seconds(kTestUnixTimeS - kOneDayS)); + EXPECT_EQ(yesterday.epoch_milliseconds(), + Milliseconds(kTestUnixTimeMs - kOneDayMs)); + + EXPECT_EQ(yesterday.year(), kTestYear); + EXPECT_EQ(yesterday.month(), kTestMonth); + EXPECT_EQ(yesterday.day(), kTestDay - 1); + EXPECT_EQ(yesterday.day_of_year(), kTestDayOfYear - 1); + EXPECT_EQ(yesterday.day_of_week(), DayOfWeekSubtraction(kTestDayOfWeek, 1)); + + EXPECT_EQ(yesterday.hour(), kTestHour); + EXPECT_EQ(yesterday.minute(), kTestMinute); + EXPECT_EQ(yesterday.second(), kTestSecond); + EXPECT_EQ(yesterday.millisecond(), kTestMs); + + // Note: This is a 31 days for March. + const Duration one_month = Duration::FromSeconds(kOneDayS * 31); + const DateTime previous_month = start - one_month; + EXPECT_EQ(previous_month.year(), kTestYear); + EXPECT_EQ(previous_month.month(), kTestMonth - 1); + EXPECT_EQ(previous_month.day(), kTestDay); + EXPECT_EQ(previous_month.day_of_year(), kTestDayOfYear - 31); + EXPECT_EQ(previous_month.day_of_week(), + DayOfWeekSubtraction(kTestDayOfWeek, 31)); + + EXPECT_EQ(previous_month.hour(), kTestHour); + EXPECT_EQ(previous_month.minute(), kTestMinute); + EXPECT_EQ(previous_month.second(), kTestSecond); + EXPECT_EQ(previous_month.millisecond(), kTestMs); + + const uint32_t day_dec_14h = (kTestHour < 14) ? 1 : 0; + const Duration fourteen_hours = Duration::FromSeconds(kOneHourS * 14); + const DateTime hours_earlier = start - fourteen_hours; + EXPECT_EQ(yesterday.year(), kTestYear); + EXPECT_EQ(yesterday.month(), kTestMonth); + EXPECT_EQ(yesterday.day(), kTestDay - day_dec_14h); + EXPECT_EQ(yesterday.day_of_year(), kTestDayOfYear - day_dec_14h); + EXPECT_EQ(yesterday.day_of_week(), + DayOfWeekSubtraction(kTestDayOfWeek, day_dec_14h)); + + const uint32_t new_hour_14h = + (kTestHour < 14) ? (kTestHour + 10) : kTestHour - 14; + EXPECT_EQ(hours_earlier.hour(), new_hour_14h); + EXPECT_EQ(hours_earlier.minute(), kTestMinute); + EXPECT_EQ(hours_earlier.second(), kTestSecond); + EXPECT_EQ(hours_earlier.millisecond(), kTestMs); + + const Duration five_years = + Duration::FromSeconds(kOneCommonYearS * 3 + kOneLeapYearS * 2); + const DateTime few_years_earlier = start - five_years; + EXPECT_EQ(few_years_earlier.year(), kTestYear - 5); + EXPECT_EQ(few_years_earlier.month(), kTestMonth); + EXPECT_EQ(few_years_earlier.day(), kTestDay); + + EXPECT_EQ(few_years_earlier.hour(), kTestHour); + EXPECT_EQ(few_years_earlier.minute(), kTestMinute); + EXPECT_EQ(few_years_earlier.second(), kTestSecond); + EXPECT_EQ(few_years_earlier.millisecond(), kTestMs); + + const DateTime same_time = start - Duration::Zero(); + EXPECT_EQ(start, same_time); +} + +TEST(WvDateTimeUtilTest, Subtraction_WithDuration_Invalid) { + // Subtract with an unset date time is not allowed. + const DateTime unset; + ASSERT_FALSE(unset.IsSet()); + const Duration one_month = Duration::FromSeconds(kOneDayS * 30); + DateTime result = unset - one_month; + EXPECT_FALSE(result.IsSet()); + + // Subtract that causes overflow is not allowed. + result = DateTime::Min() - Duration::FromMilliseconds(1); + EXPECT_FALSE(result.IsSet()); + + result = DateTime::Max() - Duration::FromSeconds(kOneCommonYearS * 10000); + EXPECT_FALSE(result.IsSet()); +} + +TEST(WvDateTimeUtilTest, Difference) { + const Duration expected_diff = + Duration(Hours(15) + Minutes(6) + Seconds(30) + Milliseconds(123)); + const DateTime datetime_a = + DateTime::FromUnixSeconds(kTestUnixTimeS, kTestMs); + const DateTime datetime_b = DateTime::FromUnixMilliseconds( + datetime_a.epoch_milliseconds() + expected_diff.total_milliseconds()); + + const Duration diff_ab = datetime_b - datetime_a; + EXPECT_EQ(diff_ab, expected_diff); + + const Duration diff_ba = datetime_a - datetime_b; + EXPECT_EQ(diff_ba, -expected_diff); + + // Very big diff + const Duration diff_min_max = DateTime::Max() - DateTime::Min(); + EXPECT_TRUE(diff_min_max.IsPositive()); + const Duration diff_max_min = DateTime::Min() - DateTime::Max(); + EXPECT_TRUE(diff_max_min.IsNegative()); + EXPECT_EQ(diff_min_max, -diff_max_min); + + // No difference. + const Duration diff_aa = datetime_a - datetime_a; + EXPECT_TRUE(diff_aa.IsZero()); +} + +TEST(WvDateTimeUtilTest, Difference_Invalid) { + const DateTime datetime = DateTime::FromUnixSeconds(kTestUnixTimeS, kTestMs); + const DateTime unset; + + Duration diff = datetime - unset; + EXPECT_TRUE(diff.IsZero()); + + diff = unset - datetime; + EXPECT_TRUE(diff.IsZero()); +} + +TEST(WvDateTimeUtilTest, ToString) { + // Default should ISO date times with short timezones, and milliseconds + // only printed if non-zero. + DateTime datetime = DateTime::FromUnixSeconds(kTestUnixTimeS, kTestMs); + EXPECT_EQ(datetime.ToString(), "2024-04-19T13:36:18.507Z"); + datetime = DateTime::FromUnixSeconds(kTestUnixTimeS); + EXPECT_EQ(datetime.ToString(), "2024-04-19T13:36:18Z"); + + EXPECT_EQ(DateTime::Min().ToString(), "1970-01-01T00:00:00.001Z"); + EXPECT_EQ(DateTime::Max().ToString(), "9999-12-31T23:59:59.999Z"); + + EXPECT_EQ(DateTime().ToString(), ""); +} +} // namespace test +} // namespace wvutil diff --git a/util/test/wv_duration_unittest.cpp b/util/test/wv_duration_unittest.cpp new file mode 100644 index 0000000..78b80f9 --- /dev/null +++ b/util/test/wv_duration_unittest.cpp @@ -0,0 +1,1158 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "wv_duration.h" + +#include + +#include + +#include + +namespace wvutil { +namespace test { +namespace { +// Milliseconds per second. +constexpr int64_t kMsPerS = 1000; +// Seconds per minute. +constexpr int64_t kSPerM = 60; +// Minutes per hour. +constexpr int64_t kMPerH = 60; +// Hours per day. +constexpr int64_t kHPerD = 24; + +// Milliseconds per minute. +constexpr int64_t kMsPerM = kMsPerS * kSPerM; +// Milliseconds per hour. +constexpr int64_t kMsPerH = kMsPerM * kMPerH; +// Milliseconds per day. +constexpr int64_t kMsPerD = kMsPerH * kHPerD; + +// Seconds per hour. +constexpr int64_t kSPerH = kSPerM * kMPerH; +// Seconds per day. +constexpr int64_t kSPerD = kSPerH * kHPerD; + +// Minutes per day. +constexpr int64_t kMPerD = kMPerH * kHPerD; +} // namespace + +// Note: Using "WvDurationUtilTest" instead of "DurationTest" as +// it be confused with policy duration tests. + +TEST(WvDurationUtilTest, SpecialInitializers) { + const Duration zero = Duration::Zero(); + EXPECT_TRUE(zero.IsZero()); + EXPECT_FALSE(zero.IsNegative()); + EXPECT_FALSE(zero.IsPositive()); + EXPECT_TRUE(zero.IsNonNegative()); + EXPECT_EQ(zero.total_milliseconds().count(), 0); + EXPECT_EQ(zero.total_seconds().count(), 0); + EXPECT_EQ(zero.total_minutes().count(), 0); + EXPECT_EQ(zero.total_hours().count(), 0); + EXPECT_EQ(zero.total_days().count(), 0); + + const Duration positive_seconds = Duration::FromSeconds(104000); + EXPECT_FALSE(positive_seconds.IsZero()); + EXPECT_FALSE(positive_seconds.IsNegative()); + EXPECT_TRUE(positive_seconds.IsPositive()); + EXPECT_TRUE(positive_seconds.IsNonNegative()); + // Totals + EXPECT_EQ(positive_seconds.total_milliseconds().count(), 104000000); + EXPECT_EQ(positive_seconds.total_seconds().count(), 104000); + EXPECT_EQ(positive_seconds.total_minutes().count(), 1733); + EXPECT_EQ(positive_seconds.total_hours().count(), 28); + EXPECT_EQ(positive_seconds.total_days().count(), 1); + // Components + EXPECT_EQ(positive_seconds.milliseconds(), 0u); + EXPECT_EQ(positive_seconds.seconds(), 20u); + EXPECT_EQ(positive_seconds.minutes(), 53u); + EXPECT_EQ(positive_seconds.hours(), 4u); + EXPECT_EQ(positive_seconds.days(), 1u); + + const Duration positive_milliseconds = Duration::FromMilliseconds(104000123); + EXPECT_FALSE(positive_milliseconds.IsZero()); + EXPECT_FALSE(positive_milliseconds.IsNegative()); + EXPECT_TRUE(positive_milliseconds.IsPositive()); + EXPECT_TRUE(positive_milliseconds.IsNonNegative()); + // Totals + EXPECT_EQ(positive_milliseconds.total_milliseconds().count(), 104000123); + EXPECT_EQ(positive_milliseconds.total_seconds().count(), 104000); + EXPECT_EQ(positive_milliseconds.total_minutes().count(), 1733); + EXPECT_EQ(positive_milliseconds.total_hours().count(), 28); + EXPECT_EQ(positive_milliseconds.total_days().count(), 1); + // Components + EXPECT_EQ(positive_milliseconds.milliseconds(), 123u); + EXPECT_EQ(positive_milliseconds.seconds(), 20u); + EXPECT_EQ(positive_milliseconds.minutes(), 53u); + EXPECT_EQ(positive_milliseconds.hours(), 4u); + EXPECT_EQ(positive_milliseconds.days(), 1u); + + const Duration negative_seconds = Duration::FromSeconds(-104000); + EXPECT_FALSE(negative_seconds.IsZero()); + EXPECT_TRUE(negative_seconds.IsNegative()); + EXPECT_FALSE(negative_seconds.IsPositive()); + EXPECT_FALSE(negative_seconds.IsNonNegative()); + // Totals + EXPECT_EQ(negative_seconds.total_milliseconds().count(), -104000000); + EXPECT_EQ(negative_seconds.total_seconds().count(), -104000); + EXPECT_EQ(negative_seconds.total_minutes().count(), -1733); + EXPECT_EQ(negative_seconds.total_hours().count(), -28); + EXPECT_EQ(negative_seconds.total_days().count(), -1); + // Components (presented as positives) + EXPECT_EQ(negative_seconds.milliseconds(), 0u); + EXPECT_EQ(negative_seconds.seconds(), 20u); + EXPECT_EQ(negative_seconds.minutes(), 53u); + EXPECT_EQ(negative_seconds.hours(), 4u); + EXPECT_EQ(negative_seconds.days(), 1u); +} + +TEST(WvDurationUtilTest, GetAbsolute) { + const Duration positive = Duration::FromMilliseconds(4000123); + const Duration negative = Duration::FromMilliseconds(-4000123); + const Duration zero = Duration::Zero(); + EXPECT_EQ(negative.GetAbsolute(), positive); + EXPECT_EQ(positive.GetAbsolute(), positive); + EXPECT_EQ(zero.GetAbsolute(), zero); +} + +TEST(WvDurationUtilTest, GetTruncate) { + // Truncate a positive value. + const Duration positive_milliseconds = Duration::FromMilliseconds(104000123); + // Expected values of |positive_milliseconds| + const Duration positive_seconds = Duration::FromSeconds(104000); + const Duration positive_minutes = Duration::FromSeconds(1733 * 60); + const Duration positive_hours = Duration::FromSeconds(28 * 60 * 60); + const Duration positive_days = Duration::FromSeconds(1 * 24 * 60 * 60); + EXPECT_EQ(positive_milliseconds.GetTruncateBySeconds(), positive_seconds); + EXPECT_EQ(positive_milliseconds.GetTruncateByMinutes(), positive_minutes); + EXPECT_EQ(positive_milliseconds.GetTruncateByHours(), positive_hours); + EXPECT_EQ(positive_milliseconds.GetTruncateByDays(), positive_days); + + // Truncate a negative value. + const Duration negative_milliseconds = Duration::FromMilliseconds(-104000123); + // Expected values of |negative_milliseconds| + const Duration negative_seconds = Duration::FromSeconds(-104000); + const Duration negative_minutes = Duration::FromSeconds(-(1733 * 60)); + const Duration negative_hours = Duration::FromSeconds(-(28 * 60 * 60)); + const Duration negative_days = Duration::FromSeconds(-(1 * 24 * 60 * 60)); + EXPECT_EQ(negative_milliseconds.GetTruncateBySeconds(), negative_seconds); + EXPECT_EQ(negative_milliseconds.GetTruncateByMinutes(), negative_minutes); + EXPECT_EQ(negative_milliseconds.GetTruncateByHours(), negative_hours); + EXPECT_EQ(negative_milliseconds.GetTruncateByDays(), negative_days); +} + +TEST(WvDurationUtilTest, Comparison) { + const Duration duration_a = Duration::FromMilliseconds(-4000123); + const Duration duration_b = Duration::FromSeconds(-4000); + const Duration duration_c = Duration::Zero(); + const Duration duration_d = Duration::FromSeconds(4000); + const Duration duration_e = Duration::FromMilliseconds(4000123); + + // Equality + EXPECT_EQ(duration_a, duration_a); + EXPECT_EQ(duration_b, duration_b); + EXPECT_EQ(duration_c, duration_c); + EXPECT_EQ(duration_d, duration_d); + EXPECT_EQ(duration_e, duration_e); + + // Inequality. + EXPECT_NE(duration_a, duration_b); + EXPECT_NE(duration_a, duration_c); + EXPECT_NE(duration_a, duration_d); + EXPECT_NE(duration_a, duration_e); + EXPECT_NE(duration_b, duration_c); + EXPECT_NE(duration_b, duration_d); + EXPECT_NE(duration_b, duration_e); + EXPECT_NE(duration_c, duration_d); + EXPECT_NE(duration_c, duration_e); + EXPECT_NE(duration_d, duration_e); + + // Less than + EXPECT_LT(duration_a, duration_b); + EXPECT_LT(duration_a, duration_c); + EXPECT_LT(duration_a, duration_d); + EXPECT_LT(duration_a, duration_e); + EXPECT_LT(duration_b, duration_c); + EXPECT_LT(duration_b, duration_d); + EXPECT_LT(duration_b, duration_e); + EXPECT_LT(duration_c, duration_d); + EXPECT_LT(duration_c, duration_e); + EXPECT_LT(duration_d, duration_e); + + // Less than or equal. + EXPECT_LE(duration_a, duration_a); + EXPECT_LE(duration_a, duration_b); + EXPECT_LE(duration_a, duration_c); + EXPECT_LE(duration_a, duration_d); + EXPECT_LE(duration_a, duration_e); + EXPECT_LE(duration_b, duration_b); + EXPECT_LE(duration_b, duration_c); + EXPECT_LE(duration_b, duration_d); + EXPECT_LE(duration_b, duration_e); + EXPECT_LE(duration_c, duration_c); + EXPECT_LE(duration_c, duration_d); + EXPECT_LE(duration_c, duration_e); + EXPECT_LE(duration_d, duration_d); + EXPECT_LE(duration_d, duration_e); + EXPECT_LE(duration_e, duration_e); + + // Greater than + EXPECT_GT(duration_b, duration_a); + EXPECT_GT(duration_c, duration_a); + EXPECT_GT(duration_d, duration_a); + EXPECT_GT(duration_e, duration_a); + EXPECT_GT(duration_c, duration_b); + EXPECT_GT(duration_d, duration_b); + EXPECT_GT(duration_e, duration_b); + EXPECT_GT(duration_d, duration_c); + EXPECT_GT(duration_e, duration_c); + EXPECT_GT(duration_e, duration_d); + + // Greater than or equal. + EXPECT_GE(duration_a, duration_a); + EXPECT_GE(duration_b, duration_a); + EXPECT_GE(duration_c, duration_a); + EXPECT_GE(duration_d, duration_a); + EXPECT_GE(duration_e, duration_a); + EXPECT_GE(duration_b, duration_b); + EXPECT_GE(duration_c, duration_b); + EXPECT_GE(duration_d, duration_b); + EXPECT_GE(duration_e, duration_b); + EXPECT_GE(duration_c, duration_c); + EXPECT_GE(duration_d, duration_c); + EXPECT_GE(duration_e, duration_c); + EXPECT_GE(duration_d, duration_d); + EXPECT_GE(duration_e, duration_d); + EXPECT_GE(duration_e, duration_e); +} + +TEST(WvDurationUtilTest, Comparison_WithChronoDuration) { + // 1h6m40s123ms + const Duration duration = Duration::FromMilliseconds(104000123); + + // Equality + EXPECT_EQ(duration, Milliseconds(104000123)); + EXPECT_EQ(duration, Microseconds(104000123000)); + EXPECT_EQ(duration, Nanoseconds(104000123000000)); + // Equality (after truncation) + EXPECT_EQ(duration.GetTruncateByDays(), Days(1)); + EXPECT_EQ(duration.GetTruncateByHours(), Hours(28)); + EXPECT_EQ(duration.GetTruncateByMinutes(), Minutes(1733)); + EXPECT_EQ(duration.GetTruncateBySeconds(), Seconds(104000)); + + // Inequality (slight differences). + EXPECT_NE(duration, Milliseconds(104000124)); + EXPECT_NE(duration, Microseconds(104000122999)); + EXPECT_NE(duration, Nanoseconds(104000123000001)); + // Inequality (equal if truncated) + EXPECT_NE(duration, Days(1)); + EXPECT_NE(duration, Hours(28)); + EXPECT_NE(duration, Minutes(1733)); + EXPECT_NE(duration, Seconds(104000)); + + // Less than (slight differences). + EXPECT_LT(duration, Milliseconds(104000124)); + EXPECT_LT(duration, Microseconds(104000123001)); + EXPECT_LT(duration, Nanoseconds(104000123000001)); + // Less than (larger difference). + EXPECT_LT(duration, Days(2)); + EXPECT_LT(duration, Hours(29)); + EXPECT_LT(duration, Minutes(1734)); + EXPECT_LT(duration, Seconds(104001)); + + // Less than or equal (check equal) + EXPECT_LE(duration, Milliseconds(104000123)); + EXPECT_LE(duration, Microseconds(104000123000)); + EXPECT_LE(duration, Nanoseconds(104000123000000)); + // Less than or equal (slight difference) + EXPECT_LE(duration, Milliseconds(104000124)); + EXPECT_LE(duration, Microseconds(104000123001)); + EXPECT_LE(duration, Nanoseconds(104000123000001)); + // Less than or equal (larger differences). + EXPECT_LE(duration, Days(2)); + EXPECT_LE(duration, Hours(29)); + EXPECT_LE(duration, Minutes(1734)); + EXPECT_LE(duration, Seconds(104001)); + + // Greater than (slight differences). + EXPECT_GT(duration, Milliseconds(104000122)); + EXPECT_GT(duration, Microseconds(104000122999)); + EXPECT_GT(duration, Nanoseconds(104000122999999)); + // Less than (larger difference) + EXPECT_GT(duration, Days(1)); + EXPECT_GT(duration, Hours(28)); + EXPECT_GT(duration, Minutes(1733)); + EXPECT_GT(duration, Seconds(104000)); + + // Greater than or equal (check equal) + EXPECT_GE(duration, Milliseconds(104000123)); + EXPECT_GE(duration, Microseconds(104000123000)); + EXPECT_GE(duration, Nanoseconds(104000123000000)); + // Greater than or equal (slight difference) + EXPECT_GE(duration, Milliseconds(104000122)); + EXPECT_GE(duration, Microseconds(104000122999)); + EXPECT_GE(duration, Nanoseconds(104000122999999)); + // Greater than or equal (larger differences). + EXPECT_GE(duration, Days(1)); + EXPECT_GE(duration, Hours(28)); + EXPECT_GE(duration, Minutes(1733)); + EXPECT_GE(duration, Seconds(104000)); +} + +TEST(WvDurationUtilTest, UnaryPlusMinus) { + const Duration negative_milliseconds = Duration::FromMilliseconds(-4000123); + const Duration negative_seconds = Duration::FromSeconds(-4000); + const Duration zero = Duration::Zero(); + const Duration positive_seconds = Duration::FromSeconds(4000); + const Duration positive_milliseconds = Duration::FromMilliseconds(4000123); + + // + operator + EXPECT_EQ(+negative_milliseconds, negative_milliseconds); + EXPECT_EQ(+negative_seconds, negative_seconds); + EXPECT_EQ(+zero, zero); + EXPECT_EQ(+positive_seconds, positive_seconds); + EXPECT_EQ(+positive_milliseconds, positive_milliseconds); + + // - operator + EXPECT_EQ(-negative_milliseconds, positive_milliseconds); + EXPECT_EQ(-negative_seconds, positive_seconds); + EXPECT_EQ(-zero, zero); + EXPECT_EQ(-positive_seconds, negative_seconds); + EXPECT_EQ(-positive_milliseconds, negative_milliseconds); +} + +TEST(WvDurationUtilTest, Addition) { + const Duration ps = Duration::FromSeconds(4000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(4000123); + const Duration ns = Duration::FromSeconds(-4000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-4000123); + const Duration zero = Duration::Zero(); + + // Positive + positive + EXPECT_EQ(ps + pms, pt); + // Positive + negative + EXPECT_EQ(pt + nms, ps); + EXPECT_EQ(pt + ns, pms); + EXPECT_EQ(pt + nt, zero); + EXPECT_EQ(ps + ns, zero); + EXPECT_EQ(pms + nms, zero); + // Positive + zero + EXPECT_EQ(pt + zero, pt); + EXPECT_EQ(ps + zero, ps); + EXPECT_EQ(pms + zero, pms); + // Negative + negative + EXPECT_EQ(ns + nms, nt); + // Negative + positive + EXPECT_EQ(nt + pms, ns); + EXPECT_EQ(nt + ps, nms); + EXPECT_EQ(nt + pt, zero); + EXPECT_EQ(ns + ps, zero); + EXPECT_EQ(nms + pms, zero); + // Negative + zero + EXPECT_EQ(nt + zero, nt); + EXPECT_EQ(ns + zero, ns); + EXPECT_EQ(nms + zero, nms); + // Zero + zero + EXPECT_EQ(zero + zero, zero); + // Zero + positive + EXPECT_EQ(zero + pt, pt); + EXPECT_EQ(zero + ps, ps); + EXPECT_EQ(zero + pms, pms); + // Zero + negative + EXPECT_EQ(zero + nt, nt); + EXPECT_EQ(zero + ns, ns); + EXPECT_EQ(zero + nms, nms); +} + +TEST(WvDurationUtilTest, Subtraction) { + const Duration ps = Duration::FromSeconds(4000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(4000123); + const Duration ns = Duration::FromSeconds(-4000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-4000123); + const Duration zero = Duration::Zero(); + // Positive - positive + EXPECT_EQ(pt - pms, ps); + EXPECT_EQ(pt - ps, pms); + EXPECT_EQ(pt - pt, zero); + EXPECT_EQ(ps - ps, zero); + EXPECT_EQ(pms - pms, zero); + // Positive - negative + EXPECT_EQ(ps - nms, pt); + EXPECT_EQ(pms - ns, pt); + // Positive - zero + EXPECT_EQ(pt - zero, pt); + EXPECT_EQ(ps - zero, ps); + EXPECT_EQ(pms - zero, pms); + // Negative - negative + EXPECT_EQ(nt - ns, nms); + EXPECT_EQ(nt - nms, ns); + EXPECT_EQ(nt - nt, zero); + EXPECT_EQ(ns - ns, zero); + EXPECT_EQ(nms - nms, zero); + // Negative - positive + EXPECT_EQ(ns - pms, nt); + EXPECT_EQ(nms - ps, nt); + // Negative - zero + EXPECT_EQ(nt - zero, nt); + EXPECT_EQ(ns - zero, ns); + EXPECT_EQ(nms - zero, nms); + // Zero - zero + EXPECT_EQ(zero - zero, zero); + // Zero - positive + EXPECT_EQ(zero - pt, nt); + EXPECT_EQ(zero - ps, ns); + EXPECT_EQ(zero - pms, nms); + // Zero - negative + EXPECT_EQ(zero - nt, pt); + EXPECT_EQ(zero - ns, ps); + EXPECT_EQ(zero - nms, pms); +} + +TEST(WvDurationUtilTest, Addition_WithChronoDuration) { + const Duration ps = Duration::FromSeconds(104000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(104000123); + const Duration ns = Duration::FromSeconds(-104000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-104000123); + const Duration zero = Duration::Zero(); + // Positive + positive (small units). + EXPECT_EQ(ps + Milliseconds(123), pt); + // Positive + positive (large units). + EXPECT_EQ(pms + Seconds(104000), pt); + EXPECT_EQ(pms + Minutes(1733) + Seconds(20), pt); + EXPECT_EQ(pms + Hours(28) + Minutes(53) + Seconds(20), pt); + EXPECT_EQ(pms + Days(1) + Hours(4) + Minutes(53) + Seconds(20), pt); + // Positive + negative (small units) + EXPECT_EQ(pt + Milliseconds(-123), ps); + // Positive + negative (large units) + EXPECT_EQ(pt + Seconds(-104000), pms); + EXPECT_EQ(pt + Minutes(-1733) + Seconds(-20), pms); + EXPECT_EQ(pt + Hours(-28) + Minutes(-53) + Seconds(-20), pms); + EXPECT_EQ(pt + Days(-1) + Hours(-4) + Minutes(-53) + Seconds(-20), pms); + // Positive + zero (all units) + EXPECT_EQ(pt + Milliseconds::zero(), pt); + EXPECT_EQ(pt + Seconds::zero(), pt); + EXPECT_EQ(pt + Minutes::zero(), pt); + EXPECT_EQ(pt + Hours::zero(), pt); + + // Negative + positive (small units). + EXPECT_EQ(nt + Milliseconds(123), ns); + // Negative + positive (large units). + EXPECT_EQ(nt + Seconds(104000), nms); + EXPECT_EQ(nt + Minutes(1733) + Seconds(20), nms); + EXPECT_EQ(nt + Hours(28) + Minutes(53) + Seconds(20), nms); + EXPECT_EQ(nt + Days(1) + Hours(4) + Minutes(53) + Seconds(20), nms); + // Negative + negative (small units) + EXPECT_EQ(ns + Milliseconds(-123), nt); + // Negative + negative (large units) + EXPECT_EQ(nms + Seconds(-104000), nt); + EXPECT_EQ(nms + Minutes(-1733) + Seconds(-20), nt); + EXPECT_EQ(nms + Hours(-28) + Minutes(-53) + Seconds(-20), nt); + EXPECT_EQ(nms + Days(-1) + Hours(-4) + Minutes(-53) + Seconds(-20), nt); + // Negative + zero (all units) + EXPECT_EQ(nt + Milliseconds::zero(), nt); + EXPECT_EQ(nt + Seconds::zero(), nt); + EXPECT_EQ(nt + Minutes::zero(), nt); + EXPECT_EQ(nt + Hours::zero(), nt); + + // Zero + zero (all units) + EXPECT_EQ(zero + Milliseconds::zero(), zero); + EXPECT_EQ(zero + Seconds::zero(), zero); + EXPECT_EQ(zero + Minutes::zero(), zero); + EXPECT_EQ(zero + Hours::zero(), zero); + // Zero + positive (small units) + EXPECT_EQ(zero + Milliseconds(123), pms); + // Zero + positive (large units) + EXPECT_EQ(zero + Seconds(104000), ps); + EXPECT_EQ(zero + Minutes(1733) + Seconds(20), ps); + EXPECT_EQ(zero + Hours(28) + Minutes(53) + Seconds(20), ps); + EXPECT_EQ(zero + Days(1) + Hours(4) + Minutes(53) + Seconds(20), ps); + // Zero + negative (small units) + EXPECT_EQ(zero + Milliseconds(-123), nms); + // Zero + negative (large units) + EXPECT_EQ(zero + Seconds(-104000), ns); + EXPECT_EQ(zero + Minutes(-1733) + Seconds(-20), ns); + EXPECT_EQ(zero + Hours(-28) + Minutes(-53) + Seconds(-20), ns); + EXPECT_EQ(zero + Days(-1) + Hours(-4) + Minutes(-53) + Seconds(-20), ns); +} + +TEST(WvDurationUtilTest, Subtraction_WithChronoDuration) { + const Duration ps = Duration::FromSeconds(4000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(4000123); + const Duration ns = Duration::FromSeconds(-4000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-4000123); + const Duration zero = Duration::Zero(); + + // Positive - positive (small units). + EXPECT_EQ(pt - Milliseconds(123), ps); + // Positive - positive (large units). + EXPECT_EQ(pt - Seconds(4000), pms); + EXPECT_EQ(pt - Minutes(66) - Seconds(40), pms); + EXPECT_EQ(pt - Hours(1) - Minutes(6) - Seconds(40), pms); + // Positive - negative (small units) + EXPECT_EQ(ps - Milliseconds(-123), pt); + // Positive - negative (large units) + EXPECT_EQ(pms - Seconds(-4000), pt); + EXPECT_EQ(pms - Minutes(-66) - Seconds(-40), pt); + EXPECT_EQ(pms - Hours(-1) - Minutes(-6) - Seconds(-40), pt); + // Positive - zero (all units) + EXPECT_EQ(pt - Milliseconds::zero(), pt); + EXPECT_EQ(pt - Seconds::zero(), pt); + EXPECT_EQ(pt - Minutes::zero(), pt); + EXPECT_EQ(pt - Hours::zero(), pt); + + // Negative - positive (small units). + EXPECT_EQ(ns - Milliseconds(123), nt); + // Negative - positive (large units). + EXPECT_EQ(nms - Seconds(4000), nt); + EXPECT_EQ(nms - Minutes(66) - Seconds(40), nt); + EXPECT_EQ(nms - Hours(1) - Minutes(6) - Seconds(40), nt); + // Negative - negative (small units) + EXPECT_EQ(nt - Milliseconds(-123), ns); + // Negative - negative (large units) + EXPECT_EQ(nt - Seconds(-4000), nms); + EXPECT_EQ(nt - Minutes(-66) - Seconds(-40), nms); + EXPECT_EQ(nt - Hours(-1) - Minutes(-6) - Seconds(-40), nms); + // Negative - zero (all units) + EXPECT_EQ(nt - Milliseconds::zero(), nt); + EXPECT_EQ(nt - Seconds::zero(), nt); + EXPECT_EQ(nt - Minutes::zero(), nt); + EXPECT_EQ(nt - Hours::zero(), nt); + + // Zero - zero (all units) + EXPECT_EQ(zero - Milliseconds::zero(), zero); + EXPECT_EQ(zero - Seconds::zero(), zero); + EXPECT_EQ(zero - Minutes::zero(), zero); + EXPECT_EQ(zero - Hours::zero(), zero); + // Zero - positive (small units) + EXPECT_EQ(zero - Milliseconds(123), nms); + // Zero - positive (large units) + EXPECT_EQ(zero - Seconds(4000), ns); + EXPECT_EQ(zero - Minutes(66) - Seconds(40), ns); + EXPECT_EQ(zero - Hours(1) - Minutes(6) - Seconds(40), ns); + // Zero - negative (small units) + EXPECT_EQ(zero - Milliseconds(-123), pms); + // Zero - negative (large units) + EXPECT_EQ(zero - Seconds(-4000), ps); + EXPECT_EQ(zero - Minutes(-66) - Seconds(-40), ps); + EXPECT_EQ(zero - Hours(-1) - Minutes(-6) - Seconds(-40), ps); +} + +TEST(WvDurationUtilTest, Increment) { + const Duration ps = Duration::FromSeconds(4000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(4000123); + const Duration ns = Duration::FromSeconds(-4000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-4000123); + const Duration zero = Duration::Zero(); + + Duration acc = zero; + // zero += positive -> positive + acc += ps; + EXPECT_EQ(acc, ps); + // positive += positive -> positive + acc += pms; + EXPECT_EQ(acc, pt); + // positive += zero -> positive + acc += zero; + EXPECT_EQ(acc, pt); + // positive += negative -> positive + acc += ns; + EXPECT_EQ(acc, pms); + // positive += negative -> zero + acc += nms; + EXPECT_EQ(acc, zero); + // zero += zero -> zero + acc += zero; + EXPECT_EQ(acc, zero); + // zero += negative -> negative + acc += nms; + EXPECT_EQ(acc, nms); + // negative += negative -> negative + acc += ns; + EXPECT_EQ(acc, nt); + // negative += zero -> zero + acc += zero; + EXPECT_EQ(acc, nt); + // negative += positive -> negative + acc += pms; + EXPECT_EQ(acc, ns); + // negative += positive -> zero + acc += ps; + EXPECT_EQ(acc, zero); +} + +TEST(WvDurationUtilTest, Decrement) { + const Duration ps = Duration::FromSeconds(4000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(4000123); + const Duration ns = Duration::FromSeconds(-4000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-4000123); + const Duration zero = Duration::Zero(); + + Duration acc = zero; + // zero -= positive -> negative + acc -= ps; + EXPECT_EQ(acc, ns); + // negative -= positive -> negative + acc -= pms; + EXPECT_EQ(acc, nt); + // negative -= zero -> negative + acc -= zero; + EXPECT_EQ(acc, nt); + // negative -= negative -> negative + acc -= ns; + EXPECT_EQ(acc, nms); + // negative -= negative -> zero + acc -= nms; + EXPECT_EQ(acc, zero); + // zero -= zero -> zero + acc -= zero; + EXPECT_EQ(acc, zero); + // zero -= negative -> positive + acc -= nms; + EXPECT_EQ(acc, pms); + // positive -= negative -> positive + acc -= ns; + EXPECT_EQ(acc, pt); + // positive -= zero -> positive + acc -= zero; + EXPECT_EQ(acc, pt); + // positive -= positive -> positive + acc -= pms; + EXPECT_EQ(acc, ps); + // positive -= positive -> zero + acc -= ps; + EXPECT_EQ(acc, zero); +} + +TEST(WvDurationUtilTest, Increment_WithChronoDuration) { + const Duration ps = Duration::FromSeconds(104000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(104000123); + const Duration ns = Duration::FromSeconds(-104000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-104000123); + const Duration zero = Duration::Zero(); + + Duration acc = zero; + // zero += positive (small units) -> positive + acc += Milliseconds(123); + EXPECT_EQ(acc, pms); + // positive += positive (larger units) -> positive + acc += Days(1); + acc += Hours(4); + acc += Minutes(53); + acc += Seconds(20); + EXPECT_EQ(acc, pt); + // positive += zero (all units) -> positive + acc += Milliseconds::zero(); + EXPECT_EQ(acc, pt); + acc += Seconds::zero(); + EXPECT_EQ(acc, pt); + acc += Minutes::zero(); + EXPECT_EQ(acc, pt); + acc += Hours::zero(); + EXPECT_EQ(acc, pt); + acc += Days::zero(); + EXPECT_EQ(acc, pt); + // positive += negative (small units) -> positive + acc += Milliseconds(-123); + EXPECT_EQ(acc, ps); + // positive += negative (large units) -> zero + acc += Days(-1); + acc += Hours(-4); + acc += Minutes(-53); + acc += Seconds(-20); + EXPECT_EQ(acc, zero); + // zero += zero (all units) -> zero + acc += Milliseconds::zero(); + EXPECT_EQ(acc, zero); + acc += Seconds::zero(); + EXPECT_EQ(acc, zero); + acc += Minutes::zero(); + EXPECT_EQ(acc, zero); + acc += Hours::zero(); + EXPECT_EQ(acc, zero); + acc += Days::zero(); + EXPECT_EQ(acc, zero); + // zero += negative (large units) -> negative + acc += Days(-1); + acc += Hours(-4); + acc += Minutes(-53); + acc += Seconds(-20); + EXPECT_EQ(acc, ns); + // negative += negative (small unit) -> negative + acc += Milliseconds(-123); + EXPECT_EQ(acc, nt); + // negative += zero (all units) -> negative + acc += Milliseconds::zero(); + EXPECT_EQ(acc, nt); + acc += Seconds::zero(); + EXPECT_EQ(acc, nt); + acc += Minutes::zero(); + EXPECT_EQ(acc, nt); + acc += Hours::zero(); + EXPECT_EQ(acc, nt); + acc += Days::zero(); + EXPECT_EQ(acc, nt); + // negative += positive (large units) -> negative + acc += Days(1); + acc += Hours(4); + acc += Minutes(53); + acc += Seconds(20); + EXPECT_EQ(acc, nms); + // negative += positive (small units) -> zero + acc += Milliseconds(123); + EXPECT_EQ(acc, zero); +} + +TEST(WvDurationUtilTest, Decrement_WithChronoDuration) { + const Duration ps = Duration::FromSeconds(104000); + const Duration pms = Duration::FromMilliseconds(123); + const Duration pt = Duration::FromMilliseconds(104000123); + const Duration ns = Duration::FromSeconds(-104000); + const Duration nms = Duration::FromMilliseconds(-123); + const Duration nt = Duration::FromMilliseconds(-104000123); + const Duration zero = Duration::Zero(); + + Duration acc = zero; + // zero -= positive (small units) -> negative + acc -= Milliseconds(123); + EXPECT_EQ(acc, nms); + // negative -= positive (larger units) -> negative + acc -= Days(1); + acc -= Hours(4); + acc -= Minutes(53); + acc -= Seconds(20); + EXPECT_EQ(acc, nt); + // negative -= zero (all units) -> negative + acc -= Milliseconds::zero(); + EXPECT_EQ(acc, nt); + acc -= Seconds::zero(); + EXPECT_EQ(acc, nt); + acc -= Minutes::zero(); + EXPECT_EQ(acc, nt); + acc -= Hours::zero(); + EXPECT_EQ(acc, nt); + acc -= Days::zero(); + EXPECT_EQ(acc, nt); + // negative -= negative (small units) -> negative + acc -= Milliseconds(-123); + EXPECT_EQ(acc, ns); + // negative -= negative (large units) -> zero + acc -= Days(-1); + acc -= Hours(-4); + acc -= Minutes(-53); + acc -= Seconds(-20); + EXPECT_EQ(acc, zero); + // zero -= zero (all units) -> zero + acc -= Milliseconds::zero(); + EXPECT_EQ(acc, zero); + acc -= Seconds::zero(); + EXPECT_EQ(acc, zero); + acc -= Minutes::zero(); + EXPECT_EQ(acc, zero); + acc -= Hours::zero(); + EXPECT_EQ(acc, zero); + acc -= Days::zero(); + EXPECT_EQ(acc, zero); + // zero -= negative (large units) -> positive + acc -= Days(-1); + acc -= Hours(-4); + acc -= Minutes(-53); + acc -= Seconds(-20); + EXPECT_EQ(acc, ps); + // positive -= negative (small unit) -> positive + acc -= Milliseconds(-123); + EXPECT_EQ(acc, pt); + // positive -= zero (all units) -> positive + acc -= Milliseconds::zero(); + EXPECT_EQ(acc, pt); + acc -= Seconds::zero(); + EXPECT_EQ(acc, pt); + acc -= Minutes::zero(); + EXPECT_EQ(acc, pt); + acc -= Hours::zero(); + EXPECT_EQ(acc, pt); + acc -= Days::zero(); + EXPECT_EQ(acc, pt); + // positive -= positive (large units) -> positive + acc -= Days(1); + acc -= Hours(4); + acc -= Minutes(53); + acc -= Seconds(20); + EXPECT_EQ(acc, pms); + // positive -= positive (small units) -> zero + acc -= Milliseconds(123); + EXPECT_EQ(acc, zero); +} + +TEST(WvDurationUtilTest, ToString_Zero) { + EXPECT_EQ(Duration::Zero().ToString(), "0s"); +} + +TEST(WvDurationUtilTest, ToString_Ms) { + // Millis only + EXPECT_EQ(Duration::FromMilliseconds(1).ToString(), "1ms"); + EXPECT_EQ(Duration::FromMilliseconds(12).ToString(), "12ms"); + EXPECT_EQ(Duration::FromMilliseconds(123).ToString(), "123ms"); + EXPECT_EQ(Duration::FromMilliseconds(999).ToString(), "999ms"); +} + +TEST(WvDurationUtilTest, ToString_S) { + // Seconds only + EXPECT_EQ(Duration::FromSeconds(1).ToString(), "1s"); + EXPECT_EQ(Duration::FromSeconds(9).ToString(), "9s"); + EXPECT_EQ(Duration::FromSeconds(30).ToString(), "30s"); + EXPECT_EQ(Duration::FromSeconds(59).ToString(), "59s"); +} + +TEST(WvDurationUtilTest, ToString_M) { + // Minutes only + EXPECT_EQ(Duration::FromSeconds(1 * kSPerM).ToString(), "1m"); + EXPECT_EQ(Duration::FromSeconds(9 * kSPerM).ToString(), "9m"); + EXPECT_EQ(Duration::FromSeconds(30 * kSPerM).ToString(), "30m"); + EXPECT_EQ(Duration::FromSeconds(59 * kSPerM).ToString(), "59m"); +} + +TEST(WvDurationUtilTest, ToString_H) { + // Hours only + EXPECT_EQ(Duration::FromSeconds(1 * kSPerH).ToString(), "1h"); + EXPECT_EQ(Duration::FromSeconds(9 * kSPerH).ToString(), "9h"); + EXPECT_EQ(Duration::FromSeconds(12 * kSPerH).ToString(), "12h"); + EXPECT_EQ(Duration::FromSeconds(23 * kSPerH).ToString(), "23h"); +} + +TEST(WvDurationUtilTest, ToString_D) { + // Days only + EXPECT_EQ(Duration::FromSeconds(1 * kSPerD).ToString(), "1d"); + EXPECT_EQ(Duration::FromSeconds(7 * kSPerD).ToString(), "7d"); + EXPECT_EQ(Duration::FromSeconds(99 * kSPerD).ToString(), "99d"); +} + +TEST(WvDurationUtilTest, ToString_S_Ms) { + // Seconds and millis + EXPECT_EQ(Duration::FromMilliseconds(1001).ToString(), "1s1ms"); + EXPECT_EQ(Duration::FromMilliseconds(1123).ToString(), "1s123ms"); + EXPECT_EQ(Duration::FromMilliseconds(30507).ToString(), "30s507ms"); + EXPECT_EQ(Duration::FromMilliseconds(59999).ToString(), "59s999ms"); +} + +TEST(WvDurationUtilTest, ToString_M_Ms) { + // Minutes and millis + EXPECT_EQ(Duration::FromMilliseconds(kMsPerM + 1).ToString(), "1m1ms"); + EXPECT_EQ(Duration::FromMilliseconds(kMsPerM + 123).ToString(), "1m123ms"); + EXPECT_EQ(Duration::FromMilliseconds(30l * kMsPerM + 507).ToString(), + "30m507ms"); + EXPECT_EQ(Duration::FromMilliseconds(59l * kMsPerM + 999).ToString(), + "59m999ms"); +} + +TEST(WvDurationUtilTest, ToString_H_Ms) { + // Hours and millis + EXPECT_EQ(Duration::FromMilliseconds(kMsPerH + 1).ToString(), "1h1ms"); + EXPECT_EQ(Duration::FromMilliseconds(kMsPerH + 123).ToString(), "1h123ms"); + EXPECT_EQ(Duration::FromMilliseconds(12l * kMsPerH + 507).ToString(), + "12h507ms"); + EXPECT_EQ(Duration::FromMilliseconds(23l * kMsPerH + 999).ToString(), + "23h999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_Ms) { + // Days and millis + EXPECT_EQ(Duration::FromMilliseconds(1l * kMsPerD + 1).ToString(), "1d1ms"); + EXPECT_EQ(Duration::FromMilliseconds(7l * kMsPerD + 123).ToString(), + "7d123ms"); + EXPECT_EQ(Duration::FromMilliseconds(10l * kMsPerD + 507).ToString(), + "10d507ms"); + EXPECT_EQ(Duration::FromMilliseconds(99l * kMsPerD + 507).ToString(), + "99d507ms"); +} + +TEST(WvDurationUtilTest, ToString_M_S) { + // Minutes and seconds + EXPECT_EQ(Duration::FromSeconds(kSPerM + 1).ToString(), "1m1s"); + EXPECT_EQ(Duration::FromSeconds(30 * kSPerM + 30).ToString(), "30m30s"); + EXPECT_EQ(Duration::FromSeconds(59 * kSPerM + 59).ToString(), "59m59s"); +} + +TEST(WvDurationUtilTest, ToString_H_S) { + // Hours and seconds + EXPECT_EQ(Duration::FromSeconds(kSPerH + 1).ToString(), "1h1s"); + EXPECT_EQ(Duration::FromSeconds(12 * kSPerH + 30).ToString(), "12h30s"); + EXPECT_EQ(Duration::FromSeconds(23 * kSPerH + 59).ToString(), "23h59s"); +} + +TEST(WvDurationUtilTest, ToString_D_S) { + // Days and seconds + EXPECT_EQ(Duration::FromSeconds(kSPerD + 1).ToString(), "1d1s"); + EXPECT_EQ(Duration::FromSeconds(7 * kSPerD + 30).ToString(), "7d30s"); + EXPECT_EQ(Duration::FromSeconds(30 * kSPerD + 59).ToString(), "30d59s"); +} + +TEST(WvDurationUtilTest, ToString_H_M) { + // Hours and minutes + EXPECT_EQ(Duration::FromSeconds((kMPerH + 1) * kSPerM).ToString(), "1h1m"); + EXPECT_EQ(Duration::FromSeconds((12 * kMPerH + 30) * kSPerM).ToString(), + "12h30m"); + EXPECT_EQ(Duration::FromSeconds((23 * kMPerH + 59) * kSPerM).ToString(), + "23h59m"); +} + +TEST(WvDurationUtilTest, ToString_D_M) { + // Days and minutes + EXPECT_EQ(Duration::FromSeconds((kMPerD + 1) * kSPerM).ToString(), "1d1m"); + EXPECT_EQ(Duration::FromSeconds((7 * kMPerD + 30) * kSPerM).ToString(), + "7d30m"); + EXPECT_EQ(Duration::FromSeconds((30 * kMPerD + 59) * kSPerM).ToString(), + "30d59m"); +} + +TEST(WvDurationUtilTest, ToString_D_H) { + // Days and hours + EXPECT_EQ(Duration::FromSeconds((kHPerD + 1) * kSPerH).ToString(), "1d1h"); + EXPECT_EQ(Duration::FromSeconds((7 * kHPerD + 12) * kSPerH).ToString(), + "7d12h"); + EXPECT_EQ(Duration::FromSeconds((30 * kHPerD + 23) * kSPerH).ToString(), + "30d23h"); +} + +TEST(WvDurationUtilTest, ToString_M_S_Ms) { + // Minutes, seconds, and millis + EXPECT_EQ(Duration::FromMilliseconds((kSPerM + 1) * kMsPerS + 1).ToString(), + "1m1s1ms"); + EXPECT_EQ( + Duration::FromMilliseconds((6l * kSPerM + 30) * kMsPerS + 123).ToString(), + "6m30s123ms"); + EXPECT_EQ( + Duration::FromMilliseconds((30l * kSPerM + 3) * kMsPerS + 507).ToString(), + "30m3s507ms"); + EXPECT_EQ(Duration::FromMilliseconds((59l * kSPerM + 59) * kMsPerS + 999) + .ToString(), + "59m59s999ms"); +} + +TEST(WvDurationUtilTest, ToString_H_S_Ms) { + // Hours, seconds, and millis + EXPECT_EQ(Duration::FromMilliseconds((kSPerH + 1) * kMsPerS + 1).ToString(), + "1h1s1ms"); + EXPECT_EQ( + Duration::FromMilliseconds((6l * kSPerH + 30) * kMsPerS + 123).ToString(), + "6h30s123ms"); + EXPECT_EQ( + Duration::FromMilliseconds((12l * kSPerH + 3) * kMsPerS + 507).ToString(), + "12h3s507ms"); + EXPECT_EQ(Duration::FromMilliseconds((23l * kSPerH + 59) * kMsPerS + 999) + .ToString(), + "23h59s999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_S_Ms) { + // Days, seconds, and millis + EXPECT_EQ(Duration::FromMilliseconds((kSPerD + 1) * kMsPerS + 1).ToString(), + "1d1s1ms"); + EXPECT_EQ( + Duration::FromMilliseconds((7 * kSPerD + 30) * kMsPerS + 123).ToString(), + "7d30s123ms"); + EXPECT_EQ( + Duration::FromMilliseconds((30 * kSPerD + 59) * kMsPerS + 999).ToString(), + "30d59s999ms"); +} + +TEST(WvDurationUtilTest, ToString_H_M_Ms) { + // Hours, minutes, and millis + EXPECT_EQ(Duration::FromMilliseconds((kMPerH + 1) * kMsPerM + 1).ToString(), + "1h1m1ms"); + EXPECT_EQ( + Duration::FromMilliseconds((6l * kMPerH + 30) * kMsPerM + 123).ToString(), + "6h30m123ms"); + EXPECT_EQ( + Duration::FromMilliseconds((12l * kMPerH + 3) * kMsPerM + 507).ToString(), + "12h3m507ms"); + EXPECT_EQ(Duration::FromMilliseconds((23l * kMPerH + 59) * kMsPerM + 999) + .ToString(), + "23h59m999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_M_Ms) { + // Days, minutes and millis + EXPECT_EQ(Duration::FromMilliseconds((kMPerD + 1) * kMsPerM + 1).ToString(), + "1d1m1ms"); + EXPECT_EQ( + Duration::FromMilliseconds((7l * kMPerD + 30) * kMsPerM + 123).ToString(), + "7d30m123ms"); + EXPECT_EQ(Duration::FromMilliseconds((30l * kMPerD + 59) * kMsPerM + 999) + .ToString(), + "30d59m999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_H_Ms) { + // Days, hours and millis + EXPECT_EQ(Duration::FromMilliseconds((kHPerD + 1) * kMsPerH + 1).ToString(), + "1d1h1ms"); + EXPECT_EQ( + Duration::FromMilliseconds((7 * kHPerD + 12) * kMsPerH + 123).ToString(), + "7d12h123ms"); + EXPECT_EQ( + Duration::FromMilliseconds((30 * kHPerD + 23) * kMsPerH + 999).ToString(), + "30d23h999ms"); +} + +TEST(WvDurationUtilTest, ToString_H_M_S) { + // Hours, minutes and seconds. + EXPECT_EQ(Duration::FromSeconds((kMPerH + 1) * kSPerM + 1).ToString(), + "1h1m1s"); + EXPECT_EQ(Duration::FromSeconds((6 * kMPerH + 30) * kSPerM + 9).ToString(), + "6h30m9s"); + EXPECT_EQ(Duration::FromSeconds((12 * kMPerH + 3) * kSPerM + 30).ToString(), + "12h3m30s"); + EXPECT_EQ(Duration::FromSeconds((23 * kMPerH + 59) * kSPerM + 59).ToString(), + "23h59m59s"); +} + +TEST(WvDurationUtilTest, ToString_D_M_S) { + // Days, minutes and seconds. + EXPECT_EQ(Duration::FromSeconds((kMPerD + 1) * kSPerM + 1).ToString(), + "1d1m1s"); + EXPECT_EQ(Duration::FromSeconds((7 * kMPerD + 3) * kSPerM + 30).ToString(), + "7d3m30s"); + EXPECT_EQ(Duration::FromSeconds((30 * kMPerD + 59) * kSPerM + 59).ToString(), + "30d59m59s"); +} + +TEST(WvDurationUtilTest, ToString_D_H_S) { + // Days, hours and seconds. + EXPECT_EQ(Duration::FromSeconds((kHPerD + 1) * kSPerH + 1).ToString(), + "1d1h1s"); + EXPECT_EQ(Duration::FromSeconds((7 * kHPerD + 12) * kSPerH + 30).ToString(), + "7d12h30s"); + EXPECT_EQ(Duration::FromSeconds((30 * kHPerD + 23) * kSPerH + 59).ToString(), + "30d23h59s"); +} + +TEST(WvDurationUtilTest, ToString_D_H_M) { + // Days, hours and minutes. + EXPECT_EQ( + Duration::FromSeconds(((kHPerD + 1) * kMPerH + 1) * kSPerM).ToString(), + "1d1h1m"); + EXPECT_EQ(Duration::FromSeconds(((7 * kHPerD + 12) * kMPerH + 30) * kSPerM) + .ToString(), + "7d12h30m"); + EXPECT_EQ(Duration::FromSeconds(((30 * kHPerD + 23) * kMPerH + 59) * kSPerM) + .ToString(), + "30d23h59m"); +} + +TEST(WvDurationUtilTest, ToString_H_M_S_Ms) { + // Hours, minutes, seconds, and millis + EXPECT_EQ( + Duration::FromMilliseconds(((1l * kMPerH + 1) * kSPerM + 1) * kMsPerS + 1) + .ToString(), + "1h1m1s1ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((12l * kMPerH + 7) * kSPerM + 21) * kMsPerS + 60) + .ToString(), + "12h7m21s60ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((3l * kMPerH + 18) * kSPerM + 6) * kMsPerS + 507) + .ToString(), + "3h18m6s507ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((23l * kMPerH + 59) * kSPerM + 59) * kMsPerS + 999) + .ToString(), + "23h59m59s999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_M_S_Ms) { + // Days, minutes, seconds, and millis + EXPECT_EQ( + Duration::FromMilliseconds(((1l * kMPerD + 1) * kSPerM + 1) * kMsPerS + 1) + .ToString(), + "1d1m1s1ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((7l * kMPerD + 18) * kSPerM + 6) * kMsPerS + 507) + .ToString(), + "7d18m6s507ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((30l * kMPerD + 59) * kSPerM + 59) * kMsPerS + 999) + .ToString(), + "30d59m59s999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_H_S_Ms) { + // Days, hours, seconds, and millis + EXPECT_EQ( + Duration::FromMilliseconds(((1l * kHPerD + 1) * kSPerH + 1) * kMsPerS + 1) + .ToString(), + "1d1h1s1ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((7l * kHPerD + 12) * kSPerH + 30) * kMsPerS + 507) + .ToString(), + "7d12h30s507ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((30l * kHPerD + 23) * kSPerH + 59) * kMsPerS + 999) + .ToString(), + "30d23h59s999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_H_M_Ms) { + // Days, hours, minutes, and millis + EXPECT_EQ( + Duration::FromMilliseconds(((1l * kHPerD + 1) * kMPerH + 1) * kMsPerM + 1) + .ToString(), + "1d1h1m1ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((7l * kHPerD + 12) * kMPerH + 30) * kMsPerM + 507) + .ToString(), + "7d12h30m507ms"); + EXPECT_EQ(Duration::FromMilliseconds( + ((30l * kHPerD + 23) * kMPerH + 59) * kMsPerM + 999) + .ToString(), + "30d23h59m999ms"); +} + +TEST(WvDurationUtilTest, ToString_D_H_M_S) { + // Days, hours, minutes, and seconds + EXPECT_EQ(Duration::FromSeconds(((1l * kHPerD + 1) * kMPerH + 1) * kSPerM + 1) + .ToString(), + "1d1h1m1s"); + EXPECT_EQ( + Duration::FromSeconds(((7l * kHPerD + 12) * kMPerH + 30) * kSPerM + 22) + .ToString(), + "7d12h30m22s"); + EXPECT_EQ( + Duration::FromSeconds(((30l * kHPerD + 23) * kMPerH + 59) * kSPerM + 59) + .ToString(), + "30d23h59m59s"); +} + +TEST(WvDurationUtilTest, ToString_D_H_M_S_Ms) { + // Days, hours, minutes, seconds and millis + EXPECT_EQ(Duration::FromMilliseconds( + (((1l * kHPerD + 1) * kMPerH + 1) * kSPerM + 1) * kMsPerS + 1) + .ToString(), + "1d1h1m1s1ms"); + EXPECT_EQ( + Duration::FromMilliseconds( + (((7l * kHPerD + 12) * kMPerH + 10) * kSPerM + 20) * kMsPerS + 507) + .ToString(), + "7d12h10m20s507ms"); + EXPECT_EQ( + Duration::FromMilliseconds( + (((30l * kHPerD + 23) * kMPerH + 59) * kSPerM + 59) * kMsPerS + 999) + .ToString(), + "30d23h59m59s999ms"); +} +} // namespace test +} // namespace wvutil diff --git a/util/test/wv_timestamp_unittest.cpp b/util/test/wv_timestamp_unittest.cpp new file mode 100644 index 0000000..f8d8f12 --- /dev/null +++ b/util/test/wv_timestamp_unittest.cpp @@ -0,0 +1,300 @@ +// Copyright 2025 Google LLC. All Rights Reserved. This file and proprietary +// source code may only be used and distributed under the Widevine License +// Agreement. +#include "wv_timestamp.h" + +#include + +#include + +#include + +#include "wv_duration.h" + +namespace wvutil { +namespace test { +namespace { +constexpr uint64_t kRawValidEpochSeconds = 1733526346; +constexpr Seconds kValidEpochSeconds = Seconds(kRawValidEpochSeconds); + +constexpr uint64_t kRawOutOfRangeEpochSeconds = 253402300800; // Year 10000 +constexpr Seconds kOutOfRangeEpochSeconds = Seconds(kRawOutOfRangeEpochSeconds); + +constexpr uint64_t kOneS = 1; +constexpr uint64_t kOneMinuteS = kOneS * 60; +constexpr uint64_t kOneHourS = kOneMinuteS * 60; +constexpr uint64_t kOneDayS = kOneHourS * 24; +} // namespace + +TEST(WvTimestampUtilTest, DefaultConstructor) { + const Timestamp ts; + EXPECT_FALSE(ts.IsSet()); + EXPECT_EQ(ts.epoch_milliseconds(), Milliseconds::zero()); + EXPECT_EQ(ts.epoch_seconds(), Seconds::zero()); +} + +TEST(WvTimestampUtilTest, MinMax) { + const Timestamp max = Timestamp::Max(); + EXPECT_TRUE(max.IsSet()); + const Timestamp min = Timestamp::Min(); + EXPECT_TRUE(min.IsSet()); + EXPECT_LE(min, max); + EXPECT_GE(max, min); +} + +TEST(WvTimestampUtilTest, FromUnixSeconds) { + const Milliseconds epoch_milliseconds = + std::chrono::duration_cast(kValidEpochSeconds); + + const Timestamp ts = Timestamp::FromUnixSeconds(kValidEpochSeconds); + EXPECT_TRUE(ts.IsSet()); + + EXPECT_EQ(ts.epoch_seconds(), kValidEpochSeconds); + EXPECT_EQ(ts.epoch_milliseconds(), epoch_milliseconds); + EXPECT_EQ(ts.milliseconds(), 0u); + + EXPECT_EQ(ts, Timestamp::FromUnixSeconds(kRawValidEpochSeconds)); +} + +TEST(WvTimestampUtilTest, FromUnixSeconds_WithMs) { + constexpr uint32_t raw_milliseconds = 123; + const Milliseconds epoch_milliseconds = + kValidEpochSeconds + Milliseconds(raw_milliseconds); + + const Timestamp ts = + Timestamp::FromUnixSeconds(kValidEpochSeconds, raw_milliseconds); + EXPECT_TRUE(ts.IsSet()); + + EXPECT_EQ(ts.epoch_seconds(), kValidEpochSeconds); + EXPECT_EQ(ts.epoch_milliseconds(), epoch_milliseconds); + EXPECT_EQ(ts.milliseconds(), raw_milliseconds); + + EXPECT_EQ( + ts, Timestamp::FromUnixSeconds(kRawValidEpochSeconds, raw_milliseconds)); +} + +TEST(WvTimestampUtilTest, FromUnixSeconds_SecondsOutOfRange) { + EXPECT_FALSE(Timestamp::FromUnixSeconds(kRawOutOfRangeEpochSeconds).IsSet()); + EXPECT_FALSE(Timestamp::FromUnixSeconds(kOutOfRangeEpochSeconds).IsSet()); + + constexpr Seconds negative_epoch_seconds = Seconds(-1); + EXPECT_FALSE(Timestamp::FromUnixSeconds(negative_epoch_seconds).IsSet()); +} + +TEST(WvTimestampUtilTest, FromUnixSeconds_MsOutOfRange) { + EXPECT_FALSE(Timestamp::FromUnixSeconds(kValidEpochSeconds, 1000).IsSet()); + EXPECT_FALSE(Timestamp::FromUnixSeconds(kRawValidEpochSeconds, 1000).IsSet()); +} + +TEST(WvTimestampUtilTest, FromUnixMilliseconds) { + constexpr uint32_t raw_milliseconds = 123; + constexpr uint64_t raw_total_milliseconds = + (kRawValidEpochSeconds * 1000) + raw_milliseconds; + constexpr Milliseconds epoch_milliseconds = + Milliseconds(raw_total_milliseconds); + + constexpr Timestamp ts = Timestamp::FromUnixMilliseconds(epoch_milliseconds); + + EXPECT_TRUE(ts.IsSet()); + + EXPECT_EQ(ts.epoch_seconds(), kValidEpochSeconds); + EXPECT_EQ(ts.epoch_milliseconds(), epoch_milliseconds); + EXPECT_EQ(ts.milliseconds(), raw_milliseconds); + + EXPECT_EQ(ts, Timestamp::FromUnixMilliseconds(raw_total_milliseconds)); +} + +TEST(WvTimestampUtilTest, FromUnixMilliseconds_OutOfRange) { + constexpr uint64_t raw_total_milliseconds = kRawOutOfRangeEpochSeconds * 1000; + constexpr Milliseconds epoch_milliseconds = + Milliseconds(raw_total_milliseconds); + EXPECT_FALSE(Timestamp::FromUnixMilliseconds(raw_total_milliseconds).IsSet()); + EXPECT_FALSE(Timestamp::FromUnixMilliseconds(epoch_milliseconds).IsSet()); +} + +TEST(WvTimestampUtilTest, Clear) { + Timestamp ts = Timestamp::FromUnixSeconds(kValidEpochSeconds); + EXPECT_TRUE(ts.IsSet()); + ts.Clear(); + EXPECT_FALSE(ts.IsSet()); +} + +TEST(WvTimestampUtilTest, Comparison) { + constexpr Timestamp ts_a = Timestamp::Min(); + constexpr Timestamp ts_b = + Timestamp::FromUnixSeconds(kRawValidEpochSeconds - kOneDayS); + constexpr Timestamp ts_c = Timestamp::FromUnixSeconds(kValidEpochSeconds); + constexpr Timestamp ts_d = + Timestamp::FromUnixSeconds(kValidEpochSeconds, 123); + constexpr Timestamp ts_e = Timestamp::Max(); + + // Equality + EXPECT_EQ(ts_a, ts_a); + EXPECT_EQ(ts_b, ts_b); + EXPECT_EQ(ts_c, ts_c); + EXPECT_EQ(ts_d, ts_d); + EXPECT_EQ(ts_e, ts_e); + + // Inequality. + EXPECT_NE(ts_a, ts_b); + EXPECT_NE(ts_a, ts_c); + EXPECT_NE(ts_a, ts_d); + EXPECT_NE(ts_a, ts_e); + EXPECT_NE(ts_b, ts_c); + EXPECT_NE(ts_b, ts_d); + EXPECT_NE(ts_b, ts_e); + EXPECT_NE(ts_c, ts_d); + EXPECT_NE(ts_c, ts_e); + EXPECT_NE(ts_d, ts_e); + + // Less than + EXPECT_LT(ts_a, ts_b); + EXPECT_LT(ts_a, ts_c); + EXPECT_LT(ts_a, ts_d); + EXPECT_LT(ts_a, ts_e); + EXPECT_LT(ts_b, ts_c); + EXPECT_LT(ts_b, ts_d); + EXPECT_LT(ts_b, ts_e); + EXPECT_LT(ts_c, ts_d); + EXPECT_LT(ts_c, ts_e); + EXPECT_LT(ts_d, ts_e); + + // Less than or equal. + EXPECT_LE(ts_a, ts_a); + EXPECT_LE(ts_a, ts_b); + EXPECT_LE(ts_a, ts_c); + EXPECT_LE(ts_a, ts_d); + EXPECT_LE(ts_a, ts_e); + EXPECT_LE(ts_b, ts_b); + EXPECT_LE(ts_b, ts_c); + EXPECT_LE(ts_b, ts_d); + EXPECT_LE(ts_b, ts_e); + EXPECT_LE(ts_c, ts_c); + EXPECT_LE(ts_c, ts_d); + EXPECT_LE(ts_c, ts_e); + EXPECT_LE(ts_d, ts_d); + EXPECT_LE(ts_d, ts_e); + EXPECT_LE(ts_e, ts_e); + + // Greater than + EXPECT_GT(ts_b, ts_a); + EXPECT_GT(ts_c, ts_a); + EXPECT_GT(ts_d, ts_a); + EXPECT_GT(ts_e, ts_a); + EXPECT_GT(ts_c, ts_b); + EXPECT_GT(ts_d, ts_b); + EXPECT_GT(ts_e, ts_b); + EXPECT_GT(ts_d, ts_c); + EXPECT_GT(ts_e, ts_c); + EXPECT_GT(ts_e, ts_d); + + // Greater than or equal. + EXPECT_GE(ts_a, ts_a); + EXPECT_GE(ts_b, ts_a); + EXPECT_GE(ts_c, ts_a); + EXPECT_GE(ts_d, ts_a); + EXPECT_GE(ts_e, ts_a); + EXPECT_GE(ts_b, ts_b); + EXPECT_GE(ts_c, ts_b); + EXPECT_GE(ts_d, ts_b); + EXPECT_GE(ts_e, ts_b); + EXPECT_GE(ts_c, ts_c); + EXPECT_GE(ts_d, ts_c); + EXPECT_GE(ts_e, ts_c); + EXPECT_GE(ts_d, ts_d); + EXPECT_GE(ts_e, ts_d); + EXPECT_GE(ts_e, ts_e); +} + +TEST(WvTimestampUtilTest, Addition_WithDuration) { + constexpr Timestamp start = Timestamp::FromUnixSeconds(kValidEpochSeconds); + constexpr Duration one_day = Duration::FromSeconds(kOneDayS); + + constexpr Timestamp tomorrow = start + one_day; + constexpr decltype(tomorrow.epoch_seconds().count()) expected_second_count = + kRawValidEpochSeconds + kOneDayS; + EXPECT_EQ(tomorrow.epoch_seconds().count(), expected_second_count); +} + +TEST(WvTimestampUtilTest, Addition_WithDuration_Unset) { + constexpr Timestamp unset; + constexpr Duration one_day = Duration::FromSeconds(kOneDayS); + + constexpr Timestamp result = unset + one_day; + EXPECT_FALSE(result.IsSet()); +} + +TEST(WvTimestampUtilTest, Addition_WithDuration_Overflow) { + Timestamp start = Timestamp::Max(); + constexpr Duration one_day = Duration::FromSeconds(kOneDayS); + + Timestamp result = start + one_day; + EXPECT_FALSE(result.IsSet()); + + start = Timestamp::Min(); + constexpr Duration very_long_duration = + Duration::FromSeconds(kOneDayS * 365 * 10000); + result = start + very_long_duration; + EXPECT_FALSE(result.IsSet()); +} + +TEST(WvTimestampUtilTest, Subtraction_WithDuration) { + constexpr Timestamp start = Timestamp::FromUnixSeconds(kValidEpochSeconds); + constexpr Duration one_day = Duration::FromSeconds(kOneDayS); + + constexpr Timestamp yesterday = start - one_day; + constexpr decltype(yesterday.epoch_seconds().count()) expected_second_count = + kRawValidEpochSeconds - kOneDayS; + EXPECT_EQ(yesterday.epoch_seconds().count(), expected_second_count); +} + +TEST(WvTimestampUtilTest, Subtraction_WithDuration_Unset) { + constexpr Timestamp unset; + constexpr Duration one_day = Duration::FromSeconds(kOneDayS); + + constexpr Timestamp result = unset - one_day; + EXPECT_FALSE(result.IsSet()); +} + +TEST(WvTimestampUtilTest, Subtraction_WithDuration_Underflow) { + Timestamp start = Timestamp::Min(); + constexpr Duration one_day = Duration::FromSeconds(kOneDayS); + + Timestamp result = start - one_day; + EXPECT_FALSE(result.IsSet()); + + start = Timestamp::Max(); + constexpr Duration very_long_duration = + Duration::FromSeconds(kOneDayS * 365 * 10000); + result = start - very_long_duration; + EXPECT_FALSE(result.IsSet()); +} + +TEST(WvTimestampUtilTest, Difference) { + // 4 days, 3 hours, 2 minutes and 1 second. + constexpr int64_t kDifferenceS = + 1 + 2 * kOneMinuteS + 3 * kOneHourS + 4 * kOneDayS; + + constexpr Timestamp ts_a = Timestamp::FromUnixSeconds(kValidEpochSeconds); + constexpr Timestamp ts_b = + Timestamp::FromUnixSeconds(kRawValidEpochSeconds + kDifferenceS); + + constexpr Duration diff_ab = ts_b - ts_a; + EXPECT_EQ(diff_ab.total_seconds().count(), kDifferenceS); + + constexpr Duration diff_ba = ts_a - ts_b; + EXPECT_EQ(diff_ba.total_seconds().count(), -kDifferenceS); +} + +TEST(WvTimestampUtilTest, Difference_Unset) { + constexpr Duration unset_duration; + + constexpr Timestamp ts = Timestamp::FromUnixSeconds(kValidEpochSeconds); + constexpr Timestamp unset; + + EXPECT_EQ(ts - unset, unset_duration); + EXPECT_EQ(unset - ts, unset_duration); + EXPECT_EQ(unset - unset, unset_duration); +} +} // namespace test +} // namespace wvutil