Source code for app.repository.abuse_limit_counter_repository

from contextlib import AbstractAsyncContextManager
from datetime import datetime, timezone
from typing import Callable

from sqlalchemy import select
from sqlalchemy import update as sa_update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from app.model.abuse_limit_counter import AbuseLimitCounter
from app.repository.base_repository import BaseRepository


[docs] class AbuseLimitCounterRepository(BaseRepository): """ Repository for abuse prevention counters. The increment operation is safe under concurrent writes: - update existing bucket atomically (`counter = counter + 1`) - if bucket does not exist, insert - if concurrent insert collides, retry update """ def __init__( self, session_factory: Callable[..., AbstractAsyncContextManager[AsyncSession]], model=AbuseLimitCounter, ) -> None: super().__init__(session_factory, model)
[docs] async def increment_and_get( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ) -> int: """ Atomically increment a rate-limit bucket and return its new value. Concurrency-safe: it first tries an atomic ``counter = counter + 1`` update; if the bucket does not yet exist it inserts it, and if a concurrent insert collides it rolls back and retries the update. The ``window_start`` is normalized to UTC. Args: scope_type (str): Dimension being limited (e.g. ``"ip"``, ``"api_key"``). scope_value (str): Concrete value within ``scope_type``. window_name (str): Name of the limit window (e.g. ``"per_minute"``). window_start (datetime): Start of the bucket's time window. Returns: int: The counter value after this increment. """ if window_start.tzinfo is None: window_start = window_start.replace(tzinfo=timezone.utc) async with self.session_factory() as session: filters = self._build_filters( scope_type=scope_type, scope_value=scope_value, window_name=window_name, window_start=window_start, ) update_payload = { self.model.counter: self.model.counter + 1, self.model.updated_at: datetime.now(timezone.utc), } updated_rows = ( await session.execute( sa_update(self.model) .where(*filters) .values(update_payload) .execution_options(synchronize_session=False) ) ).rowcount if updated_rows: await session.commit() return await self._read_counter(session, filters) try: session.add( self.model( scopeType=scope_type, scopeValue=scope_value, windowName=window_name, windowStart=window_start, counter=1, ) ) await session.commit() return 1 except IntegrityError: await session.rollback() await session.execute( sa_update(self.model) .where(*filters) .values(update_payload) .execution_options(synchronize_session=False) ) await session.commit() return await self._read_counter(session, filters)
def _build_filters( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ): """ Build the equality filters that uniquely identify a counter bucket. Returns: tuple: SQLAlchemy boolean expressions for scope type/value, window name and window start, used in ``WHERE`` clauses. """ return ( self.model.scopeType == scope_type, self.model.scopeValue == scope_value, self.model.windowName == window_name, self.model.windowStart == window_start, ) async def _read_counter(self, session: AsyncSession, filters) -> int: """ Read the current counter value for a bucket within ``session``. Args: session (AsyncSession): Active session to query within. filters: Bucket-identifying filters from :meth:`_build_filters`. Returns: int: The stored counter, or ``0`` when the bucket is absent. """ value = ( await session.execute(select(self.model.counter).where(*filters)) ).scalar() return int(value or 0)