Files
romm/backend/handler/sync/ssh_handler.py
T
Georges-Antoine Assi 26387c3bc1 fix check
2026-04-06 09:09:46 -04:00

228 lines
7.9 KiB
Python

"""SSH/SFTP handler for Push-Pull sync mode.
Provides methods to connect to remote devices via SSH, list remote save files,
and perform bidirectional file transfers using SFTP.
SSH keys are expected to be pre-mounted on the server (e.g. via Docker volume)
at the path configured by SYNC_SSH_KEYS_PATH. Keys are looked up by device_id
({SYNC_SSH_KEYS_PATH}/{device_id}.pem) or via an explicit ssh_key_path in the
device's sync_config.
"""
from __future__ import annotations
import hashlib
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import asyncssh
from anyio import Path as AnyioPath
from anyio import open_file
from config import SYNC_SSH_KEYS_PATH, SYNC_SSH_KNOWN_HOSTS_PATH
from logger.logger import log
@dataclass
class RemoteSaveInfo:
"""Information about a save file on a remote device."""
path: str
file_name: str
platform_slug: str
file_size: int
mtime: datetime
content_hash: str | None = None
class SSHSyncHandler:
"""Handles SSH/SFTP operations for push-pull sync mode.
SSH keys are expected to be pre-mounted on the server filesystem at
SYNC_SSH_KEYS_PATH. The handler looks up keys by device_id convention
({keys_path}/{device_id}.pem) or uses an explicit path from sync_config.
"""
def __init__(self) -> None:
self.keys_path = Path(SYNC_SSH_KEYS_PATH)
self.keys_path.mkdir(parents=True, exist_ok=True)
def _resolve_key_path(self, device_id: str, sync_config: dict) -> str | None:
"""Resolve the SSH key path for a device.
Checks, in order:
1. Explicit ssh_key_path in sync_config
2. Convention-based path: {SYNC_SSH_KEYS_PATH}/{device_id}.pem
"""
explicit = sync_config.get("ssh_key_path")
if explicit and os.path.isfile(explicit):
return explicit
convention_path = self.keys_path / f"{device_id}.pem"
if convention_path.is_file():
return str(convention_path)
return None
async def connect(
self, sync_config: dict, device_id: str | None = None
) -> asyncssh.SSHClientConnection:
"""Establish an SSH connection using device sync_config.
SSH keys should be pre-mounted on the server. The handler resolves
the key by checking sync_config.ssh_key_path first, then falls back
to the convention-based path {SYNC_SSH_KEYS_PATH}/{device_id}.pem.
sync_config should contain:
- ssh_host: hostname or IP
- ssh_port: port (default 22)
- ssh_username: username
- ssh_key_path: explicit path to private key (optional)
- ssh_password: password (optional, fallback if no key found)
"""
host = sync_config["ssh_host"]
port = sync_config.get("ssh_port", 22)
username = sync_config.get("ssh_username", "root")
if not AnyioPath(SYNC_SSH_KNOWN_HOSTS_PATH).is_file():
raise FileNotFoundError(
f"SSH known_hosts file not found at {SYNC_SSH_KNOWN_HOSTS_PATH}. "
"Mount a known_hosts file or set SYNC_SSH_KNOWN_HOSTS_PATH."
)
connect_kwargs: dict[str, Any] = {
"host": host,
"port": port,
"username": username,
"known_hosts": SYNC_SSH_KNOWN_HOSTS_PATH,
}
# Resolve key path (explicit or convention-based)
key_path = self._resolve_key_path(device_id or "", sync_config)
if key_path:
connect_kwargs["client_keys"] = [key_path]
elif sync_config.get("ssh_password"):
connect_kwargs["password"] = sync_config["ssh_password"]
else:
raise ValueError(
f"No SSH authentication method available for {host}. "
f"Mount a key at {self.keys_path}/{{device_id}}.pem or "
"provide ssh_key_path/ssh_password in sync_config."
)
log.info(f"Connecting to {username}@{host}:{port}")
return await asyncssh.connect(**connect_kwargs)
async def list_remote_saves(
self,
conn: asyncssh.SSHClientConnection,
save_directories: list[dict],
) -> list[RemoteSaveInfo]:
"""List save files on a remote device.
save_directories is a list of dicts with keys:
- platform_slug: str
- path: str (remote directory path)
- extension: str (optional, file extension filter, e.g. ".srm")
"""
results: list[RemoteSaveInfo] = []
async with conn.start_sftp_client() as sftp:
for dir_config in save_directories:
platform_slug = dir_config["platform_slug"]
remote_path = dir_config["path"]
extension = dir_config.get("extension", "")
try:
entries = await sftp.listdir(remote_path)
except asyncssh.SFTPNoSuchFile:
log.warning(f"Remote directory not found: {remote_path}")
continue
for entry in entries:
if extension and not entry.endswith(extension):
continue
full_remote_path = f"{remote_path}/{entry}"
try:
attrs = await sftp.stat(full_remote_path)
if not attrs.type == asyncssh.constants.FILEXFER_TYPE_REGULAR:
continue
mtime = datetime.fromtimestamp(
attrs.mtime or 0, tz=timezone.utc
)
results.append(
RemoteSaveInfo(
path=full_remote_path,
file_name=entry,
platform_slug=platform_slug,
file_size=attrs.size or 0,
mtime=mtime,
)
)
except asyncssh.SFTPError as e:
log.warning(f"Failed to stat {full_remote_path}: {e}")
return results
async def download_save(
self,
conn: asyncssh.SSHClientConnection,
remote_path: str,
local_path: str | None = None,
) -> tuple[str, str]:
"""Download a save file from a remote device.
Returns (local_temp_path, content_hash).
"""
if local_path is None:
fd, local_path = tempfile.mkstemp(prefix="romm_sync_")
os.close(fd)
async with conn.start_sftp_client() as sftp:
await sftp.get(remote_path, local_path)
# Compute hash
hash_obj = hashlib.md5(usedforsecurity=False)
async with await open_file(local_path, "rb") as f:
while chunk := await f.read(8192):
hash_obj.update(chunk)
return local_path, hash_obj.hexdigest()
async def upload_save(
self,
conn: asyncssh.SSHClientConnection,
local_path: str,
remote_path: str,
) -> None:
"""Upload a save file to a remote device."""
async with conn.start_sftp_client() as sftp:
# Ensure remote directory exists
remote_dir = str(Path(remote_path).parent)
try:
await sftp.mkdir(remote_dir)
except asyncssh.SFTPError:
pass # Directory likely already exists
await sftp.put(local_path, remote_path)
log.info(f"Uploaded {local_path} -> {remote_path}")
async def delete_remote_save(
self,
conn: asyncssh.SSHClientConnection,
remote_path: str,
) -> None:
"""Delete a save file from a remote device."""
async with conn.start_sftp_client() as sftp:
await sftp.remove(remote_path)
log.info(f"Deleted remote file: {remote_path}")
ssh_sync_handler = SSHSyncHandler()