Source code for app.repository.abuse_limit_counter_repository

from contextlib import AbstractContextManager
from datetime import datetime, timezone
from typing import Callable

from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from app.model.abuse_limit_counter import AbuseLimitCounter
from app.repository.base_repository import BaseRepository


[docs] class AbuseLimitCounterRepository(BaseRepository): """ Repository for abuse prevention counters. The increment operation is implemented to be safe under concurrent writes: - update existing bucket atomically (`counter = counter + 1`) - if bucket does not exist, insert - if concurrent insert collides, retry update """ def __init__( self, session_factory: Callable[..., AbstractContextManager[Session]], model=AbuseLimitCounter, ) -> None: super().__init__(session_factory, model)
[docs] def increment_and_get( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ) -> int: """ Increments a counter bucket and returns the resulting value. """ if window_start.tzinfo is None: window_start = window_start.replace(tzinfo=timezone.utc) with self.session_factory() as session: filters = self._build_filters( scope_type=scope_type, scope_value=scope_value, window_name=window_name, window_start=window_start, ) update_payload = { self.model.counter: self.model.counter + 1, self.model.updated_at: datetime.now(timezone.utc), } updated_rows = ( session.query(self.model) .filter(*filters) .update(update_payload, synchronize_session=False) ) if updated_rows: session.commit() return self._read_counter(session, filters) try: session.add( self.model( scopeType=scope_type, scopeValue=scope_value, windowName=window_name, windowStart=window_start, counter=1, ) ) session.commit() return 1 except IntegrityError: session.rollback() ( session.query(self.model) .filter(*filters) .update(update_payload, synchronize_session=False) ) session.commit() return self._read_counter(session, filters)
def _build_filters( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ): return ( self.model.scopeType == scope_type, self.model.scopeValue == scope_value, self.model.windowName == window_name, self.model.windowStart == window_start, ) def _read_counter(self, session: Session, filters) -> int: value = session.query(self.model.counter).filter(*filters).scalar() return int(value or 0)