Source code for app.services.rate_limit_counter_backend

"""
Pluggable counter backends used by ``AbusePreventionService``.

The DB-backed backend writes to ``abuse_limit_counter`` (preserves the
historical behavior and keeps existing migrations relevant). The Redis-backed
backend uses ``INCR`` + ``SET NX EX`` which is atomic, ~50 us per request,
and naturally distributed across API instances -- the cost an ``UPDATE`` on
a hot row in Postgres does not pay.

Selection is driven by ``configs.ABUSE_PREVENTION_BACKEND`` and wired in
``app/core/container.py``.
"""

from __future__ import annotations

import logging
from datetime import datetime, timezone
from typing import Any, Optional, Protocol, runtime_checkable

from app.repository.abuse_limit_counter_repository import AbuseLimitCounterRepository

logger = logging.getLogger(__name__)


[docs] @runtime_checkable class RateLimitCounterBackend(Protocol): """Increment a per-bucket counter and return its new value."""
[docs] async def increment_and_get( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ttl_seconds: int, ) -> int: """Increment the bucket's counter and return its new value.""" ...
[docs] class DatabaseRateLimitCounterBackend: """ Backend that persists counters in ``abuse_limit_counter``. ``ttl_seconds`` is unused because the row's bucket key already encodes window boundaries; expired rows can be reaped offline (a future periodic job, not the request path). """ def __init__(self, repository: AbuseLimitCounterRepository) -> None: self._repository = repository
[docs] async def increment_and_get( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ttl_seconds: int, ) -> int: """ Increment a DB-backed rate-limit counter and return its value. ``ttl_seconds`` is ignored: the bucket's ``window_start`` already encodes window boundaries, so expiry is handled offline. Args: scope_type (str): Dimension being limited. scope_value (str): Concrete value within ``scope_type``. window_name (str): Name of the time window/bucket. window_start (datetime): Start of the bucket window. ttl_seconds (int): Ignored by this backend. Returns: int: The counter value after the increment. """ del ttl_seconds # bucket key carries window semantics for the DB backend return await self._repository.increment_and_get( scope_type=scope_type, scope_value=scope_value, window_name=window_name, window_start=window_start, )
[docs] class RedisRateLimitCounterBackend: """ Backend that uses Redis ``INCR`` + ``SET NX EX`` for per-bucket counters. Each bucket maps to one key with a TTL slightly longer than the window. ``SET key 0 EX ttl NX`` plants a zeroed counter with a TTL atomically on the first request of a bucket; subsequent ``INCR`` calls just bump it. The pipeline below batches both commands into a single round-trip and runs under ``MULTI/EXEC`` so other clients cannot observe a key without a TTL. """ def __init__(self, client, key_prefix: str = "game:rl:") -> None: self._client = client self._key_prefix = key_prefix @staticmethod def _bucket_epoch(window_start: datetime) -> int: """ Return the UTC unix timestamp of a bucket's window start. Args: window_start (datetime): Start of the bucket window (naive treated as UTC). Returns: int: Unix epoch seconds, used to make the Redis key unique per window. """ if window_start.tzinfo is None: window_start = window_start.replace(tzinfo=timezone.utc) else: window_start = window_start.astimezone(timezone.utc) return int(window_start.timestamp()) def _build_key( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ) -> str: """ Build the Redis key uniquely identifying a rate-limit bucket. Args: scope_type (str): Dimension being limited. scope_value (str): Concrete value within ``scope_type``. window_name (str): Name of the time window. window_start (datetime): Start of the bucket window. Returns: str: The fully-qualified, prefixed Redis key. """ return ( f"{self._key_prefix}{window_name}:{scope_type}:" f"{scope_value}:{self._bucket_epoch(window_start)}" )
[docs] async def increment_and_get( self, scope_type: str, scope_value: str, window_name: str, window_start: datetime, ttl_seconds: int, ) -> int: """ Increment a Redis-backed rate-limit counter and return its value. Uses a single ``MULTI/EXEC`` pipeline of ``SET key 0 EX ttl NX`` (plant a TTL'd zero on the first hit) followed by ``INCR`` so a key is never visible without a TTL. Args: scope_type (str): Dimension being limited. scope_value (str): Concrete value within ``scope_type``. window_name (str): Name of the time window/bucket. window_start (datetime): Start of the bucket window. ttl_seconds (int): Key TTL (clamped to ≥ 1 second). Returns: int: The counter value after the increment. """ key = self._build_key(scope_type, scope_value, window_name, window_start) ttl = max(1, int(ttl_seconds)) pipe = self._client.pipeline(transaction=True) pipe.set(key, 0, ex=ttl, nx=True) pipe.incr(key) results = await pipe.execute() return int(results[1])
[docs] def build_redis_client_from_url(url: str) -> Any: """ Build an async Redis client from a connection URL. Imported lazily so the ``redis`` package is only required when the Redis backend is actually selected -- callers running the DB backend never pay the dependency cost. """ try: from redis import asyncio as redis_asyncio except ImportError as exc: # pragma: no cover - import guard raise RuntimeError( "ABUSE_PREVENTION_BACKEND=redis requires the 'redis' package. " "Install with `poetry add redis@^8.0` or set " "ABUSE_PREVENTION_BACKEND=database to keep the DB-backed limiter." ) from exc return redis_asyncio.from_url(url, decode_responses=True)
[docs] def build_rate_limit_counter_backend( repository: AbuseLimitCounterRepository, backend_name: str, redis_url: Optional[str], redis_key_prefix: str, ) -> RateLimitCounterBackend: """ Select the configured backend. Falls back to the DB backend with a warning when Redis is requested but ``REDIS_URL`` is missing -- so a misconfigured deploy still rate-limits (just slower) instead of opening the floodgates. """ normalized = (backend_name or "database").strip().lower() if normalized == "redis": if not redis_url: logger.warning( "ABUSE_PREVENTION_BACKEND=redis but REDIS_URL is empty; " "falling back to the database-backed counter." ) return DatabaseRateLimitCounterBackend(repository) client = build_redis_client_from_url(redis_url) return RedisRateLimitCounterBackend(client, key_prefix=redis_key_prefix) return DatabaseRateLimitCounterBackend(repository)