Files
provisioning_sdk_source/oem_certificate_generator/oem_certificate_test.py
Kongqun Yang 8d17e4549a Export provisioning sdk
Change-Id: I4d47d80444c9507f84896767dc676112ca11e901
2017-01-24 20:06:25 -08:00

539 lines
24 KiB
Python

################################################################################
# Copyright 2016 Google Inc.
#
# This software is licensed under the terms defined in the Widevine Master
# License Agreement. For a copy of this agreement, please contact
# widevine-licensing@google.com.
################################################################################
import datetime
import os
import shutil
import StringIO
import tempfile
import textwrap
import unittest
from cryptography import x509
from cryptography.hazmat import backends
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509 import oid
import oem_certificate
class ArgParseObject(object):
"""A convenient object to allow adding arbitrary attribute to it."""
pass
class OemCertificateTest(unittest.TestCase):
def _setup_csr_args(self, key_size=4096, passphrase=None):
args = ArgParseObject()
args.key_size = key_size
args.country_name = 'US'
args.state_or_province_name = 'WA'
args.locality_name = 'Kirkland'
args.organization_name = 'CompanyXYZ'
args.organizational_unit_name = 'ContentProtection'
args.output_csr_file = StringIO.StringIO()
args.output_private_key_file = StringIO.StringIO()
args.passphrase = passphrase
return args
def test_widevine_system_id(self):
system_id = 1234567890123
self.assertEqual(
oem_certificate.WidevineSystemId.from_int_value(system_id).int_value(),
system_id)
def test_generate_csr(self):
args = self._setup_csr_args()
oem_certificate.generate_csr(args)
# Verify CSR.
csr = x509.load_pem_x509_csr(args.output_csr_file.getvalue(),
backends.default_backend())
subject = csr.subject
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.COUNTRY_NAME)[0].value,
args.country_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.STATE_OR_PROVINCE_NAME)[0]
.value, args.state_or_province_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.LOCALITY_NAME)[0].value,
args.locality_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.ORGANIZATION_NAME)[0].value,
args.organization_name)
self.assertEqual(
subject.get_attributes_for_oid(oid.NameOID.ORGANIZATIONAL_UNIT_NAME)[0]
.value, args.organizational_unit_name)
private_key = serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
args.passphrase,
backend=backends.default_backend())
self.assertEqual(private_key.key_size, args.key_size)
self.assertEqual(csr.public_key().key_size, args.key_size)
# Verify csr and private key match.
self.assertEqual(csr.public_key().public_numbers(),
private_key.public_key().public_numbers())
def test_generate_csr_with_keysize4096_and_passphrase(self):
args = self._setup_csr_args(key_size=4096, passphrase='passphrase_4096')
oem_certificate.generate_csr(args)
private_key = serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
'passphrase_4096',
backend=backends.default_backend())
csr = x509.load_pem_x509_csr(args.output_csr_file.getvalue(),
backends.default_backend())
self.assertEqual(private_key.key_size, 4096)
self.assertEqual(csr.public_key().key_size, 4096)
# Verify csr and private key match.
self.assertEqual(csr.public_key().public_numbers(),
private_key.public_key().public_numbers())
def _create_root_certificate_and_key(self):
key = rsa.generate_private_key(
public_exponent=65537,
key_size=3072,
backend=backends.default_backend())
subject_name = x509.Name(
[x509.NameAttribute(oid.NameOID.COMMON_NAME, u'root_cert')])
certificate = oem_certificate._build_certificate(
subject_name, subject_name, None,
datetime.datetime(2001, 8, 9), 1000, key.public_key(), key, True)
return (key, certificate)
def _setup_intermediate_cert_args(self, csr_bytes, root_key,
root_certificate):
args = ArgParseObject()
args.not_valid_before = datetime.datetime(2001, 8, 9)
args.valid_duration = 100
args.system_id = 1234554321
args.csr_file = StringIO.StringIO(csr_bytes)
args.root_private_key_passphrase = 'root_passphrase'
args.output_certificate_file = StringIO.StringIO()
serialized_private_key = root_key.private_bytes(
serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(
args.root_private_key_passphrase))
serialized_certificate = root_certificate.public_bytes(
serialization.Encoding.DER)
args.root_certificate_file = StringIO.StringIO(serialized_certificate)
args.root_private_key_file = StringIO.StringIO(serialized_private_key)
return args
def test_generate_intermediate_certificate(self):
csr_args = self._setup_csr_args()
oem_certificate.generate_csr(csr_args)
csr_bytes = csr_args.output_csr_file.getvalue()
csr = x509.load_pem_x509_csr(csr_bytes, backends.default_backend())
root_key, root_certificate = self._create_root_certificate_and_key()
args = self._setup_intermediate_cert_args(csr_bytes, root_key,
root_certificate)
oem_certificate.generate_intermediate_certificate(args)
cert = x509.load_der_x509_certificate(
args.output_certificate_file.getvalue(), backends.default_backend())
self.assertEqual(cert.issuer, root_certificate.subject)
self.assertEqual(cert.subject, csr.subject)
system_id_raw_bytes = cert.extensions.get_extension_for_oid(
oem_certificate.WidevineSystemId.oid).value.value
self.assertEqual(
oem_certificate.WidevineSystemId(system_id_raw_bytes).int_value(),
args.system_id)
self.assertEqual(cert.not_valid_before, datetime.datetime(2001, 8, 9))
self.assertEqual(cert.not_valid_after, datetime.datetime(2001, 11, 17))
root_key.public_key().verify(cert.signature, cert.tbs_certificate_bytes,
padding.PKCS1v15(),
cert.signature_hash_algorithm)
def test_generate_intermediate_with_cert_mismatch_root_cert_and_key(self):
root_key1, _ = self._create_root_certificate_and_key()
_, root_certificate2 = self._create_root_certificate_and_key()
args = self._setup_intermediate_cert_args('some csr data', root_key1,
root_certificate2)
with self.assertRaises(ValueError) as context:
oem_certificate.generate_intermediate_certificate(args)
self.assertTrue('certificate does not match' in str(context.exception))
def _setup_leaf_cert_args(self,
intermediate_key_bytes,
intermediate_certificate_bytes,
key_size=1024,
passphrase=None):
args = ArgParseObject()
args.key_size = key_size
args.not_valid_before = datetime.datetime(2001, 8, 9)
args.valid_duration = 8000
args.intermediate_private_key_passphrase = None
args.output_certificate_file = StringIO.StringIO()
args.output_private_key_file = StringIO.StringIO()
args.passphrase = passphrase
args.intermediate_private_key_file = StringIO.StringIO(
intermediate_key_bytes)
args.intermediate_certificate_file = StringIO.StringIO(
intermediate_certificate_bytes)
return args
def _create_intermediate_certificate_and_key_bytes(self,
key_size=4096,
passphrase=None):
csr_args = self._setup_csr_args(key_size, passphrase)
oem_certificate.generate_csr(csr_args)
csr_bytes = csr_args.output_csr_file.getvalue()
root_key, root_certificate = self._create_root_certificate_and_key()
args = self._setup_intermediate_cert_args(csr_bytes, root_key,
root_certificate)
oem_certificate.generate_intermediate_certificate(args)
return (csr_args.output_private_key_file.getvalue(),
args.output_certificate_file.getvalue())
def test_generate_leaf_certificate(self):
intermediate_key_bytes, intermediate_certificate_bytes = (
self._create_intermediate_certificate_and_key_bytes())
args = self._setup_leaf_cert_args(intermediate_key_bytes,
intermediate_certificate_bytes)
oem_certificate.generate_leaf_certificate(args)
certificate_chain = oem_certificate.X509CertificateChain.load_der(
args.output_certificate_file.getvalue())
certificates = list(certificate_chain)
self.assertEqual(len(certificates), 2)
intermediate_cert = certificates[1]
leaf_cert = certificates[0]
self.assertEqual(
intermediate_cert.public_bytes(serialization.Encoding.DER),
intermediate_certificate_bytes)
intermediate_cert.public_key().verify(leaf_cert.signature,
leaf_cert.tbs_certificate_bytes,
padding.PKCS1v15(),
leaf_cert.signature_hash_algorithm)
self.assertEqual(leaf_cert.not_valid_before, datetime.datetime(2001, 8, 9))
self.assertEqual(leaf_cert.not_valid_after, datetime.datetime(2023, 7, 5))
system_id_raw_bytes = leaf_cert.extensions.get_extension_for_oid(
oem_certificate.WidevineSystemId.oid).value.value
self.assertEqual(
oem_certificate.WidevineSystemId(system_id_raw_bytes).int_value(),
1234554321)
leaf_key = serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
args.passphrase,
backend=backends.default_backend())
self.assertEqual(leaf_key.key_size, args.key_size)
self.assertEqual(leaf_cert.public_key().key_size, args.key_size)
# Verify cert and private key match.
self.assertEqual(leaf_cert.public_key().public_numbers(),
leaf_key.public_key().public_numbers())
def test_generate_leaf_certificate_with_keysize4096_and_passphrase(self):
intermediate_key_bytes, intermediate_certificate_bytes = (
self._create_intermediate_certificate_and_key_bytes())
args = self._setup_leaf_cert_args(
intermediate_key_bytes,
intermediate_certificate_bytes,
key_size=4096,
passphrase='leaf passphrase')
oem_certificate.generate_leaf_certificate(args)
serialization.load_der_private_key(
args.output_private_key_file.getvalue(),
'leaf passphrase',
backend=backends.default_backend())
self.assertEqual(4096, args.key_size)
def test_get_csr_info(self):
args = self._setup_csr_args()
oem_certificate.generate_csr(args)
args.file = StringIO.StringIO(args.output_csr_file.getvalue())
output = StringIO.StringIO()
oem_certificate.get_info(args, output)
expected_info = """\
CSR Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Key Size: 4096"""
self.assertEqual(output.getvalue(), textwrap.dedent(expected_info))
def test_get_certificate_info(self):
_, intermediate_certificate_bytes = (
self._create_intermediate_certificate_and_key_bytes())
args = ArgParseObject()
args.file = StringIO.StringIO(intermediate_certificate_bytes)
output = StringIO.StringIO()
oem_certificate.get_info(args, output)
expected_info = """\
Certificate Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Issuer Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.3, name=commonName)>, value=u'root_cert')>
Key Size: 4096
Widevine System Id: 1234554321
Not valid before: 2001-08-09 00:00:00
Not valid after: 2001-11-17 00:00:00"""
self.assertEqual(output.getvalue(), textwrap.dedent(expected_info))
def test_get_certificate_chain_info(self):
intermediate_key_bytes, intermediate_certificate_bytes = (
self._create_intermediate_certificate_and_key_bytes())
args = self._setup_leaf_cert_args(intermediate_key_bytes,
intermediate_certificate_bytes)
oem_certificate.generate_leaf_certificate(args)
args.file = StringIO.StringIO(args.output_certificate_file.getvalue())
output = StringIO.StringIO()
oem_certificate.get_info(args, output)
expected_info = """\
Certificate Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.3, name=commonName)>, value=u'1234554321-leaf')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Issuer Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Key Size: 1024
Widevine System Id: 1234554321
Not valid before: 2001-08-09 00:00:00
Not valid after: 2023-07-05 00:00:00
Certificate Subject Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.6, name=countryName)>, value=u'US')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.8, name=stateOrProvinceName)>, value=u'WA')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.7, name=localityName)>, value=u'Kirkland')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, value=u'CompanyXYZ')>
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.11, name=organizationalUnitName)>, value=u'ContentProtection')>
Issuer Name:
<NameAttribute(oid=<ObjectIdentifier(oid=2.5.4.3, name=commonName)>, value=u'root_cert')>
Key Size: 4096
Widevine System Id: 1234554321
Not valid before: 2001-08-09 00:00:00
Not valid after: 2001-11-17 00:00:00"""
self.assertEqual(output.getvalue(), textwrap.dedent(expected_info))
def test_secure_erase(self):
args = ArgParseObject()
args.file = tempfile.NamedTemporaryFile(delete=False)
args.passes = 2
self.assertTrue(os.path.exists(args.file.name))
oem_certificate.secure_erase(args)
self.assertFalse(os.path.exists(args.file.name))
class OemCertificateArgParseTest(unittest.TestCase):
def setUp(self):
self.parser = oem_certificate.create_parser()
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_generate_csr(self):
cmds = ('generate_csr --key_size 4096 -C USA -ST WA '
'-L Kirkland -O Company -OU Widevine').split()
output_private_key_file = os.path.join(self.test_dir, 'private_key')
output_csr_file = os.path.join(self.test_dir, 'csr')
cmds.extend([
'--output_csr_file', output_csr_file, '--output_private_key_file',
output_private_key_file, '--passphrase', 'pass'
])
args = self.parser.parse_args(cmds)
self.assertEqual(args.key_size, 4096)
self.assertEqual(args.country_name, 'USA')
self.assertEqual(args.state_or_province_name, 'WA')
self.assertEqual(args.locality_name, 'Kirkland')
self.assertEqual(args.organization_name, 'Company')
self.assertEqual(args.organizational_unit_name, 'Widevine')
self.assertEqual(args.output_csr_file.name, output_csr_file)
self.assertEqual(args.output_csr_file.mode, 'wb')
self.assertEqual(args.output_private_key_file.name, output_private_key_file)
self.assertEqual(args.output_private_key_file.mode, 'wb')
self.assertEqual(args.passphrase, 'pass')
self.assertEqual(args.func, oem_certificate.generate_csr)
def _fill_file_with_dummy_contents(self, file_name):
with open(file_name, 'wb') as f:
f.write('dummy')
def test_generate_csr_invalid_key_size(self):
cmds = ('generate_csr --key_size unknown -C USA -ST WA '
'-L Kirkland -O Company -OU Widevine').split()
output_private_key_file = os.path.join(self.test_dir, 'private_key')
output_csr_file = os.path.join(self.test_dir, 'csr')
cmds.extend([
'--output_csr_file', output_csr_file, '--output_private_key_file',
output_private_key_file, '--passphrase', 'pass'
])
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(cmds)
self.assertEqual(context.exception.code, 2)
def test_generate_intermediate_cert(self):
cmds = (
'generate_intermediate_certificate --valid_duration 10 --system_id 100'
).split()
csr_file = os.path.join(self.test_dir, 'csr')
self._fill_file_with_dummy_contents(csr_file)
root_certificate_file = os.path.join(self.test_dir, 'root_cert')
self._fill_file_with_dummy_contents(root_certificate_file)
root_private_key_file = os.path.join(self.test_dir, 'root_private_key')
self._fill_file_with_dummy_contents(root_private_key_file)
output_certificate_file = os.path.join(self.test_dir, 'cert')
cmds.extend([
'--csr_file', csr_file, '--root_certificate_file',
root_certificate_file, '--root_private_key_file', root_private_key_file,
'--root_private_key_passphrase', 'root_key',
'--output_certificate_file', output_certificate_file
])
args = self.parser.parse_args(cmds)
self.assertAlmostEqual(
args.not_valid_before,
datetime.datetime.today(),
delta=datetime.timedelta(seconds=60))
self.assertEqual(args.valid_duration, 10)
self.assertEqual(args.system_id, 100)
self.assertEqual(args.csr_file.name, csr_file)
self.assertEqual(args.csr_file.mode, 'rb')
self.assertEqual(args.root_certificate_file.name, root_certificate_file)
self.assertEqual(args.root_certificate_file.mode, 'rb')
self.assertEqual(args.root_private_key_file.name, root_private_key_file)
self.assertEqual(args.root_private_key_file.mode, 'rb')
self.assertEqual(args.root_private_key_passphrase, 'root_key')
self.assertEqual(args.output_certificate_file.name, output_certificate_file)
self.assertEqual(args.output_certificate_file.mode, 'wb')
self.assertEqual(args.func,
oem_certificate.generate_intermediate_certificate)
def test_generate_leaf_cert(self):
cmds = ('generate_leaf_certificate --not_valid_before 2016-01-02 '
'--valid_duration 10').split()
intermediate_certificate_file = os.path.join(self.test_dir,
'intermediate_cert')
self._fill_file_with_dummy_contents(intermediate_certificate_file)
intermediate_private_key_file = os.path.join(self.test_dir,
'intermediate_private_key')
self._fill_file_with_dummy_contents(intermediate_private_key_file)
output_certificate_file = os.path.join(self.test_dir, 'cert')
output_private_key_file = os.path.join(self.test_dir, 'key')
cmds.extend([
'--intermediate_certificate_file', intermediate_certificate_file,
'--intermediate_private_key_file', intermediate_private_key_file,
'--intermediate_private_key_passphrase', 'intermediate_key',
'--output_certificate_file', output_certificate_file,
'--output_private_key_file', output_private_key_file, '--passphrase',
'leaf_key'
])
args = self.parser.parse_args(cmds)
self.assertEqual(args.not_valid_before, datetime.datetime(2016, 1, 2))
self.assertEqual(args.valid_duration, 10)
self.assertEqual(args.intermediate_certificate_file.name,
intermediate_certificate_file)
self.assertEqual(args.intermediate_certificate_file.mode, 'rb')
self.assertEqual(args.intermediate_private_key_file.name,
intermediate_private_key_file)
self.assertEqual(args.intermediate_private_key_file.mode, 'rb')
self.assertEqual(args.intermediate_private_key_passphrase,
'intermediate_key')
self.assertEqual(args.output_certificate_file.name, output_certificate_file)
self.assertEqual(args.output_certificate_file.mode, 'wb')
self.assertEqual(args.output_private_key_file.name, output_private_key_file)
self.assertEqual(args.output_private_key_file.mode, 'wb')
self.assertEqual(args.passphrase, 'leaf_key')
self.assertEqual(args.func, oem_certificate.generate_leaf_certificate)
def test_generate_leaf_cert_invalid_date(self):
cmds = ('generate_leaf_certificate --not_valid_before invaid-date '
'--valid_duration 10').split()
intermediate_certificate_file = os.path.join(self.test_dir,
'intermediate_cert')
self._fill_file_with_dummy_contents(intermediate_certificate_file)
intermediate_private_key_file = os.path.join(self.test_dir,
'intermediate_private_key')
self._fill_file_with_dummy_contents(intermediate_private_key_file)
output_certificate_file = os.path.join(self.test_dir, 'cert')
output_private_key_file = os.path.join(self.test_dir, 'key')
cmds.extend([
'--intermediate_certificate_file', intermediate_certificate_file,
'--intermediate_private_key_file', intermediate_private_key_file,
'--intermediate_private_key_passphrase', 'intermediate_key',
'--output_certificate_file', output_certificate_file,
'--output_private_key_file', output_private_key_file, '--passphrase',
'leaf_key'
])
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(cmds)
self.assertEqual(context.exception.code, 2)
def test_secure_erase(self):
file_path = os.path.join(self.test_dir, 'file')
self._fill_file_with_dummy_contents(file_path)
cmds = ['erase', '-F', file_path, '--passes', '2']
args = self.parser.parse_args(cmds)
self.assertEqual(args.passes, 2)
self.assertEqual(args.file.name, file_path)
self.assertEqual(args.file.mode, 'a')
self.assertEqual(args.func, oem_certificate.secure_erase)
def test_get_info(self):
file_path = os.path.join(self.test_dir, 'file')
self._fill_file_with_dummy_contents(file_path)
cmds = ['info', '-F', file_path]
args = self.parser.parse_args(cmds)
self.assertEqual(args.file.name, file_path)
self.assertEqual(args.file.mode, 'rb')
self.assertEqual(args.func, oem_certificate.get_info)
def test_arbitrary_commands(self):
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(['unsupport', '--commands'])
self.assertEqual(context.exception.code, 2)
def test_no_argument(self):
with self.assertRaises(SystemExit) as context:
self.parser.parse_args([])
self.assertEqual(context.exception.code, 2)
if __name__ == '__main__':
unittest.main()