diff --git a/unshackle/commands/kv.py b/unshackle/commands/kv.py index 60498c8..035f7f7 100644 --- a/unshackle/commands/kv.py +++ b/unshackle/commands/kv.py @@ -12,84 +12,113 @@ from unshackle.core.vault import Vault from unshackle.core.vaults import Vaults +def _load_vaults(vault_names: list[str]) -> Vaults: + """Load and validate vaults by name.""" + vaults = Vaults() + for vault_name in vault_names: + vault_config = next((x for x in config.key_vaults if x["name"] == vault_name), None) + if not vault_config: + raise click.ClickException(f"Vault ({vault_name}) is not defined in the config.") + + vault_type = vault_config["type"] + vault_args = vault_config.copy() + del vault_args["type"] + + if not vaults.load(vault_type, **vault_args): + raise click.ClickException(f"Failed to load vault ({vault_name}).") + + return vaults + + +def _process_service_keys(from_vault: Vault, service: str, log: logging.Logger) -> dict[str, str]: + """Get and validate keys from a vault for a specific service.""" + content_keys = list(from_vault.get_keys(service)) + + bad_keys = {kid: key for kid, key in content_keys if not key or key.count("0") == len(key)} + for kid, key in bad_keys.items(): + log.warning(f"Skipping NULL key: {kid}:{key}") + + return {kid: key for kid, key in content_keys if kid not in bad_keys} + + +def _copy_service_data(to_vault: Vault, from_vault: Vault, service: str, log: logging.Logger) -> int: + """Copy data for a single service between vaults.""" + content_keys = _process_service_keys(from_vault, service, log) + total_count = len(content_keys) + + if total_count == 0: + log.info(f"{service}: No keys found in {from_vault}") + return 0 + + try: + added = to_vault.add_keys(service, content_keys) + except PermissionError: + log.warning(f"{service}: No permission to create table in {to_vault}, skipped") + return 0 + + existed = total_count - added + + if added > 0 and existed > 0: + log.info(f"{service}: {added} added, {existed} skipped ({total_count} total)") + elif added > 0: + log.info(f"{service}: {added} added ({total_count} total)") + else: + log.info(f"{service}: {existed} skipped (all existed)") + + return added + + @click.group(short_help="Manage and configure Key Vaults.", context_settings=context_settings) def kv() -> None: """Manage and configure Key Vaults.""" @kv.command() -@click.argument("to_vault", type=str) -@click.argument("from_vaults", nargs=-1, type=click.UNPROCESSED) +@click.argument("to_vault_name", type=str) +@click.argument("from_vault_names", nargs=-1, type=click.UNPROCESSED) @click.option("-s", "--service", type=str, default=None, help="Only copy data to and from a specific service.") -def copy(to_vault: str, from_vaults: list[str], service: Optional[str] = None) -> None: +def copy(to_vault_name: str, from_vault_names: list[str], service: Optional[str] = None) -> None: """ Copy data from multiple Key Vaults into a single Key Vault. Rows with matching KIDs are skipped unless there's no KEY set. Existing data is not deleted or altered. - The `to_vault` argument is the key vault you wish to copy data to. + The `to_vault_name` argument is the key vault you wish to copy data to. It should be the name of a Key Vault defined in the config. - The `from_vaults` argument is the key vault(s) you wish to take + The `from_vault_names` argument is the key vault(s) you wish to take data from. You may supply multiple key vaults. """ - if not from_vaults: + if not from_vault_names: raise click.ClickException("No Vaults were specified to copy data from.") log = logging.getLogger("kv") - vaults = Vaults() - for vault_name in [to_vault] + list(from_vaults): - vault = next((x for x in config.key_vaults if x["name"] == vault_name), None) - if not vault: - raise click.ClickException(f"Vault ({vault_name}) is not defined in the config.") - vault_type = vault["type"] - vault_args = vault.copy() - del vault_args["type"] - if not vaults.load(vault_type, **vault_args): - raise click.ClickException(f"Failed to load vault ({vault_name}).") + all_vault_names = [to_vault_name] + list(from_vault_names) + vaults = _load_vaults(all_vault_names) - to_vault: Vault = vaults.vaults[0] - from_vaults: list[Vault] = vaults.vaults[1:] + to_vault = vaults.vaults[0] + from_vaults = vaults.vaults[1:] + + vault_names = ", ".join([v.name for v in from_vaults]) + log.info(f"Copying data from {vault_names} → {to_vault.name}") - log.info(f"Copying data from {', '.join([x.name for x in from_vaults])}, into {to_vault.name}") if service: service = Services.get_tag(service) - log.info(f"Only copying data for service {service}") + log.info(f"Filtering by service: {service}") total_added = 0 for from_vault in from_vaults: - if service: - services = [service] - else: - services = from_vault.get_services() - - for service_ in services: - log.info(f"Getting data from {from_vault} for {service_}") - content_keys = list(from_vault.get_keys(service_)) # important as it's a generator we iterate twice - - bad_keys = {kid: key for kid, key in content_keys if not key or key.count("0") == len(key)} - - for kid, key in bad_keys.items(): - log.warning(f"Cannot add a NULL Content Key to a Vault, skipping: {kid}:{key}") - - content_keys = {kid: key for kid, key in content_keys if kid not in bad_keys} - - total_count = len(content_keys) - log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}") - - try: - added = to_vault.add_keys(service_, content_keys) - except PermissionError: - log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...") - continue + services_to_copy = [service] if service else from_vault.get_services() + for service_tag in services_to_copy: + added = _copy_service_data(to_vault, from_vault, service_tag, log) total_added += added - existed = total_count - added - log.info(f"{to_vault} ({service_}): {added} newly added, {existed} already existed (skipped)") - - log.info(f"{to_vault}: {total_added} total newly added") + if total_added > 0: + log.info(f"Successfully added {total_added} new keys to {to_vault}") + else: + log.info("Copy completed - no new keys to add") @kv.command() @@ -106,9 +135,9 @@ def sync(ctx: click.Context, vaults: list[str], service: Optional[str] = None) - if not len(vaults) > 1: raise click.ClickException("You must provide more than one Vault to sync.") - ctx.invoke(copy, to_vault=vaults[0], from_vaults=vaults[1:], service=service) + ctx.invoke(copy, to_vault_name=vaults[0], from_vault_names=vaults[1:], service=service) for i in range(1, len(vaults)): - ctx.invoke(copy, to_vault=vaults[i], from_vaults=[vaults[i - 1]], service=service) + ctx.invoke(copy, to_vault_name=vaults[i], from_vault_names=[vaults[i - 1]], service=service) @kv.command() @@ -135,15 +164,7 @@ def add(file: Path, service: str, vaults: list[str]) -> None: log = logging.getLogger("kv") service = Services.get_tag(service) - vaults_ = Vaults() - for vault_name in vaults: - vault = next((x for x in config.key_vaults if x["name"] == vault_name), None) - if not vault: - raise click.ClickException(f"Vault ({vault_name}) is not defined in the config.") - vault_type = vault["type"] - vault_args = vault.copy() - del vault_args["type"] - vaults_.load(vault_type, **vault_args) + vaults_ = _load_vaults(list(vaults)) data = file.read_text(encoding="utf8") kid_keys: dict[str, str] = {} @@ -173,15 +194,7 @@ def prepare(vaults: list[str]) -> None: """Create Service Tables on Vaults if not yet created.""" log = logging.getLogger("kv") - vaults_ = Vaults() - for vault_name in vaults: - vault = next((x for x in config.key_vaults if x["name"] == vault_name), None) - if not vault: - raise click.ClickException(f"Vault ({vault_name}) is not defined in the config.") - vault_type = vault["type"] - vault_args = vault.copy() - del vault_args["type"] - vaults_.load(vault_type, **vault_args) + vaults_ = _load_vaults(vaults) for vault in vaults_: if hasattr(vault, "has_table") and hasattr(vault, "create_table"): diff --git a/unshackle/vaults/MySQL.py b/unshackle/vaults/MySQL.py index ecd7a90..01b35bb 100644 --- a/unshackle/vaults/MySQL.py +++ b/unshackle/vaults/MySQL.py @@ -131,16 +131,27 @@ class MySQL(Vault): if any(isinstance(kid, UUID) for kid, key_ in kid_keys.items()): kid_keys = {kid.hex if isinstance(kid, UUID) else kid: key_ for kid, key_ in kid_keys.items()} + if not kid_keys: + return 0 + conn = self.conn_factory.get() cursor = conn.cursor() try: + placeholders = ",".join(["%s"] * len(kid_keys)) + cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", list(kid_keys.keys())) + existing_kids = {row["kid"] for row in cursor.fetchall()} + + new_keys = {kid: key for kid, key in kid_keys.items() if kid not in existing_kids} + + if not new_keys: + return 0 + cursor.executemany( - # TODO: SQL injection risk - f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)", - kid_keys.items(), + f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)", + new_keys.items(), ) - return cursor.rowcount + return len(new_keys) finally: conn.commit() cursor.close() diff --git a/unshackle/vaults/SQLite.py b/unshackle/vaults/SQLite.py index 5eaf6a8..d796bfa 100644 --- a/unshackle/vaults/SQLite.py +++ b/unshackle/vaults/SQLite.py @@ -102,16 +102,27 @@ class SQLite(Vault): if any(isinstance(kid, UUID) for kid, key_ in kid_keys.items()): kid_keys = {kid.hex if isinstance(kid, UUID) else kid: key_ for kid, key_ in kid_keys.items()} + if not kid_keys: + return 0 + conn = self.conn_factory.get() cursor = conn.cursor() try: + placeholders = ",".join(["?"] * len(kid_keys)) + cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", list(kid_keys.keys())) + existing_kids = {row[0] for row in cursor.fetchall()} + + new_keys = {kid: key for kid, key in kid_keys.items() if kid not in existing_kids} + + if not new_keys: + return 0 + cursor.executemany( - # TODO: SQL injection risk - f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)", - kid_keys.items(), + f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)", + new_keys.items(), ) - return cursor.rowcount + return len(new_keys) finally: conn.commit() cursor.close()