mirror of
https://github.com/rommapp/romm.git
synced 2026-04-23 06:54:40 +00:00
228 lines
7.9 KiB
Python
228 lines
7.9 KiB
Python
"""SSH/SFTP handler for Push-Pull sync mode.
|
|
|
|
Provides methods to connect to remote devices via SSH, list remote save files,
|
|
and perform bidirectional file transfers using SFTP.
|
|
|
|
SSH keys are expected to be pre-mounted on the server (e.g. via Docker volume)
|
|
at the path configured by SYNC_SSH_KEYS_PATH. Keys are looked up by device_id
|
|
({SYNC_SSH_KEYS_PATH}/{device_id}.pem) or via an explicit ssh_key_path in the
|
|
device's sync_config.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import os
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import asyncssh
|
|
from anyio import Path as AnyioPath
|
|
from anyio import open_file
|
|
|
|
from config import SYNC_SSH_KEYS_PATH, SYNC_SSH_KNOWN_HOSTS_PATH
|
|
from logger.logger import log
|
|
|
|
|
|
@dataclass
|
|
class RemoteSaveInfo:
|
|
"""Information about a save file on a remote device."""
|
|
|
|
path: str
|
|
file_name: str
|
|
platform_slug: str
|
|
file_size: int
|
|
mtime: datetime
|
|
content_hash: str | None = None
|
|
|
|
|
|
class SSHSyncHandler:
|
|
"""Handles SSH/SFTP operations for push-pull sync mode.
|
|
|
|
SSH keys are expected to be pre-mounted on the server filesystem at
|
|
SYNC_SSH_KEYS_PATH. The handler looks up keys by device_id convention
|
|
({keys_path}/{device_id}.pem) or uses an explicit path from sync_config.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.keys_path = Path(SYNC_SSH_KEYS_PATH)
|
|
self.keys_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _resolve_key_path(self, device_id: str, sync_config: dict) -> str | None:
|
|
"""Resolve the SSH key path for a device.
|
|
|
|
Checks, in order:
|
|
1. Explicit ssh_key_path in sync_config
|
|
2. Convention-based path: {SYNC_SSH_KEYS_PATH}/{device_id}.pem
|
|
"""
|
|
explicit = sync_config.get("ssh_key_path")
|
|
if explicit and os.path.isfile(explicit):
|
|
return explicit
|
|
|
|
convention_path = self.keys_path / f"{device_id}.pem"
|
|
if convention_path.is_file():
|
|
return str(convention_path)
|
|
|
|
return None
|
|
|
|
async def connect(
|
|
self, sync_config: dict, device_id: str | None = None
|
|
) -> asyncssh.SSHClientConnection:
|
|
"""Establish an SSH connection using device sync_config.
|
|
|
|
SSH keys should be pre-mounted on the server. The handler resolves
|
|
the key by checking sync_config.ssh_key_path first, then falls back
|
|
to the convention-based path {SYNC_SSH_KEYS_PATH}/{device_id}.pem.
|
|
|
|
sync_config should contain:
|
|
- ssh_host: hostname or IP
|
|
- ssh_port: port (default 22)
|
|
- ssh_username: username
|
|
- ssh_key_path: explicit path to private key (optional)
|
|
- ssh_password: password (optional, fallback if no key found)
|
|
"""
|
|
host = sync_config["ssh_host"]
|
|
port = sync_config.get("ssh_port", 22)
|
|
username = sync_config.get("ssh_username", "root")
|
|
|
|
if not AnyioPath(SYNC_SSH_KNOWN_HOSTS_PATH).is_file():
|
|
raise FileNotFoundError(
|
|
f"SSH known_hosts file not found at {SYNC_SSH_KNOWN_HOSTS_PATH}. "
|
|
"Mount a known_hosts file or set SYNC_SSH_KNOWN_HOSTS_PATH."
|
|
)
|
|
|
|
connect_kwargs: dict[str, Any] = {
|
|
"host": host,
|
|
"port": port,
|
|
"username": username,
|
|
"known_hosts": SYNC_SSH_KNOWN_HOSTS_PATH,
|
|
}
|
|
|
|
# Resolve key path (explicit or convention-based)
|
|
key_path = self._resolve_key_path(device_id or "", sync_config)
|
|
if key_path:
|
|
connect_kwargs["client_keys"] = [key_path]
|
|
elif sync_config.get("ssh_password"):
|
|
connect_kwargs["password"] = sync_config["ssh_password"]
|
|
else:
|
|
raise ValueError(
|
|
f"No SSH authentication method available for {host}. "
|
|
f"Mount a key at {self.keys_path}/{{device_id}}.pem or "
|
|
"provide ssh_key_path/ssh_password in sync_config."
|
|
)
|
|
|
|
log.info(f"Connecting to {username}@{host}:{port}")
|
|
return await asyncssh.connect(**connect_kwargs)
|
|
|
|
async def list_remote_saves(
|
|
self,
|
|
conn: asyncssh.SSHClientConnection,
|
|
save_directories: list[dict],
|
|
) -> list[RemoteSaveInfo]:
|
|
"""List save files on a remote device.
|
|
|
|
save_directories is a list of dicts with keys:
|
|
- platform_slug: str
|
|
- path: str (remote directory path)
|
|
- extension: str (optional, file extension filter, e.g. ".srm")
|
|
"""
|
|
results: list[RemoteSaveInfo] = []
|
|
|
|
async with conn.start_sftp_client() as sftp:
|
|
for dir_config in save_directories:
|
|
platform_slug = dir_config["platform_slug"]
|
|
remote_path = dir_config["path"]
|
|
extension = dir_config.get("extension", "")
|
|
|
|
try:
|
|
entries = await sftp.listdir(remote_path)
|
|
except asyncssh.SFTPNoSuchFile:
|
|
log.warning(f"Remote directory not found: {remote_path}")
|
|
continue
|
|
|
|
for entry in entries:
|
|
if extension and not entry.endswith(extension):
|
|
continue
|
|
|
|
full_remote_path = f"{remote_path}/{entry}"
|
|
try:
|
|
attrs = await sftp.stat(full_remote_path)
|
|
if not attrs.type == asyncssh.constants.FILEXFER_TYPE_REGULAR:
|
|
continue
|
|
|
|
mtime = datetime.fromtimestamp(
|
|
attrs.mtime or 0, tz=timezone.utc
|
|
)
|
|
results.append(
|
|
RemoteSaveInfo(
|
|
path=full_remote_path,
|
|
file_name=entry,
|
|
platform_slug=platform_slug,
|
|
file_size=attrs.size or 0,
|
|
mtime=mtime,
|
|
)
|
|
)
|
|
except asyncssh.SFTPError as e:
|
|
log.warning(f"Failed to stat {full_remote_path}: {e}")
|
|
|
|
return results
|
|
|
|
async def download_save(
|
|
self,
|
|
conn: asyncssh.SSHClientConnection,
|
|
remote_path: str,
|
|
local_path: str | None = None,
|
|
) -> tuple[str, str]:
|
|
"""Download a save file from a remote device.
|
|
|
|
Returns (local_temp_path, content_hash).
|
|
"""
|
|
if local_path is None:
|
|
fd, local_path = tempfile.mkstemp(prefix="romm_sync_")
|
|
os.close(fd)
|
|
|
|
async with conn.start_sftp_client() as sftp:
|
|
await sftp.get(remote_path, local_path)
|
|
|
|
# Compute hash
|
|
hash_obj = hashlib.md5(usedforsecurity=False)
|
|
async with await open_file(local_path, "rb") as f:
|
|
while chunk := await f.read(8192):
|
|
hash_obj.update(chunk)
|
|
|
|
return local_path, hash_obj.hexdigest()
|
|
|
|
async def upload_save(
|
|
self,
|
|
conn: asyncssh.SSHClientConnection,
|
|
local_path: str,
|
|
remote_path: str,
|
|
) -> None:
|
|
"""Upload a save file to a remote device."""
|
|
async with conn.start_sftp_client() as sftp:
|
|
# Ensure remote directory exists
|
|
remote_dir = str(Path(remote_path).parent)
|
|
try:
|
|
await sftp.mkdir(remote_dir)
|
|
except asyncssh.SFTPError:
|
|
pass # Directory likely already exists
|
|
|
|
await sftp.put(local_path, remote_path)
|
|
log.info(f"Uploaded {local_path} -> {remote_path}")
|
|
|
|
async def delete_remote_save(
|
|
self,
|
|
conn: asyncssh.SSHClientConnection,
|
|
remote_path: str,
|
|
) -> None:
|
|
"""Delete a save file from a remote device."""
|
|
async with conn.start_sftp_client() as sftp:
|
|
await sftp.remove(remote_path)
|
|
log.info(f"Deleted remote file: {remote_path}")
|
|
|
|
|
|
ssh_sync_handler = SSHSyncHandler()
|