Source code for app.repository.task_repository
from contextlib import AbstractAsyncContextManager
from typing import Callable
from sqlalchemy import delete as sa_delete
from sqlalchemy import func, select
from sqlalchemy import update as sa_update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.config import configs
from app.core.exceptions import DuplicatedError, NotFoundError
from app.model.task_params import TasksParams
from app.model.tasks import Tasks
from app.model.user_points import UserPoints
from app.repository.base_repository import BaseRepository
from app.util.query_builder import dict_to_sqlalchemy_filter_options
[docs]
class TaskRepository(BaseRepository):
"""
Repository class for tasks.
"""
def __init__(
self,
session_factory: Callable[..., AbstractAsyncContextManager[AsyncSession]],
model=Tasks,
model_task_params=TasksParams,
model_user_points=UserPoints,
) -> None:
self.model_task_params = model_task_params
self.model_user_points = model_user_points
super().__init__(session_factory, model)
[docs]
async def read_by_gameId(self, schema, eager: bool = False):
"""
Reads tasks filtered by gameId (and any other schema fields).
"""
async with self.session_factory() as session:
schema_as_dict = schema.model_dump(exclude_none=True)
ordering = schema_as_dict.get("ordering", configs.ORDERING)
order_query = (
getattr(self.model, ordering[1:]).desc()
if ordering.startswith("-")
else getattr(self.model, ordering).asc()
)
page = schema_as_dict.get("page", configs.PAGE)
page_size = schema_as_dict.get("page_size", configs.PAGE_SIZE)
filter_options = dict_to_sqlalchemy_filter_options(
self.model, schema.model_dump(exclude_none=True)
)
stmt = select(self.model).filter(filter_options)
if eager:
for eager_rel in getattr(self.model, "eagers", []):
stmt = stmt.options(joinedload(getattr(self.model, eager_rel)))
count_stmt = select(func.count()).select_from(
select(self.model).filter(filter_options).subquery()
)
stmt = stmt.order_by(order_query)
if page_size != "all":
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
items = (await session.execute(stmt)).unique().scalars().all()
total_count = (await session.execute(count_stmt)).scalar_one()
return {
"items": items,
"search_options": {
"page": page,
"page_size": page_size,
"ordering": ordering,
"total_count": total_count,
},
}
[docs]
async def read_by_gameId_and_externalTaskId(self, gameId, externalTaskId: str):
"""
Look up a task by its game and external identifier.
Args:
gameId: Internal identifier of the owning game.
externalTaskId (str): External identifier of the task.
Returns:
Tasks | None: The matching task, or ``None`` if not found.
"""
async with self.session_factory() as session:
stmt = select(self.model).filter(
self.model.gameId == gameId,
self.model.externalTaskId == externalTaskId,
)
return (await session.execute(stmt)).scalars().first()
[docs]
async def get_points_and_users_by_taskId(self, taskId):
"""
Fetch a task by its internal id, raising if it does not exist.
Args:
taskId: Internal task identifier.
Returns:
Tasks: The matching task.
Raises:
NotFoundError: If no task has the given id.
"""
async with self.session_factory() as session:
stmt = select(self.model).filter(self.model.id == taskId)
result = (await session.execute(stmt)).scalars().first()
if not result:
raise NotFoundError(detail=f"Task not found by id : {taskId}")
return result
[docs]
async def patch_by_id(self, taskId, fields: dict):
"""
Apply a small ``fields`` dict to the task identified by
``taskId``. Returns the refreshed row. Used by the
``PATCH /games/{gameId}/tasks/{taskId}`` flow so the assignments
admin view can rewrite ``strategyId`` (and ``status``) without a
full upsert.
Raises :class:`NotFoundError` if the task does not exist.
"""
async with self.session_factory() as session:
stmt = select(self.model).filter(self.model.id == taskId)
task = (await session.execute(stmt)).scalars().first()
if not task:
raise NotFoundError(detail=f"Task not found by id : {taskId}")
for key, value in fields.items():
setattr(task, key, value)
await session.commit()
await session.refresh(task)
return task
[docs]
async def delete_task_by_id(self, task_id):
"""
Delete a single task and everything that hangs off it.
Mirrors the per-task branch of
:meth:`GameRepository.delete_game_by_id`: task params and the
user-points rows that reference the task are removed first so the
FK constraints don't block the final ``DELETE`` on the task row.
Returns ``True`` on success.
Raises :class:`NotFoundError` if the task does not exist.
"""
try:
async with self.session_factory() as session:
task = (
(
await session.execute(
select(self.model).filter(self.model.id == task_id)
)
)
.scalars()
.first()
)
if not task:
raise NotFoundError(detail=f"Not found id : {task_id}")
await session.execute(
sa_delete(self.model_task_params).where(
self.model_task_params.taskId == task_id
)
)
await session.execute(
sa_delete(self.model_user_points).where(
self.model_user_points.taskId == task_id
)
)
await session.delete(task)
await session.commit()
return True
except IntegrityError as e:
raise DuplicatedError(detail=str(e.orig))
except NotFoundError:
raise
except Exception as e:
raise NotFoundError(detail=str(e))
[docs]
async def list_by_strategy_id(self, strategy_id: str):
"""
Return all tasks whose ``strategyId`` matches the given value.
Rollback cascade companion to
:meth:`GameRepository.list_by_strategy_id`.
"""
async with self.session_factory() as session:
stmt = select(self.model).filter(self.model.strategyId == strategy_id)
return list((await session.execute(stmt)).scalars().all())
[docs]
async def bulk_update_strategy_id(
self, *, old_strategy_id: str, new_strategy_id: str
) -> int:
"""
Rewrite every task's ``strategyId`` from ``old_strategy_id`` to
``new_strategy_id`` in a single UPDATE. Returns the row count.
Rollback cascade companion to
:meth:`GameRepository.bulk_update_strategy_id`.
"""
async with self.session_factory() as session:
result = await session.execute(
sa_update(self.model)
.where(self.model.strategyId == old_strategy_id)
.values(strategyId=new_strategy_id)
)
await session.commit()
return int(result.rowcount or 0)