feat(vault): Add no_push option to Vault and its subclasses to control key reception

This commit is contained in:
Andy
2025-08-08 23:38:52 +00:00
parent abc3b4f1a4
commit b2686ca2b1
7 changed files with 30 additions and 14 deletions

View File

@@ -4,8 +4,9 @@ from uuid import UUID
class Vault(metaclass=ABCMeta):
def __init__(self, name: str):
def __init__(self, name: str, no_push: bool = False):
self.name = name
self.no_push = no_push
def __str__(self) -> str:
return f"{self.name} {type(self).__name__}"

View File

@@ -57,7 +57,7 @@ class Vaults:
"""Add a KID:KEY to all Vaults, optionally with an exclusion."""
success = 0
for vault in self.vaults:
if vault != excluding:
if vault != excluding and not vault.no_push:
try:
success += vault.add_key(self.service, kid, key)
except (PermissionError, NotImplementedError):
@@ -68,13 +68,15 @@ class Vaults:
"""
Add multiple KID:KEYs to all Vaults. Duplicate Content Keys are skipped.
PermissionErrors when the user cannot create Tables are absorbed and ignored.
Vaults with no_push=True are skipped.
"""
success = 0
for vault in self.vaults:
try:
success += bool(vault.add_keys(self.service, kid_keys))
except (PermissionError, NotImplementedError):
pass
if not vault.no_push:
try:
success += bool(vault.add_keys(self.service, kid_keys))
except (PermissionError, NotImplementedError):
pass
return success

View File

@@ -101,6 +101,8 @@ remote_cdm:
secret: secret_key
# Key Vaults store your obtained Content Encryption Keys (CEKs)
# Use 'no_push: true' to prevent a vault from receiving pushed keys
# while still allowing it to provide keys when requested
key_vaults:
- type: SQLite
name: Local
@@ -110,6 +112,7 @@ key_vaults:
# name: "Remote Vault"
# uri: "https://key-vault.example.com"
# token: "secret_token"
# no_push: true # This vault will only provide keys, not receive them
# - type: MySQL
# name: "MySQL Vault"
# host: "127.0.0.1"
@@ -117,6 +120,7 @@ key_vaults:
# database: vault
# username: user
# password: pass
# no_push: false # Default behavior - vault both provides and receives keys
# Choose what software to use to download data
downloader: aria2c

View File

@@ -10,8 +10,8 @@ from unshackle.core.vault import Vault
class API(Vault):
"""Key Vault using a simple RESTful HTTP API call."""
def __init__(self, name: str, uri: str, token: str):
super().__init__(name)
def __init__(self, name: str, uri: str, token: str, no_push: bool = False):
super().__init__(name, no_push)
self.uri = uri.rstrip("/")
self.session = Session()
self.session.headers.update({"User-Agent": f"unshackle v{__version__}"})

View File

@@ -18,7 +18,15 @@ class InsertResult(Enum):
class HTTP(Vault):
"""Key Vault using HTTP API with support for both query parameters and JSON payloads."""
def __init__(self, name: str, host: str, password: str, username: Optional[str] = None, api_mode: str = "query"):
def __init__(
self,
name: str,
host: str,
password: str,
username: Optional[str] = None,
api_mode: str = "query",
no_push: bool = False,
):
"""
Initialize HTTP Vault.
@@ -28,8 +36,9 @@ class HTTP(Vault):
password: Password for query mode or API token for json mode
username: Username (required for query mode, ignored for json mode)
api_mode: "query" for query parameters or "json" for JSON API
no_push: If True, this vault will not receive pushed keys
"""
super().__init__(name)
super().__init__(name, no_push)
self.url = host
self.password = password
self.username = username

View File

@@ -12,12 +12,12 @@ from unshackle.core.vault import Vault
class MySQL(Vault):
"""Key Vault using a remotely-accessed mysql database connection."""
def __init__(self, name: str, host: str, database: str, username: str, **kwargs):
def __init__(self, name: str, host: str, database: str, username: str, no_push: bool = False, **kwargs):
"""
All extra arguments provided via **kwargs will be sent to pymysql.connect.
This can be used to provide more specific connection information.
"""
super().__init__(name)
super().__init__(name, no_push)
self.slug = f"{host}:{database}:{username}"
self.conn_factory = ConnectionFactory(
dict(host=host, db=database, user=username, cursorclass=DictCursor, **kwargs)

View File

@@ -12,8 +12,8 @@ from unshackle.core.vault import Vault
class SQLite(Vault):
"""Key Vault using a locally-accessed sqlite DB file."""
def __init__(self, name: str, path: Union[str, Path]):
super().__init__(name)
def __init__(self, name: str, path: Union[str, Path], no_push: bool = False):
super().__init__(name, no_push)
self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict?
self.conn_factory = ConnectionFactory(self.path)