from contextlib import AbstractAsyncContextManager
from datetime import date, datetime
from typing import Callable, Dict, List, Union
from sqlalchemy import String, case, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BadRequestError
from app.model.games import Games
from app.model.logs import Logs
from app.model.tasks import Tasks
from app.model.user_actions import UserActions
from app.model.user_points import UserPoints
from app.model.users import Users
from app.repository.base_repository import BaseRepository
[docs]
class DashboardRepository(BaseRepository):
"""
Repository class for API keys and dashboard metrics.
"""
[docs]
def __init__(
self,
session_factory: Callable[..., AbstractAsyncContextManager[AsyncSession]],
model_games=Games,
model_tasks=Tasks,
model_users=Users,
model_logs=Logs,
model_user_points=UserPoints,
model_user_actions=UserActions,
) -> None:
"""
Initializes the DashboardRepository with the provided session factory and models.
"""
self.session_factory = session_factory
self.model_games = model_games
self.model_tasks = model_tasks
self.model_users = model_users
self.model_logs = model_logs
self.model_user_points = model_user_points
self.model_user_actions = model_user_actions
super().__init__(session_factory, model_games)
@staticmethod
def _parse_date_boundary(value):
"""Coerce an incoming date string into a ``datetime``.
The ``start_date``/``end_date`` query params arrive as strings
(``Query(None)`` typed ``str``). Feeding them straight into a
comparison against the ``created_at`` (``timestamptz``) column made
asyncpg bind them as ``varchar``, so Postgres rejected the query with
``operator does not exist: timestamp with time zone >= character
varying`` (a 500). Parsing to a real ``datetime`` binds the right
type. Empty/``None`` means "no boundary"; a malformed value is a
client error, not a server crash.
"""
if value is None or value == "":
return None
if isinstance(value, (datetime, date)):
return value
try:
return datetime.fromisoformat(value)
except (TypeError, ValueError) as exc:
raise BadRequestError(
f"Invalid date '{value}'. Expected ISO 8601, e.g. 2026-05-09."
) from exc
[docs]
def process_query(
self, query, model=None, start_date=None, end_date=None, group_by_column=None
):
"""
Processes the SELECT statement by filtering and grouping the results.
The date filter is applied to the ``created_at`` column of the model
actually being aggregated (``model``); it previously always filtered
``model_users.created_at``, which cross-joined ``Users`` into queries
over ``Games``/``Logs``/``UserPoints``/``UserActions`` and skewed
their counts. ``model`` defaults to ``model_users`` to preserve the
users-summary behaviour when a caller omits it.
"""
date_model = model if model is not None else self.model_users
start_date = self._parse_date_boundary(start_date)
end_date = self._parse_date_boundary(end_date)
if start_date:
query = query.filter(date_model.created_at >= start_date)
if end_date:
query = query.filter(date_model.created_at <= end_date)
if group_by_column is not None:
query = query.group_by(group_by_column)
return query
def _get_group_by_column(self, model, group_by: str):
"""
Returns the appropriate group_by column based on the model and
grouping criteria.
Args:
model: The model to query.
group_by: The grouping criteria.
Returns:
Any: The group_by column.
"""
if group_by == "day":
return func.date_trunc("day", model.created_at).label("date")
elif group_by == "week":
return case(
(
func.extract("day", model.created_at).between(1, 7),
func.concat("week_1_", func.extract("month", model.created_at)),
),
(
func.extract("day", model.created_at).between(8, 14),
func.concat("week_2_", func.extract("month", model.created_at)),
),
(
func.extract("day", model.created_at).between(15, 21),
func.concat("week_3_", func.extract("month", model.created_at)),
),
(
func.extract("day", model.created_at).between(22, 28),
func.concat("week_4_", func.extract("month", model.created_at)),
),
(
func.extract("day", model.created_at) >= 29,
func.concat("week_5_", func.extract("month", model.created_at)),
),
else_="unknown_week",
).label("week")
elif group_by == "month":
return func.lpad(
func.cast(func.extract("month", model.created_at), String), 2, "0"
).label("month")
else:
raise BadRequestError(
"Invalid group_by value. Choose 'day', 'week', or 'month'."
)
async def _execute_query(
self, model, group_by_column, start_date, end_date, aggregation_field
) -> List[Dict[str, Union[str, int]]]:
"""Executes a query for a specific model and aggregation field."""
async with self.session_factory() as session:
stmt = select(group_by_column, aggregation_field.label("count"))
stmt = self.process_query(
stmt, model, start_date, end_date, group_by_column
)
results = (await session.execute(stmt)).all()
return [
{"label": str(result[0]), "count": result.count} for result in results
]
[docs]
async def get_dashboard_summary(self, start_date, end_date, group_by):
"""
Retrieves the dashboard summary.
Args:
start_date: The start date for the summary.
end_date: The end date for the summary.
group_by: The group by for the summary (e.g. day, week, month).
Returns:
Dict[str, Any]: The dashboard summary.
"""
group_by_column_users = self._get_group_by_column(self.model_users, group_by)
new_users = await self._execute_query(
self.model_users,
group_by_column_users,
start_date,
end_date,
func.count(self.model_users.id),
)
group_by_column_games = self._get_group_by_column(self.model_games, group_by)
games_opened = await self._execute_query(
self.model_games,
group_by_column_games,
start_date,
end_date,
func.count(self.model_games.id),
)
group_by_column_points = self._get_group_by_column(
self.model_user_points, group_by
)
points_earned = await self._execute_query(
self.model_user_points,
group_by_column_points,
start_date,
end_date,
func.sum(self.model_user_points.points),
)
group_by_column_actions = self._get_group_by_column(
self.model_user_actions, group_by
)
actions_performed = await self._execute_query(
self.model_user_actions,
group_by_column_actions,
start_date,
end_date,
func.count(self.model_user_actions.id),
)
return {
"new_users": new_users,
"games_opened": games_opened,
"points_earned": points_earned,
"actions_performed": actions_performed,
}
[docs]
async def get_dashboard_summary_logs(self, start_date, end_date, group_by):
"""
Retrieves the dashboard summary logs.
Args:
start_date: The start date for the summary.
end_date: The end date for the summary.
group_by: The group by for the summary (e.g. day, week, month).
Returns:
Dict[str, Any]: The dashboard summary logs.
"""
group_by_column = self._get_group_by_column(self.model_logs, group_by)
info = await self._execute_query(
self.model_logs,
group_by_column,
start_date,
end_date,
func.count(case((self.model_logs.log_level == "INFO", 1))),
)
success = await self._execute_query(
self.model_logs,
group_by_column,
start_date,
end_date,
func.count(case((self.model_logs.log_level == "SUCCESS", 1))),
)
error = await self._execute_query(
self.model_logs,
group_by_column,
start_date,
end_date,
func.count(case((self.model_logs.log_level == "ERROR", 1))),
)
return {
"info": info,
"success": success,
"error": error,
}