Files
provisioning_sdk_source/oem_certificate_generator/oem_certificate_test.py
Kongqun Yang 84f66d2320 NewProvisioningSession expects pkcs8 private key and SHA race fix
-------------
Fix SHA hashing to remove race condition. This change
fixes the implementation by passing in the digest buffer.

-------------
The input to ProvisioningEngine::NewProvisioningSession should be
pkcs8 private key instead of pkcs1 private key

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=151273394

Change-Id: Ibcdff7757b2ac2878ee8b1b88365083964bfa10a
2017-03-28 14:51:50 -07:00

461 lines
21 KiB
Python

################################################################################
# Copyright 2016 Google Inc.
#
# This software is licensed under the terms defined in the Widevine Master
# License Agreement. For a copy of this agreement, please contact
# widevine-licensing@google.com.
################################################################################
import datetime
import os
import shutil
import StringIO
import tempfile
import textwrap
import unittest
from cryptography import x509
from cryptography.hazmat import backends
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.x509 import oid
import oem_certificate
import oem_certificate_generator_helper as oem_cert_gen_helper
class ArgParseObject(object):
"""A convenient object to allow adding arbitrary attribute to it."""
pass
class OemCertificateTest(unittest.TestCase):
def test_widevine_system_id(self):
system_id = 1234567890123
self.assertEqual(
oem_certificate.WidevineSystemId.from_int_value(system_id).int_value(),
system_id)
def test_generate_csr(self):
args = oem_cert_gen_helper.setup_csr_rgs()
oem_certificate.generate_csr(args)
# Verify CSR.
csr = x509.load_pem_x509_csr(args.output_csr_file.getvalue(),
backends.default_backend())
subject = csr.subject
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.COUNTRY_NAME)[0].value,
args.country_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.STATE_OR_PROVINCE_NAME)[0]
.value, args.state_or_province_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.LOCALITY_NAME)[0].value,
args.locality_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.ORGANIZATION_NAME)[0].value,
args.organization_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.ORGANIZATIONAL_UNIT_NAME)[0]
.value, args.organizational_unit_name)
private_key = serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
args.passphrase,
backend=backends.default_backend())
self.assertEqual(private_key.key_size, args.key_size)
self.assertEqual(csr.public_key().key_size, args.key_size)
# Verify csr and private key match.
self.assertEqual(csr.public_key().public_numbers(),
private_key.public_key().public_numbers())
def test_generate_csr_with_keysize4096_and_passphrase(self):
args = oem_cert_gen_helper.setup_csr_rgs(
key_size=4096, passphrase='passphrase_4096')
oem_certificate.generate_csr(args)
private_key = serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
'passphrase_4096',
backend=backends.default_backend())
csr = x509.load_pem_x509_csr(args.output_csr_file.getvalue(),
backends.default_backend())
self.assertEqual(private_key.key_size, 4096)
self.assertEqual(csr.public_key().key_size, 4096)
# Verify csr and private key match.
self.assertEqual(csr.public_key().public_numbers(),
private_key.public_key().public_numbers())
def test_generate_intermediate_certificate(self):
csr_args = oem_cert_gen_helper.setup_csr_rgs()
oem_certificate.generate_csr(csr_args)
csr_bytes = csr_args.output_csr_file.getvalue()
csr = x509.load_pem_x509_csr(csr_bytes, backends.default_backend())
root_key, root_certificate = (
oem_cert_gen_helper.create_root_certificate_and_key())
args = oem_cert_gen_helper.setup_intermediate_cert_args(
csr_bytes, root_key, root_certificate)
oem_certificate.generate_intermediate_certificate(args)
cert = x509.load_der_x509_certificate(
args.output_certificate_file.getvalue(), backends.default_backend())
self.assertEqual(cert.issuer, root_certificate.subject)
self.assertEqual(cert.subject, csr.subject)
system_id_raw_bytes = cert.extensions.get_extension_for_oid(
oem_certificate.WidevineSystemId.oid).value.value
self.assertEqual(
oem_certificate.WidevineSystemId(system_id_raw_bytes).int_value(),
args.system_id)
self.assertEqual(cert.not_valid_before, datetime.datetime(2001, 8, 9))
self.assertEqual(cert.not_valid_after, datetime.datetime(2001, 11, 17))
root_key.public_key().verify(cert.signature, cert.tbs_certificate_bytes,
padding.PKCS1v15(),
cert.signature_hash_algorithm)
def test_generate_intermediate_with_cert_mismatch_root_cert_and_key(self):
root_key1, _ = (
oem_cert_gen_helper.create_root_certificate_and_key())
_, root_certificate2 = oem_cert_gen_helper.create_root_certificate_and_key()
args = oem_cert_gen_helper.setup_intermediate_cert_args(
'some csr data', root_key1, root_certificate2)
with self.assertRaises(ValueError) as context:
oem_certificate.generate_intermediate_certificate(args)
self.assertTrue('certificate does not match' in str(context.exception))
def test_generate_leaf_certificate(self):
intermediate_key_bytes, intermediate_certificate_bytes = (
oem_cert_gen_helper.create_intermediate_certificate_and_key_bytes())
args = oem_cert_gen_helper.setup_leaf_cert_args(
intermediate_key_bytes, intermediate_certificate_bytes)
oem_certificate.generate_leaf_certificate(args)
certificate_chain = oem_certificate.X509CertificateChain.load_der(
args.output_certificate_file.getvalue())
certificates = list(certificate_chain)
self.assertEqual(len(certificates), 2)
intermediate_cert = certificates[1]
leaf_cert = certificates[0]
self.assertEqual(
intermediate_cert.public_bytes(serialization.Encoding.DER),
intermediate_certificate_bytes)
intermediate_cert.public_key().verify(leaf_cert.signature,
leaf_cert.tbs_certificate_bytes,
padding.PKCS1v15(),
leaf_cert.signature_hash_algorithm)
self.assertEqual(leaf_cert.not_valid_before, datetime.datetime(2001, 8, 9))
self.assertEqual(leaf_cert.not_valid_after, datetime.datetime(2023, 7, 5))
system_id_raw_bytes = leaf_cert.extensions.get_extension_for_oid(
oem_certificate.WidevineSystemId.oid).value.value
self.assertEqual(
oem_certificate.WidevineSystemId(system_id_raw_bytes).int_value(),
2001)
leaf_key = serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
args.passphrase,
backend=backends.default_backend())
self.assertEqual(leaf_key.key_size, args.key_size)
self.assertEqual(leaf_cert.public_key().key_size, args.key_size)
# Verify cert and private key match.
self.assertEqual(leaf_cert.public_key().public_numbers(),
leaf_key.public_key().public_numbers())
def test_generate_leaf_certificate_with_keysize4096_and_passphrase(self):
intermediate_key_bytes, intermediate_certificate_bytes = (
oem_cert_gen_helper.create_intermediate_certificate_and_key_bytes())
args = oem_cert_gen_helper.setup_leaf_cert_args(
intermediate_key_bytes,
intermediate_certificate_bytes,
key_size=4096,
passphrase='leaf passphrase')
oem_certificate.generate_leaf_certificate(args)
serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
'leaf passphrase',
backend=backends.default_backend())
self.assertEqual(4096, args.key_size)
def test_get_csr_info(self):
args = oem_cert_gen_helper.setup_csr_rgs()
oem_certificate.generate_csr(args)
args.file = StringIO.StringIO(args.output_csr_file.getvalue())
output = StringIO.StringIO()
oem_certificate.get_info(args, output)
expected_info = """\
CSR Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Key Size: 4096"""
self.assertEqual(output.getvalue(), textwrap.dedent(expected_info))
def test_get_certificate_info(self):
_, intermediate_certificate_bytes = (
oem_cert_gen_helper.create_intermediate_certificate_and_key_bytes())
args = ArgParseObject()
args.file = StringIO.StringIO(intermediate_certificate_bytes)
output = StringIO.StringIO()
oem_certificate.get_info(args, output)
expected_info = """\
Certificate Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Issuer Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.3, name=commonName)>, value=u'root_cert')>
Key Size: 4096
Widevine System Id: 2001
Not valid before: 2001-08-09 00:00:00
Not valid after: 2001-11-17 00:00:00"""
self.assertEqual(output.getvalue(), textwrap.dedent(expected_info))
def test_get_certificate_chain_info(self):
intermediate_key_bytes, intermediate_certificate_bytes = (
oem_cert_gen_helper.create_intermediate_certificate_and_key_bytes())
args = oem_cert_gen_helper.setup_leaf_cert_args(
intermediate_key_bytes, intermediate_certificate_bytes)
oem_certificate.generate_leaf_certificate(args)
args.file = StringIO.StringIO(args.output_certificate_file.getvalue())
output = StringIO.StringIO()
oem_certificate.get_info(args, output)
expected_info = """\
Certificate Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.3, name=commonName)>, value=u'2001-leaf')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Issuer Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Key Size: 1024
Widevine System Id: 2001
Not valid before: 2001-08-09 00:00:00
Not valid after: 2023-07-05 00:00:00
Certificate Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Issuer Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.3, name=commonName)>, value=u'root_cert')>
Key Size: 4096
Widevine System Id: 2001
Not valid before: 2001-08-09 00:00:00
Not valid after: 2001-11-17 00:00:00"""
self.assertEqual(output.getvalue(), textwrap.dedent(expected_info))
def test_secure_erase(self):
args = ArgParseObject()
args.file = tempfile.NamedTemporaryFile(delete=False)
args.passes = 2
self.assertTrue(os.path.exists(args.file.name))
oem_certificate.secure_erase(args)
self.assertFalse(os.path.exists(args.file.name))
class OemCertificateArgParseTest(unittest.TestCase):
def setUp(self):
self.parser = oem_certificate.create_parser()
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_generate_csr(self):
cmds = ('generate_csr --key_size 4096 -C USA -ST WA '
'-L Kirkland -O Company -OU Widevine').split()
output_private_key_file = os.path.join(self.test_dir, 'private_key')
output_csr_file = os.path.join(self.test_dir, 'csr')
cmds.extend([
'--output_csr_file', output_csr_file, '--output_private_key_file',
output_private_key_file, '--passphrase', 'pass'
])
args = self.parser.parse_args(cmds)
self.assertEqual(args.key_size, 4096)
self.assertEqual(args.country_name, 'USA')
self.assertEqual(args.state_or_province_name, 'WA')
self.assertEqual(args.locality_name, 'Kirkland')
self.assertEqual(args.organization_name, 'Company')
self.assertEqual(args.organizational_unit_name, 'Widevine')
self.assertEqual(args.output_csr_file.name, output_csr_file)
self.assertEqual(args.output_csr_file.mode, 'wb')
self.assertEqual(args.output_private_key_file.name, output_private_key_file)
self.assertEqual(args.output_private_key_file.mode, 'wb')
self.assertEqual(args.passphrase, 'pass')
self.assertEqual(args.func, oem_certificate.generate_csr)
def _fill_file_with_dummy_contents(self, file_name):
with open(file_name, 'wb') as f:
f.write('dummy')
def test_generate_csr_invalid_key_size(self):
cmds = ('generate_csr --key_size unknown -C USA -ST WA '
'-L Kirkland -O Company -OU Widevine').split()
output_private_key_file = os.path.join(self.test_dir, 'private_key')
output_csr_file = os.path.join(self.test_dir, 'csr')
cmds.extend([
'--output_csr_file', output_csr_file, '--output_private_key_file',
output_private_key_file, '--passphrase', 'pass'
])
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(cmds)
self.assertEqual(context.exception.code, 2)
def test_generate_intermediate_cert(self):
cmds = (
'generate_intermediate_certificate --valid_duration 10 --system_id 100'
).split()
csr_file = os.path.join(self.test_dir, 'csr')
self._fill_file_with_dummy_contents(csr_file)
root_certificate_file = os.path.join(self.test_dir, 'root_cert')
self._fill_file_with_dummy_contents(root_certificate_file)
root_private_key_file = os.path.join(self.test_dir, 'root_private_key')
self._fill_file_with_dummy_contents(root_private_key_file)
output_certificate_file = os.path.join(self.test_dir, 'cert')
cmds.extend([
'--csr_file', csr_file, '--root_certificate_file',
root_certificate_file, '--root_private_key_file', root_private_key_file,
'--root_private_key_passphrase', 'root_key',
'--output_certificate_file', output_certificate_file
])
args = self.parser.parse_args(cmds)
self.assertAlmostEqual(
args.not_valid_before,
datetime.datetime.today(),
delta=datetime.timedelta(seconds=60))
self.assertEqual(args.valid_duration, 10)
self.assertEqual(args.system_id, 100)
self.assertEqual(args.csr_file.name, csr_file)
self.assertEqual(args.csr_file.mode, 'rb')
self.assertEqual(args.root_certificate_file.name, root_certificate_file)
self.assertEqual(args.root_certificate_file.mode, 'rb')
self.assertEqual(args.root_private_key_file.name, root_private_key_file)
self.assertEqual(args.root_private_key_file.mode, 'rb')
self.assertEqual(args.root_private_key_passphrase, 'root_key')
self.assertEqual(args.output_certificate_file.name, output_certificate_file)
self.assertEqual(args.output_certificate_file.mode, 'wb')
self.assertEqual(args.func,
oem_certificate.generate_intermediate_certificate)
def test_generate_leaf_cert(self):
cmds = ('generate_leaf_certificate --not_valid_before 2016-01-02 '
'--valid_duration 10').split()
intermediate_certificate_file = os.path.join(self.test_dir,
'intermediate_cert')
self._fill_file_with_dummy_contents(intermediate_certificate_file)
intermediate_private_key_file = os.path.join(self.test_dir,
'intermediate_private_key')
self._fill_file_with_dummy_contents(intermediate_private_key_file)
output_certificate_file = os.path.join(self.test_dir, 'cert')
output_private_key_file = os.path.join(self.test_dir, 'key')
cmds.extend([
'--intermediate_certificate_file', intermediate_certificate_file,
'--intermediate_private_key_file', intermediate_private_key_file,
'--intermediate_private_key_passphrase', 'intermediate_key',
'--output_certificate_file', output_certificate_file,
'--output_private_key_file', output_private_key_file, '--passphrase',
'leaf_key'
])
args = self.parser.parse_args(cmds)
self.assertEqual(args.not_valid_before, datetime.datetime(2016, 1, 2))
self.assertEqual(args.valid_duration, 10)
self.assertEqual(args.intermediate_certificate_file.name,
intermediate_certificate_file)
self.assertEqual(args.intermediate_certificate_file.mode, 'rb')
self.assertEqual(args.intermediate_private_key_file.name,
intermediate_private_key_file)
self.assertEqual(args.intermediate_private_key_file.mode, 'rb')
self.assertEqual(args.intermediate_private_key_passphrase,
'intermediate_key')
self.assertEqual(args.output_certificate_file.name, output_certificate_file)
self.assertEqual(args.output_certificate_file.mode, 'wb')
self.assertEqual(args.output_private_key_file.name, output_private_key_file)
self.assertEqual(args.output_private_key_file.mode, 'wb')
self.assertEqual(args.passphrase, 'leaf_key')
self.assertEqual(args.func, oem_certificate.generate_leaf_certificate)
def test_generate_leaf_cert_invalid_date(self):
cmds = ('generate_leaf_certificate --not_valid_before invaid-date '
'--valid_duration 10').split()
intermediate_certificate_file = os.path.join(self.test_dir,
'intermediate_cert')
self._fill_file_with_dummy_contents(intermediate_certificate_file)
intermediate_private_key_file = os.path.join(self.test_dir,
'intermediate_private_key')
self._fill_file_with_dummy_contents(intermediate_private_key_file)
output_certificate_file = os.path.join(self.test_dir, 'cert')
output_private_key_file = os.path.join(self.test_dir, 'key')
cmds.extend([
'--intermediate_certificate_file', intermediate_certificate_file,
'--intermediate_private_key_file', intermediate_private_key_file,
'--intermediate_private_key_passphrase', 'intermediate_key',
'--output_certificate_file', output_certificate_file,
'--output_private_key_file', output_private_key_file, '--passphrase',
'leaf_key'
])
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(cmds)
self.assertEqual(context.exception.code, 2)
def test_secure_erase(self):
file_path = os.path.join(self.test_dir, 'file')
self._fill_file_with_dummy_contents(file_path)
cmds = ['erase', '-F', file_path, '--passes', '2']
args = self.parser.parse_args(cmds)
self.assertEqual(args.passes, 2)
self.assertEqual(args.file.name, file_path)
self.assertEqual(args.file.mode, 'a')
self.assertEqual(args.func, oem_certificate.secure_erase)
def test_get_info(self):
file_path = os.path.join(self.test_dir, 'file')
self._fill_file_with_dummy_contents(file_path)
cmds = ['info', '-F', file_path]
args = self.parser.parse_args(cmds)
self.assertEqual(args.file.name, file_path)
self.assertEqual(args.file.mode, 'rb')
self.assertEqual(args.func, oem_certificate.get_info)
def test_arbitrary_commands(self):
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(['unsupport', '--commands'])
self.assertEqual(context.exception.code, 2)
def test_no_argument(self):
with self.assertRaises(SystemExit) as context:
self.parser.parse_args([])
self.assertEqual(context.exception.code, 2)
if __name__ == '__main__':
unittest.main()