Source code for app.repository.strategy_definition_repository

"""
Repository for :class:`StrategyDefinition`.

Encapsulates the versioning queries the service relies on: max-version
lookup, status transitions, and tenant-scoped reads. Heavier business
logic (forking drafts on update, archiving published siblings on
publish) is intentionally left to the service so this layer stays a thin
SQL adapter.
"""

from contextlib import AbstractAsyncContextManager
from datetime import datetime
from typing import Callable, List, Optional

from sqlalchemy import and_, func, select
from sqlalchemy import update as sa_update
from sqlalchemy.ext.asyncio import AsyncSession

from app.model.strategy_definition import StrategyDefinition, StrategyDefinitionStatus
from app.repository.base_repository import BaseRepository


[docs] class StrategyDefinitionRepository(BaseRepository): """ Async repository for the ``strategydefinition`` table. """ def __init__( self, session_factory: Callable[..., AbstractAsyncContextManager[AsyncSession]], model=StrategyDefinition, ) -> None: super().__init__(session_factory, model)
[docs] async def list_for_realm( self, *, realmId: Optional[str], status: Optional[str] = None, type: Optional[str] = None, limit: int = 100, ) -> List[StrategyDefinition]: """ List rows visible to a tenant, most recent first. ``realmId=None`` is treated as "global"; callers that want every row across realms must pass it explicitly via a separate admin helper (not implemented yet - tenancy stays strict). """ stmt = select(self.model).where(self.model.realmId == realmId) if status is not None: stmt = stmt.where(self.model.status == status) if type is not None: stmt = stmt.where(self.model.type == type) stmt = stmt.order_by(self.model.name.asc(), self.model.version.desc()).limit( max(1, min(limit, 500)) ) async with self.session_factory() as session: result = await session.execute(stmt) return list(result.scalars().all())
[docs] async def get_for_realm( self, *, id: str, realmId: Optional[str], ) -> Optional[StrategyDefinition]: """ Fetch one row scoped by id + tenant. Returns ``None`` when the row belongs to another tenant; the service maps that to a 404 so we don't leak existence across realms. """ stmt = select(self.model).where( and_(self.model.id == id, self.model.realmId == realmId) ) async with self.session_factory() as session: result = await session.execute(stmt) return result.scalars().first()
[docs] async def list_versions( self, *, realmId: Optional[str], name: str, ) -> List[StrategyDefinition]: """ Return every row in a ``(realmId, name)`` family, newest version first. Used by the history endpoint and already useful from the service when looking for the latest draft or published sibling. """ stmt = ( select(self.model) .where( and_( self.model.realmId == realmId, self.model.name == name, ) ) .order_by(self.model.version.desc()) ) async with self.session_factory() as session: result = await session.execute(stmt) return list(result.scalars().all())
[docs] async def get_max_version( self, *, realmId: Optional[str], name: str, ) -> int: """ Highest version number for a ``(realmId, name)`` family, or 0 if the family does not exist yet. The service uses this to allocate the next draft version atomically. """ stmt = select(func.max(self.model.version)).where( and_( self.model.realmId == realmId, self.model.name == name, ) ) async with self.session_factory() as session: result = await session.execute(stmt) value = result.scalar_one_or_none() return int(value or 0)
[docs] async def get_published( self, *, realmId: Optional[str], name: str, ) -> Optional[StrategyDefinition]: """Return the currently published row of a family, or ``None``.""" stmt = ( select(self.model) .where( and_( self.model.realmId == realmId, self.model.name == name, self.model.status == StrategyDefinitionStatus.PUBLISHED.value, ) ) .limit(1) ) async with self.session_factory() as session: result = await session.execute(stmt) return result.scalars().first()
[docs] async def get_version( self, *, realmId: Optional[str], name: str, version: int, ) -> Optional[StrategyDefinition]: """ Fetch a specific ``(realmId, name, version)`` row. Used by the rollback flow to locate the target version being promoted back to PUBLISHED. Returns ``None`` when the version does not exist in the family; the service maps that to a 404 so we don't leak cross-family info. """ stmt = select(self.model).where( and_( self.model.realmId == realmId, self.model.name == name, self.model.version == version, ) ) async with self.session_factory() as session: result = await session.execute(stmt) return result.scalars().first()
[docs] async def set_status( self, *, id: str, status: str, publishedAt: Optional[datetime] = None, ) -> None: """Bulk-set the status (and ``publishedAt`` on PUBLISHED).""" values = {"status": status} if publishedAt is not None: values["publishedAt"] = publishedAt async with self.session_factory() as session: await session.execute( sa_update(self.model).where(self.model.id == id).values(**values) ) await session.commit()