10 Commits
1.4.1 ... 1.4.2

16 changed files with 234 additions and 83 deletions

View File

@@ -5,6 +5,38 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.4.2] - 2025-08-14
### Added
- **Session Management for API Requests**: Enhanced API reliability with retry logic
- Implemented session management for tags functionality with automatic retry mechanisms
- Improved API request stability and error handling
- **Series Year Configuration**: New `series_year` option for title naming control
- Added configurable `series_year` option to control year inclusion in series titles
- Enhanced YAML configuration with series year handling options
- **Audio Language Override**: New audio language selection option
- Added `audio_language` option to override default language selection for audio tracks
- Provides more granular control over audio track selection
- **Vault Key Reception Control**: Enhanced vault security options
- Added `no_push` option to Vault and its subclasses to control key reception
- Improved key management security and flexibility
### Changed
- **HLS Segment Processing**: Enhanced segment retrieval and merging capabilities
- Enhanced segment retrieval to allow all file types for better compatibility
- Improved segment merging with recursive file search and fallback to binary concatenation
- Fixed issues with VTT files from HLS not being found correctly due to format changes
- Added cleanup of empty segment directories after processing
- **Documentation**: Updated README.md with latest information
### Fixed
- **Audio Track Selection**: Improved per-language logic for audio tracks
- Adjusted `per_language` logic to ensure correct audio track selection
- Fixed issue where all tracks for selected language were being downloaded instead of just the intended ones
## [1.4.1] - 2025-08-08 ## [1.4.1] - 2025-08-08
### Added ### Added

View File

@@ -2,6 +2,10 @@
<img width="16" height="16" alt="no_encryption" src="https://github.com/user-attachments/assets/6ff88473-0dd2-4bbc-b1ea-c683d5d7a134" /> unshackle <img width="16" height="16" alt="no_encryption" src="https://github.com/user-attachments/assets/6ff88473-0dd2-4bbc-b1ea-c683d5d7a134" /> unshackle
<br/> <br/>
<sup><em>Movie, TV, and Music Archival Software</em></sup> <sup><em>Movie, TV, and Music Archival Software</em></sup>
<br/>
<a href="https://discord.gg/mHYyPaCbFK">
<img src="https://img.shields.io/discord/1395571732001325127?label=&logo=discord&logoColor=ffffff&color=7289DA&labelColor=7289DA" alt="Discord">
</a>
</p> </p>
## What is unshackle? ## What is unshackle?

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "unshackle" name = "unshackle"
version = "1.4.1" version = "1.4.2"
description = "Modular Movie, TV, and Music Archival Software." description = "Modular Movie, TV, and Music Archival Software."
authors = [{ name = "unshackle team" }] authors = [{ name = "unshackle team" }]
requires-python = ">=3.10,<3.13" requires-python = ">=3.10,<3.13"

View File

@@ -153,6 +153,13 @@ class dl:
default=[], default=[],
help="Language wanted for Video, you would use this if the video language doesn't match the audio.", help="Language wanted for Video, you would use this if the video language doesn't match the audio.",
) )
@click.option(
"-al",
"--a-lang",
type=LANGUAGE_RANGE,
default=[],
help="Language wanted for Audio, overrides -l/--lang for audio tracks.",
)
@click.option("-sl", "--s-lang", type=LANGUAGE_RANGE, default=["all"], help="Language wanted for Subtitles.") @click.option("-sl", "--s-lang", type=LANGUAGE_RANGE, default=["all"], help="Language wanted for Subtitles.")
@click.option("-fs", "--forced-subs", is_flag=True, default=False, help="Include forced subtitle tracks.") @click.option("-fs", "--forced-subs", is_flag=True, default=False, help="Include forced subtitle tracks.")
@click.option( @click.option(
@@ -413,6 +420,7 @@ class dl:
wanted: list[str], wanted: list[str],
lang: list[str], lang: list[str],
v_lang: list[str], v_lang: list[str],
a_lang: list[str],
s_lang: list[str], s_lang: list[str],
forced_subs: bool, forced_subs: bool,
sub_format: Optional[Subtitle.Codec], sub_format: Optional[Subtitle.Codec],
@@ -588,8 +596,9 @@ class dl:
if language not in processed_video_sort_lang: if language not in processed_video_sort_lang:
processed_video_sort_lang.append(language) processed_video_sort_lang.append(language)
audio_sort_lang = a_lang or lang
processed_audio_sort_lang = [] processed_audio_sort_lang = []
for language in lang: for language in audio_sort_lang:
if language == "orig": if language == "orig":
if title.language: if title.language:
orig_lang = str(title.language) if hasattr(title.language, "__str__") else title.language orig_lang = str(title.language) if hasattr(title.language, "__str__") else title.language
@@ -753,9 +762,10 @@ class dl:
if not title.tracks.audio: if not title.tracks.audio:
self.log.error(f"There's no {abitrate}kbps Audio Track...") self.log.error(f"There's no {abitrate}kbps Audio Track...")
sys.exit(1) sys.exit(1)
if lang: audio_languages = a_lang or lang
if audio_languages:
processed_lang = [] processed_lang = []
for language in lang: for language in audio_languages:
if language == "orig": if language == "orig":
if title.language: if title.language:
orig_lang = ( orig_lang = (
@@ -782,7 +792,7 @@ class dl:
selected_audio.append(highest_quality) selected_audio.append(highest_quality)
title.tracks.audio = selected_audio title.tracks.audio = selected_audio
elif "all" not in processed_lang: elif "all" not in processed_lang:
per_language = 0 if len(processed_lang) > 1 else 1 per_language = 1
title.tracks.audio = title.tracks.by_language( title.tracks.audio = title.tracks.by_language(
title.tracks.audio, processed_lang, per_language=per_language title.tracks.audio, processed_lang, per_language=per_language
) )

View File

@@ -1 +1 @@
__version__ = "1.4.1" __version__ = "1.4.2"

View File

@@ -91,6 +91,7 @@ class Config:
self.update_checks: bool = kwargs.get("update_checks", True) self.update_checks: bool = kwargs.get("update_checks", True)
self.update_check_interval: int = kwargs.get("update_check_interval", 24) self.update_check_interval: int = kwargs.get("update_check_interval", 24)
self.scene_naming: bool = kwargs.get("scene_naming", True) self.scene_naming: bool = kwargs.get("scene_naming", True)
self.series_year: bool = kwargs.get("series_year", True)
self.title_cache_time: int = kwargs.get("title_cache_time", 1800) # 30 minutes default self.title_cache_time: int = kwargs.get("title_cache_time", 1800) # 30 minutes default
self.title_cache_max_retention: int = kwargs.get("title_cache_max_retention", 86400) # 24 hours default self.title_cache_max_retention: int = kwargs.get("title_cache_max_retention", 86400) # 24 hours default

View File

@@ -4,6 +4,7 @@ import base64
import html import html
import json import json
import logging import logging
import os
import shutil import shutil
import subprocess import subprocess
import sys import sys
@@ -584,11 +585,24 @@ class HLS:
if DOWNLOAD_LICENCE_ONLY.is_set(): if DOWNLOAD_LICENCE_ONLY.is_set():
return return
if segment_save_dir.exists(): def find_segments_recursively(directory: Path) -> list[Path]:
segment_save_dir.rmdir() """Find all segment files recursively in any directory structure created by downloaders."""
segments = []
# First check direct files in the directory
if directory.exists():
segments.extend([x for x in directory.iterdir() if x.is_file()])
# If no direct files, recursively search subdirectories
if not segments:
for subdir in directory.iterdir():
if subdir.is_dir():
segments.extend(find_segments_recursively(subdir))
return sorted(segments)
# finally merge all the discontinuity save files together to the final path # finally merge all the discontinuity save files together to the final path
segments_to_merge = [x for x in sorted(save_dir.iterdir()) if x.is_file()] segments_to_merge = find_segments_recursively(save_dir)
if len(segments_to_merge) == 1: if len(segments_to_merge) == 1:
shutil.move(segments_to_merge[0], save_path) shutil.move(segments_to_merge[0], save_path)
else: else:
@@ -601,9 +615,16 @@ class HLS:
discontinuity_data = discontinuity_file.read_bytes() discontinuity_data = discontinuity_file.read_bytes()
f.write(discontinuity_data) f.write(discontinuity_data)
f.flush() f.flush()
os.fsync(f.fileno())
discontinuity_file.unlink() discontinuity_file.unlink()
save_dir.rmdir() # Clean up empty segment directory
if save_dir.exists() and save_dir.name.endswith("_segments"):
try:
save_dir.rmdir()
except OSError:
# Directory might not be empty, try removing recursively
shutil.rmtree(save_dir, ignore_errors=True)
progress(downloaded="Downloaded") progress(downloaded="Downloaded")
@@ -613,40 +634,75 @@ class HLS:
@staticmethod @staticmethod
def merge_segments(segments: list[Path], save_path: Path) -> int: def merge_segments(segments: list[Path], save_path: Path) -> int:
""" """
Concatenate Segments by first demuxing with FFmpeg. Concatenate Segments using FFmpeg concat with binary fallback.
Returns the file size of the merged file. Returns the file size of the merged file.
""" """
if not binaries.FFMPEG: # Track segment directories for cleanup
raise EnvironmentError("FFmpeg executable was not found but is required to merge HLS segments.") segment_dirs = set()
demuxer_file = segments[0].parent / "ffmpeg_concat_demuxer.txt"
demuxer_file.write_text("\n".join([f"file '{segment}'" for segment in segments]))
subprocess.check_call(
[
binaries.FFMPEG,
"-hide_banner",
"-loglevel",
"panic",
"-f",
"concat",
"-safe",
"0",
"-i",
demuxer_file,
"-map",
"0",
"-c",
"copy",
save_path,
]
)
demuxer_file.unlink()
for segment in segments: for segment in segments:
segment.unlink() # Track all parent directories that contain segments
current_dir = segment.parent
while current_dir.name and "_segments" in str(current_dir):
segment_dirs.add(current_dir)
current_dir = current_dir.parent
def cleanup_segments_and_dirs():
"""Clean up segments and directories after successful merge."""
for segment in segments:
segment.unlink(missing_ok=True)
for segment_dir in segment_dirs:
if segment_dir.exists():
try:
shutil.rmtree(segment_dir)
except OSError:
pass # Directory cleanup failed, but merge succeeded
# Try FFmpeg concat first (preferred method)
if binaries.FFMPEG:
try:
demuxer_file = save_path.parent / f"ffmpeg_concat_demuxer_{save_path.stem}.txt"
demuxer_file.write_text("\n".join([f"file '{segment.absolute()}'" for segment in segments]))
subprocess.check_call(
[
binaries.FFMPEG,
"-hide_banner",
"-loglevel",
"error",
"-f",
"concat",
"-safe",
"0",
"-i",
demuxer_file,
"-map",
"0",
"-c",
"copy",
save_path,
],
timeout=300, # 5 minute timeout
)
demuxer_file.unlink(missing_ok=True)
cleanup_segments_and_dirs()
return save_path.stat().st_size
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, OSError) as e:
# FFmpeg failed, clean up demuxer file and fall back to binary concat
logging.getLogger("HLS").debug(f"FFmpeg concat failed ({e}), falling back to binary concatenation")
demuxer_file.unlink(missing_ok=True)
# Remove partial output file if it exists
save_path.unlink(missing_ok=True)
# Fallback: Binary concatenation
logging.getLogger("HLS").debug(f"Using binary concatenation for {len(segments)} segments")
with open(save_path, "wb") as output_file:
for segment in segments:
with open(segment, "rb") as segment_file:
output_file.write(segment_file.read())
cleanup_segments_and_dirs()
return save_path.stat().st_size return save_path.stat().st_size
@staticmethod @staticmethod

View File

@@ -81,7 +81,7 @@ class Episode(Title):
def __str__(self) -> str: def __str__(self) -> str:
return "{title}{year} S{season:02}E{number:02} {name}".format( return "{title}{year} S{season:02}E{number:02} {name}".format(
title=self.title, title=self.title,
year=f" {self.year}" if self.year else "", year=f" {self.year}" if self.year and config.series_year else "",
season=self.season, season=self.season,
number=self.number, number=self.number,
name=self.name or "", name=self.name or "",
@@ -95,13 +95,13 @@ class Episode(Title):
# Title [Year] SXXEXX Name (or Title [Year] SXX if folder) # Title [Year] SXXEXX Name (or Title [Year] SXX if folder)
if folder: if folder:
name = f"{self.title}" name = f"{self.title}"
if self.year: if self.year and config.series_year:
name += f" {self.year}" name += f" {self.year}"
name += f" S{self.season:02}" name += f" S{self.season:02}"
else: else:
name = "{title}{year} S{season:02}E{number:02} {name}".format( name = "{title}{year} S{season:02}E{number:02} {name}".format(
title=self.title.replace("$", "S"), # e.g., Arli$$ title=self.title.replace("$", "S"), # e.g., Arli$$
year=f" {self.year}" if self.year else "", year=f" {self.year}" if self.year and config.series_year else "",
season=self.season, season=self.season,
number=self.number, number=self.number,
name=self.name or "", name=self.name or "",
@@ -197,7 +197,7 @@ class Series(SortedKeyList, ABC):
def __str__(self) -> str: def __str__(self) -> str:
if not self: if not self:
return super().__str__() return super().__str__()
return self[0].title + (f" ({self[0].year})" if self[0].year else "") return self[0].title + (f" ({self[0].year})" if self[0].year and config.series_year else "")
def tree(self, verbose: bool = False) -> Tree: def tree(self, verbose: bool = False) -> Tree:
seasons = Counter(x.season for x in self) seasons = Counter(x.season for x in self)

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Optional, Tuple from typing import Optional, Tuple
import requests import requests
from requests.adapters import HTTPAdapter, Retry
from unshackle.core import binaries from unshackle.core import binaries
from unshackle.core.config import config from unshackle.core.config import config
@@ -25,6 +26,22 @@ HEADERS = {"User-Agent": "unshackle-tags/1.0"}
log = logging.getLogger("TAGS") log = logging.getLogger("TAGS")
def _get_session() -> requests.Session:
"""Create a requests session with retry logic for network failures."""
session = requests.Session()
session.headers.update(HEADERS)
retry = Retry(
total=3, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["GET", "POST"]
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)
session.mount("http://", adapter)
return session
def _api_key() -> Optional[str]: def _api_key() -> Optional[str]:
return config.tmdb_api_key or os.getenv("TMDB_API_KEY") return config.tmdb_api_key or os.getenv("TMDB_API_KEY")
@@ -59,7 +76,8 @@ def search_simkl(title: str, year: Optional[int], kind: str) -> Tuple[Optional[d
filename += " 2160p.mkv" filename += " 2160p.mkv"
try: try:
resp = requests.post("https://api.simkl.com/search/file", json={"file": filename}, headers=HEADERS, timeout=30) session = _get_session()
resp = session.post("https://api.simkl.com/search/file", json={"file": filename}, timeout=30)
resp.raise_for_status() resp.raise_for_status()
data = resp.json() data = resp.json()
log.debug("Simkl API response received") log.debug("Simkl API response received")
@@ -139,17 +157,21 @@ def search_tmdb(title: str, year: Optional[int], kind: str) -> Tuple[Optional[in
if year is not None: if year is not None:
params["year" if kind == "movie" else "first_air_date_year"] = year params["year" if kind == "movie" else "first_air_date_year"] = year
r = requests.get( try:
f"https://api.themoviedb.org/3/search/{kind}", session = _get_session()
params=params, r = session.get(
headers=HEADERS, f"https://api.themoviedb.org/3/search/{kind}",
timeout=30, params=params,
) timeout=30,
r.raise_for_status() )
js = r.json() r.raise_for_status()
results = js.get("results") or [] js = r.json()
log.debug("TMDB returned %d results", len(results)) results = js.get("results") or []
if not results: log.debug("TMDB returned %d results", len(results))
if not results:
return None, None
except requests.RequestException as exc:
log.warning("Failed to search TMDB for %s: %s", title, exc)
return None, None return None, None
best_ratio = 0.0 best_ratio = 0.0
@@ -196,10 +218,10 @@ def get_title(tmdb_id: int, kind: str) -> Optional[str]:
return None return None
try: try:
r = requests.get( session = _get_session()
r = session.get(
f"https://api.themoviedb.org/3/{kind}/{tmdb_id}", f"https://api.themoviedb.org/3/{kind}/{tmdb_id}",
params={"api_key": api_key}, params={"api_key": api_key},
headers=HEADERS,
timeout=30, timeout=30,
) )
r.raise_for_status() r.raise_for_status()
@@ -219,10 +241,10 @@ def get_year(tmdb_id: int, kind: str) -> Optional[int]:
return None return None
try: try:
r = requests.get( session = _get_session()
r = session.get(
f"https://api.themoviedb.org/3/{kind}/{tmdb_id}", f"https://api.themoviedb.org/3/{kind}/{tmdb_id}",
params={"api_key": api_key}, params={"api_key": api_key},
headers=HEADERS,
timeout=30, timeout=30,
) )
r.raise_for_status() r.raise_for_status()
@@ -243,16 +265,21 @@ def external_ids(tmdb_id: int, kind: str) -> dict:
return {} return {}
url = f"https://api.themoviedb.org/3/{kind}/{tmdb_id}/external_ids" url = f"https://api.themoviedb.org/3/{kind}/{tmdb_id}/external_ids"
log.debug("Fetching external IDs for %s %s", kind, tmdb_id) log.debug("Fetching external IDs for %s %s", kind, tmdb_id)
r = requests.get(
url, try:
params={"api_key": api_key}, session = _get_session()
headers=HEADERS, r = session.get(
timeout=30, url,
) params={"api_key": api_key},
r.raise_for_status() timeout=30,
js = r.json() )
log.debug("External IDs response: %s", js) r.raise_for_status()
return js js = r.json()
log.debug("External IDs response: %s", js)
return js
except requests.RequestException as exc:
log.warning("Failed to fetch external IDs for %s %s: %s", kind, tmdb_id, exc)
return {}
def _apply_tags(path: Path, tags: dict[str, str]) -> None: def _apply_tags(path: Path, tags: dict[str, str]) -> None:

View File

@@ -4,8 +4,9 @@ from uuid import UUID
class Vault(metaclass=ABCMeta): class Vault(metaclass=ABCMeta):
def __init__(self, name: str): def __init__(self, name: str, no_push: bool = False):
self.name = name self.name = name
self.no_push = no_push
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.name} {type(self).__name__}" return f"{self.name} {type(self).__name__}"

View File

@@ -57,7 +57,7 @@ class Vaults:
"""Add a KID:KEY to all Vaults, optionally with an exclusion.""" """Add a KID:KEY to all Vaults, optionally with an exclusion."""
success = 0 success = 0
for vault in self.vaults: for vault in self.vaults:
if vault != excluding: if vault != excluding and not vault.no_push:
try: try:
success += vault.add_key(self.service, kid, key) success += vault.add_key(self.service, kid, key)
except (PermissionError, NotImplementedError): except (PermissionError, NotImplementedError):
@@ -68,13 +68,15 @@ class Vaults:
""" """
Add multiple KID:KEYs to all Vaults. Duplicate Content Keys are skipped. Add multiple KID:KEYs to all Vaults. Duplicate Content Keys are skipped.
PermissionErrors when the user cannot create Tables are absorbed and ignored. PermissionErrors when the user cannot create Tables are absorbed and ignored.
Vaults with no_push=True are skipped.
""" """
success = 0 success = 0
for vault in self.vaults: for vault in self.vaults:
try: if not vault.no_push:
success += bool(vault.add_keys(self.service, kid_keys)) try:
except (PermissionError, NotImplementedError): success += bool(vault.add_keys(self.service, kid_keys))
pass except (PermissionError, NotImplementedError):
pass
return success return success

View File

@@ -15,6 +15,11 @@ set_terminal_bg: false
# false for style - Prime Suspect S07E01 The Final Act - Part One # false for style - Prime Suspect S07E01 The Final Act - Part One
scene_naming: true scene_naming: true
# Whether to include the year in series names for episodes and folders (default: true)
# true for style - Show Name (2023) S01E01 Episode Name
# false for style - Show Name S01E01 Episode Name
series_year: true
# Check for updates from GitHub repository on startup (default: true) # Check for updates from GitHub repository on startup (default: true)
update_checks: true update_checks: true
@@ -101,6 +106,8 @@ remote_cdm:
secret: secret_key secret: secret_key
# Key Vaults store your obtained Content Encryption Keys (CEKs) # Key Vaults store your obtained Content Encryption Keys (CEKs)
# Use 'no_push: true' to prevent a vault from receiving pushed keys
# while still allowing it to provide keys when requested
key_vaults: key_vaults:
- type: SQLite - type: SQLite
name: Local name: Local
@@ -110,6 +117,7 @@ key_vaults:
# name: "Remote Vault" # name: "Remote Vault"
# uri: "https://key-vault.example.com" # uri: "https://key-vault.example.com"
# token: "secret_token" # token: "secret_token"
# no_push: true # This vault will only provide keys, not receive them
# - type: MySQL # - type: MySQL
# name: "MySQL Vault" # name: "MySQL Vault"
# host: "127.0.0.1" # host: "127.0.0.1"
@@ -117,6 +125,7 @@ key_vaults:
# database: vault # database: vault
# username: user # username: user
# password: pass # password: pass
# no_push: false # Default behavior - vault both provides and receives keys
# Choose what software to use to download data # Choose what software to use to download data
downloader: aria2c downloader: aria2c

View File

@@ -10,8 +10,8 @@ from unshackle.core.vault import Vault
class API(Vault): class API(Vault):
"""Key Vault using a simple RESTful HTTP API call.""" """Key Vault using a simple RESTful HTTP API call."""
def __init__(self, name: str, uri: str, token: str): def __init__(self, name: str, uri: str, token: str, no_push: bool = False):
super().__init__(name) super().__init__(name, no_push)
self.uri = uri.rstrip("/") self.uri = uri.rstrip("/")
self.session = Session() self.session = Session()
self.session.headers.update({"User-Agent": f"unshackle v{__version__}"}) self.session.headers.update({"User-Agent": f"unshackle v{__version__}"})

View File

@@ -18,7 +18,15 @@ class InsertResult(Enum):
class HTTP(Vault): class HTTP(Vault):
"""Key Vault using HTTP API with support for both query parameters and JSON payloads.""" """Key Vault using HTTP API with support for both query parameters and JSON payloads."""
def __init__(self, name: str, host: str, password: str, username: Optional[str] = None, api_mode: str = "query"): def __init__(
self,
name: str,
host: str,
password: str,
username: Optional[str] = None,
api_mode: str = "query",
no_push: bool = False,
):
""" """
Initialize HTTP Vault. Initialize HTTP Vault.
@@ -28,8 +36,9 @@ class HTTP(Vault):
password: Password for query mode or API token for json mode password: Password for query mode or API token for json mode
username: Username (required for query mode, ignored for json mode) username: Username (required for query mode, ignored for json mode)
api_mode: "query" for query parameters or "json" for JSON API api_mode: "query" for query parameters or "json" for JSON API
no_push: If True, this vault will not receive pushed keys
""" """
super().__init__(name) super().__init__(name, no_push)
self.url = host self.url = host
self.password = password self.password = password
self.username = username self.username = username

View File

@@ -12,12 +12,12 @@ from unshackle.core.vault import Vault
class MySQL(Vault): class MySQL(Vault):
"""Key Vault using a remotely-accessed mysql database connection.""" """Key Vault using a remotely-accessed mysql database connection."""
def __init__(self, name: str, host: str, database: str, username: str, **kwargs): def __init__(self, name: str, host: str, database: str, username: str, no_push: bool = False, **kwargs):
""" """
All extra arguments provided via **kwargs will be sent to pymysql.connect. All extra arguments provided via **kwargs will be sent to pymysql.connect.
This can be used to provide more specific connection information. This can be used to provide more specific connection information.
""" """
super().__init__(name) super().__init__(name, no_push)
self.slug = f"{host}:{database}:{username}" self.slug = f"{host}:{database}:{username}"
self.conn_factory = ConnectionFactory( self.conn_factory = ConnectionFactory(
dict(host=host, db=database, user=username, cursorclass=DictCursor, **kwargs) dict(host=host, db=database, user=username, cursorclass=DictCursor, **kwargs)

View File

@@ -12,8 +12,8 @@ from unshackle.core.vault import Vault
class SQLite(Vault): class SQLite(Vault):
"""Key Vault using a locally-accessed sqlite DB file.""" """Key Vault using a locally-accessed sqlite DB file."""
def __init__(self, name: str, path: Union[str, Path]): def __init__(self, name: str, path: Union[str, Path], no_push: bool = False):
super().__init__(name) super().__init__(name, no_push)
self.path = Path(path).expanduser() self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict? # TODO: Use a DictCursor or such to get fetches as dict?
self.conn_factory = ConnectionFactory(self.path) self.conn_factory = ConnectionFactory(self.path)