Source code for app.repository.base_repository

from contextlib import AbstractAsyncContextManager
from typing import Callable, Optional

from sqlalchemy import and_
from sqlalchemy import delete as sa_delete
from sqlalchemy import func, select
from sqlalchemy import update as sa_update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

from app.core.config import configs
from app.core.exceptions import DuplicatedError, NotFoundError
from app.util.query_builder import dict_to_sqlalchemy_filter_options


[docs] class BaseRepository: """ Async base repository providing common CRUD operations on top of SQLAlchemy 2.0's ``AsyncSession``. """ def __init__( self, session_factory: Callable[..., AbstractAsyncContextManager[AsyncSession]], model, ) -> None: self.session_factory = session_factory self.model = model
[docs] async def read_by_options(self, schema, eager: bool = False): """ Run a filtered, ordered and paginated query from a search schema. Non-null fields on ``schema`` are turned into ``WHERE`` conditions; ``ordering``, ``page`` and ``page_size`` drive sorting and pagination (a ``page_size`` of ``"all"`` disables the limit). The total row count ignoring pagination is computed alongside the page. Args: schema: Search schema whose non-null fields become filters and which may carry ``ordering``/``page``/``page_size``. eager (bool): When ``True``, eagerly join the model's ``eagers`` relationships. Returns: dict: ``{"items": [...], "search_options": {page, page_size, ordering, total_count}}``. """ async with self.session_factory() as session: schema_as_dict = schema.model_dump(exclude_none=True) ordering = schema_as_dict.get("ordering", configs.ORDERING) order_query = ( getattr(self.model, ordering[1:]).desc() if ordering.startswith("-") else getattr(self.model, ordering).asc() ) page = schema_as_dict.get("page", configs.PAGE) page_size = schema_as_dict.get("page_size", configs.PAGE_SIZE) filter_options = dict_to_sqlalchemy_filter_options( self.model, schema.model_dump(exclude_none=True) ) stmt = select(self.model) if eager: for eager_rel in getattr(self.model, "eagers", []): stmt = stmt.options(joinedload(getattr(self.model, eager_rel))) stmt = stmt.filter(filter_options) count_stmt = select(func.count()).select_from( select(self.model).filter(filter_options).subquery() ) stmt = stmt.order_by(order_query) if page_size != "all": stmt = stmt.limit(page_size).offset((page - 1) * page_size) result = await session.execute(stmt) items = result.unique().scalars().all() total_count = (await session.execute(count_stmt)).scalar_one() return { "items": items, "search_options": { "page": page, "page_size": page_size, "ordering": ordering, "total_count": total_count, }, }
[docs] async def read_by_id( self, id, eager: bool = False, not_found_raise_exception: bool = True, not_found_message: str = "Not found id : {id}", ): """ Fetch a single row by primary key. Args: id: Primary-key value to look up. eager (bool): When ``True``, eagerly join the model's ``eagers`` relationships. not_found_raise_exception (bool): Raise :class:`NotFoundError` when no row matches; when ``False`` return ``None`` instead. not_found_message (str): Format string for the not-found error (``{id}`` is substituted). Returns: The matching model instance, or ``None`` when missing and ``not_found_raise_exception`` is ``False``. Raises: NotFoundError: If no row matches and the flag is ``True``. """ async with self.session_factory() as session: stmt = select(self.model).filter(self.model.id == id) if eager: for eager_rel in getattr(self.model, "eagers", []): stmt = stmt.options(joinedload(getattr(self.model, eager_rel))) result = (await session.execute(stmt)).unique().scalars().first() if not result and not_found_raise_exception: raise NotFoundError(detail=not_found_message.format(id=id)) return result
[docs] async def read_by_column( self, column: str, value, eager: bool = False, only_one: bool = True, not_found_raise_exception: bool = True, not_found_message: str = "Not found {column} : {value}", ): """ Fetch rows whose ``column`` equals ``value``. Args: column (str): Name of the model attribute to filter on. value: Value the column must equal. eager (bool): When ``True``, eagerly join the model's ``eagers`` relationships. only_one (bool): Return the first match when ``True``; otherwise return the full list of matches. not_found_raise_exception (bool): When ``only_one`` is ``True``, raise :class:`NotFoundError` if nothing matches. not_found_message (str): Format string for the not-found error (``{column}`` and ``{value}`` are substituted). Returns: A single instance (or ``None``) when ``only_one`` is ``True``, otherwise a list of matching instances. Raises: NotFoundError: If ``only_one`` and nothing matches and the flag is ``True``. """ async with self.session_factory() as session: stmt = select(self.model).filter(getattr(self.model, column) == value) if eager: for eager_rel in getattr(self.model, "eagers", []): stmt = stmt.options(joinedload(getattr(self.model, eager_rel))) executed = await session.execute(stmt) if only_one: result = executed.unique().scalars().first() if not result and not_found_raise_exception: raise NotFoundError( detail=not_found_message.format(column=column, value=value) ) return result return executed.unique().scalars().all()
[docs] async def create( self, schema, session: Optional[AsyncSession] = None, auto_commit: bool = True, ): """ Insert a new row built from ``schema``. Supports two modes: a self-managed session (``session=None``, ``auto_commit=True``) that opens, commits and closes its own session, or an externally-managed session passed by the caller so the insert can take part in a larger transaction (``auto_commit=False`` flushes instead of committing). Args: schema: Pydantic schema dumped into the model's constructor. session (Optional[AsyncSession]): Caller-managed session to reuse; a new one is opened when ``None``. auto_commit (bool): Commit immediately when ``True``; otherwise only flush (requires an external ``session``). Returns: The persisted (and refreshed) model instance. Raises: ValueError: If ``auto_commit=False`` is requested without an external session. DuplicatedError: If the insert violates a uniqueness constraint. """ if session is None and not auto_commit: raise ValueError( "auto_commit=False requires an external session managed by the caller." ) if session is None: async with self.session_factory() as managed_session: return await self.create( schema, session=managed_session, auto_commit=auto_commit, ) entity = self.model(**schema.model_dump()) try: session.add(entity) if auto_commit: await session.commit() else: await session.flush() await session.refresh(entity) except IntegrityError as e: if auto_commit: await session.rollback() raise DuplicatedError(detail=str(e.orig)) return entity
[docs] async def update(self, id, schema): """ Partially update a row, ignoring ``None`` fields on ``schema``. Args: id: Primary key of the row to update. schema: Pydantic schema; only its non-null fields are written. Returns: The refreshed model instance after the update. """ async with self.session_factory() as session: await session.execute( sa_update(self.model) .where(self.model.id == id) .values(**schema.model_dump(exclude_none=True)) ) await session.commit() return await self.read_by_id(id)
[docs] async def update_attr(self, id, column: str, value): """ Update a single column of a row. Args: id: Primary key of the row to update. column (str): Name of the column to set. value: New value for the column. Returns: The refreshed model instance after the update. """ async with self.session_factory() as session: await session.execute( sa_update(self.model) .where(self.model.id == id) .values(**{column: value}) ) await session.commit() return await self.read_by_id(id)
[docs] async def whole_update(self, id, schema): """ Fully replace a row's columns from ``schema`` (including ``None``). Unlike :meth:`update`, every field on ``schema`` is written, so this performs a complete overwrite rather than a partial patch. Args: id: Primary key of the row to update. schema: Pydantic schema dumped in full into the update. Returns: The refreshed model instance after the update. """ async with self.session_factory() as session: await session.execute( sa_update(self.model) .where(self.model.id == id) .values(**schema.model_dump()) ) await session.commit() return await self.read_by_id(id)
[docs] async def delete_by_id(self, id): """ Delete a row by primary key. Args: id: Primary key of the row to delete. Raises: NotFoundError: If no row matches ``id``. """ async with self.session_factory() as session: stmt = select(self.model).filter(self.model.id == id) entity = (await session.execute(stmt)).scalars().first() if not entity: raise NotFoundError(detail=f"Not found id : {id}") await session.execute(sa_delete(self.model).where(self.model.id == id)) await session.commit()
[docs] async def read_by_columns( self, filters: dict, eager: bool = False, only_one: bool = True, not_found_raise_exception: bool = True, ): """ Fetch rows matching several equality filters combined with ``AND``. Args: filters (dict): Mapping of column name to required value; all conditions must hold (``AND``). eager (bool): When ``True``, eagerly join the model's ``eagers`` relationships. only_one (bool): Return the first match when ``True``; otherwise return all matches. not_found_raise_exception (bool): When ``only_one`` is ``True``, raise :class:`NotFoundError` if nothing matches. Returns: A single instance (or ``None``) when ``only_one`` is ``True``, otherwise a list of matching instances. Raises: NotFoundError: If ``only_one`` and nothing matches and the flag is ``True``. """ async with self.session_factory() as session: stmt = select(self.model) if eager: for eager_field in getattr(self.model, "eagers", []): stmt = stmt.options(joinedload(getattr(self.model, eager_field))) conditions = [ getattr(self.model, col) == val for col, val in filters.items() ] stmt = stmt.filter(and_(*conditions)) executed = await session.execute(stmt) if only_one: result = executed.unique().scalars().first() if not result and not_found_raise_exception: raise NotFoundError(detail=f"Not found for filters: {filters}") return result return executed.unique().scalars().all()