mirror of
https://github.com/rommapp/romm.git
synced 2026-04-23 06:54:40 +00:00
193 lines
6.0 KiB
Python
193 lines
6.0 KiB
Python
from collections.abc import Sequence
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.exc import NoResultFound
|
|
from sqlalchemy.orm import Session
|
|
|
|
from decorators.database import begin_session
|
|
from models.sync_session import SyncSession, SyncSessionStatus
|
|
|
|
from .base_handler import DBBaseHandler
|
|
|
|
|
|
class DBSyncSessionsHandler(DBBaseHandler):
|
|
@begin_session
|
|
def create_session(
|
|
self,
|
|
device_id: str,
|
|
user_id: int,
|
|
session: Session = None, # type: ignore
|
|
) -> SyncSession:
|
|
sync_session = SyncSession(
|
|
device_id=device_id,
|
|
user_id=user_id,
|
|
status=SyncSessionStatus.PENDING,
|
|
initiated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(sync_session)
|
|
session.flush()
|
|
return sync_session
|
|
|
|
@begin_session
|
|
def get_session(
|
|
self,
|
|
session_id: int,
|
|
user_id: int,
|
|
session: Session = None, # type: ignore
|
|
) -> SyncSession | None:
|
|
return session.scalar(
|
|
select(SyncSession).filter_by(id=session_id, user_id=user_id).limit(1)
|
|
)
|
|
|
|
@begin_session
|
|
def get_active_session(
|
|
self,
|
|
device_id: str,
|
|
user_id: int,
|
|
session: Session = None, # type: ignore
|
|
) -> SyncSession | None:
|
|
return session.scalar(
|
|
select(SyncSession)
|
|
.filter(
|
|
SyncSession.device_id == device_id,
|
|
SyncSession.user_id == user_id,
|
|
SyncSession.status.in_(
|
|
[
|
|
SyncSessionStatus.PENDING,
|
|
SyncSessionStatus.IN_PROGRESS,
|
|
]
|
|
),
|
|
)
|
|
.order_by(SyncSession.initiated_at.desc())
|
|
.limit(1)
|
|
)
|
|
|
|
@begin_session
|
|
def update_session(
|
|
self,
|
|
session_id: int,
|
|
data: dict,
|
|
session: Session = None, # type: ignore
|
|
) -> SyncSession:
|
|
session.execute(
|
|
update(SyncSession)
|
|
.where(SyncSession.id == session_id)
|
|
.values(**data)
|
|
.execution_options(synchronize_session="evaluate")
|
|
)
|
|
result = session.scalar(select(SyncSession).filter_by(id=session_id))
|
|
if not result:
|
|
raise NoResultFound(f"SyncSession {session_id} not found after update")
|
|
return result
|
|
|
|
@begin_session
|
|
def increment_operations_completed(
|
|
self,
|
|
session_id: int,
|
|
user_id: int,
|
|
session: Session = None, # type: ignore
|
|
) -> None:
|
|
session.execute(
|
|
update(SyncSession)
|
|
.where(SyncSession.id == session_id, SyncSession.user_id == user_id)
|
|
.values(
|
|
operations_completed=SyncSession.operations_completed + 1,
|
|
)
|
|
.execution_options(synchronize_session="evaluate")
|
|
)
|
|
|
|
@begin_session
|
|
def complete_session(
|
|
self,
|
|
session_id: int,
|
|
operations_completed: int = 0,
|
|
operations_failed: int = 0,
|
|
session: Session = None, # type: ignore
|
|
) -> SyncSession:
|
|
session.execute(
|
|
update(SyncSession)
|
|
.where(SyncSession.id == session_id)
|
|
.values(
|
|
status=SyncSessionStatus.COMPLETED,
|
|
completed_at=datetime.now(timezone.utc),
|
|
operations_completed=operations_completed,
|
|
operations_failed=operations_failed,
|
|
)
|
|
.execution_options(synchronize_session="evaluate")
|
|
)
|
|
result = session.scalar(select(SyncSession).filter_by(id=session_id))
|
|
if not result:
|
|
raise NoResultFound(f"SyncSession {session_id} not found after complete")
|
|
return result
|
|
|
|
@begin_session
|
|
def fail_session(
|
|
self,
|
|
session_id: int,
|
|
error_message: str | None = None,
|
|
session: Session = None, # type: ignore
|
|
) -> SyncSession:
|
|
session.execute(
|
|
update(SyncSession)
|
|
.where(SyncSession.id == session_id)
|
|
.values(
|
|
status=SyncSessionStatus.FAILED,
|
|
completed_at=datetime.now(timezone.utc),
|
|
error_message=error_message,
|
|
)
|
|
.execution_options(synchronize_session="evaluate")
|
|
)
|
|
result = session.scalar(select(SyncSession).filter_by(id=session_id))
|
|
if not result:
|
|
raise NoResultFound(f"SyncSession {session_id} not found after fail")
|
|
return result
|
|
|
|
@begin_session
|
|
def cancel_active_sessions(
|
|
self,
|
|
device_id: str,
|
|
user_id: int,
|
|
session: Session = None, # type: ignore
|
|
) -> int:
|
|
"""Cancel all active sessions for a device. Returns count of cancelled sessions."""
|
|
result = session.execute(
|
|
update(SyncSession)
|
|
.where(
|
|
SyncSession.device_id == device_id,
|
|
SyncSession.user_id == user_id,
|
|
SyncSession.status.in_(
|
|
[
|
|
SyncSessionStatus.PENDING,
|
|
SyncSessionStatus.IN_PROGRESS,
|
|
]
|
|
),
|
|
)
|
|
.values(
|
|
status=SyncSessionStatus.CANCELLED,
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
.execution_options(synchronize_session="evaluate")
|
|
)
|
|
return result.rowcount
|
|
|
|
@begin_session
|
|
def get_sessions(
|
|
self,
|
|
user_id: int,
|
|
device_id: str | None = None,
|
|
status: SyncSessionStatus | None = None,
|
|
limit: int = 50,
|
|
session: Session = None, # type: ignore
|
|
) -> Sequence[SyncSession]:
|
|
query = select(SyncSession).filter_by(user_id=user_id)
|
|
|
|
if device_id:
|
|
query = query.filter_by(device_id=device_id)
|
|
|
|
if status:
|
|
query = query.filter_by(status=status)
|
|
|
|
query = query.order_by(SyncSession.initiated_at.desc()).limit(limit)
|
|
return session.scalars(query).all()
|