#!/usr/bin/python3 # Copyright 2017 Google LLC. All Rights Reserved. """Common test utility functions for OEM certificate generation.""" import datetime import io from cryptography import x509 from cryptography.hazmat import backends from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509 import oid import oem_certificate _COUNTRY_NAME = 'US' _STATE_OR_PROVINCE_NAME = 'WA' _LOCALITY_NAME = 'Kirkland' _ORGANIZATION_NAME = 'CompanyXYZ' _ORGANIZATIONAL_UNIT_NAME = 'ContentProtection' _NOT_VALID_BEFORE = datetime.datetime(2001, 8, 9) _VALID_DURATION = 100 _LEAF_CERT_VALID_DURATION = 8000 _SYSTEM_ID = 2001 _ROOT_PRIVATE_KEY_PASSPHRASE = b'root_passphrase' class ArgParseObject(object): """A convenient object to allow adding arbitrary attribute to it.""" def create_root_certificate_and_key(): """Creates a root certificate and key.""" key = rsa.generate_private_key( public_exponent=65537, key_size=3072, backend=backends.default_backend()) subject_name = x509.Name( [x509.NameAttribute(oid.NameOID.COMMON_NAME, u'root_cert')]) certificate = oem_certificate.build_certificate( subject_name, subject_name, None, datetime.datetime(2001, 8, 9), 1000, key.public_key(), key, True) return (key, certificate) def setup_csr_args(country_name=_COUNTRY_NAME, state_or_province_name=_STATE_OR_PROVINCE_NAME, locality_name=_LOCALITY_NAME, organization_name=_ORGANIZATION_NAME, organizational_unit_name=_ORGANIZATIONAL_UNIT_NAME, key_size=4096, output_csr_file=None, output_private_key_file=None, passphrase=None, common_name=None): """Sets up arguments to OEM Certificate generator for generating csr.""" args = ArgParseObject() args.key_size = key_size args.country_name = country_name args.state_or_province_name = state_or_province_name args.locality_name = locality_name args.organization_name = organization_name args.organizational_unit_name = organizational_unit_name args.common_name = common_name if output_csr_file: args.output_csr_file = output_csr_file else: args.output_csr_file = io.BytesIO() if output_private_key_file: args.output_private_key_file = output_private_key_file else: args.output_private_key_file = io.BytesIO() args.passphrase = passphrase return args def setup_intermediate_cert_args( csr_bytes, root_key, root_certificate, not_valid_before=_NOT_VALID_BEFORE, valid_duration=_VALID_DURATION, system_id=_SYSTEM_ID, root_private_key_passphrase=_ROOT_PRIVATE_KEY_PASSPHRASE, output_certificate_file=None): """Sets up args to OEM Cert generator for generating intermediate cert.""" args = ArgParseObject() args.not_valid_before = not_valid_before args.valid_duration = valid_duration args.system_id = system_id args.csr_file = io.BytesIO(csr_bytes) args.root_private_key_passphrase = root_private_key_passphrase if output_certificate_file: args.output_certificate_file = output_certificate_file else: args.output_certificate_file = io.BytesIO() serialized_private_key = root_key.private_bytes( serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.BestAvailableEncryption( args.root_private_key_passphrase)) serialized_certificate = root_certificate.public_bytes( serialization.Encoding.DER) args.root_certificate_file = io.BytesIO(serialized_certificate) args.root_private_key_file = io.BytesIO(serialized_private_key) return args def setup_leaf_cert_args(intermediate_key_bytes, intermediate_certificate_bytes, key_size=1024, passphrase=None, not_valid_before=_NOT_VALID_BEFORE, valid_duration=_LEAF_CERT_VALID_DURATION, output_certificate_file=None, output_private_key_file=None): """Sets up args to OEM Certificate generator for generating leaf cert.""" args = ArgParseObject() args.key_size = key_size args.not_valid_before = not_valid_before args.valid_duration = valid_duration args.intermediate_private_key_passphrase = None if output_certificate_file: args.output_certificate_file = output_certificate_file else: args.output_certificate_file = io.BytesIO() if output_private_key_file: args.output_private_key_file = output_private_key_file else: args.output_private_key_file = io.BytesIO() args.passphrase = passphrase args.intermediate_private_key_file = io.BytesIO( intermediate_key_bytes) args.intermediate_certificate_file = io.BytesIO( intermediate_certificate_bytes) return args def create_intermediate_certificate_and_key_bytes(key_size=4096, passphrase=None, pem_format=True): """Creates an intermediate certificate and key.""" csr_args = setup_csr_args(key_size=key_size, passphrase=passphrase) oem_certificate.generate_csr(csr_args) csr_bytes = csr_args.output_csr_file.getvalue() root_key, root_certificate = create_root_certificate_and_key() args = setup_intermediate_cert_args(csr_bytes, root_key, root_certificate) oem_certificate.generate_intermediate_certificate(args) cert_bytes = args.output_certificate_file.getvalue() if pem_format: cert = x509.load_der_x509_certificate(cert_bytes, backends.default_backend()) cert_bytes = cert.public_bytes(serialization.Encoding.PEM) return (csr_args.output_private_key_file.getvalue(), cert_bytes)