import ipaddress
from datetime import datetime, timezone
from typing import Optional
from fastapi import Request
from app.core.config import configs
from app.core.exceptions import TooManyRequestsError
from app.services.rate_limit_counter_backend import RateLimitCounterBackend
[docs]
class AbusePreventionService:
"""
Service responsible for abuse prevention checks on sensitive endpoints.
"""
def __init__(self, counter_backend: RateLimitCounterBackend) -> None:
self.counter_backend = counter_backend
@staticmethod
def _peer_is_trusted_proxy(ip: Optional[str]) -> bool:
"""
Return whether ``ip`` falls within a configured trusted-proxy network.
Used to decide if ``X-Forwarded-For`` can be trusted to derive the real
client IP.
Args:
ip (Optional[str]): The direct peer IP to test.
Returns:
bool: ``True`` if ``ip`` is valid and within ``TRUSTED_PROXY_IPS``.
"""
if not ip:
return False
trusted = configs.TRUSTED_PROXY_IPS
if not trusted:
return False
try:
addr = ipaddress.ip_address(ip)
except ValueError:
return False
return any(addr in network for network in trusted)
[docs]
async def enforce_task_mutation_limits(
self,
api_key: Optional[str],
client_ip: Optional[str],
external_user_id: Optional[str],
now: Optional[datetime] = None,
) -> None:
"""
Enforces abuse controls for task mutation endpoints (`/points`, `/action`).
"""
if not configs.ABUSE_PREVENTION_ENABLED:
return
now = self._normalize_now(now)
window_seconds = max(1, int(configs.ABUSE_RATE_LIMIT_WINDOW_SECONDS))
short_window_name = f"task_mutation_short_{window_seconds}s"
short_window_start = self._get_window_bucket_start(now, window_seconds)
short_ttl = self._ttl_for_seconds(window_seconds)
normalized_api_key = self._normalize_scope_value(api_key)
normalized_ip = self._normalize_scope_value(client_ip)
normalized_external_user = self._normalize_scope_value(external_user_id)
await self._enforce_limit(
scope_type="api_key",
scope_value=normalized_api_key,
window_name=short_window_name,
window_start=short_window_start,
ttl_seconds=short_ttl,
max_allowed=int(configs.ABUSE_RATE_LIMIT_PER_API_KEY),
error_detail="API key rate limit exceeded for sensitive task operations.",
)
await self._enforce_limit(
scope_type="ip",
scope_value=normalized_ip,
window_name=short_window_name,
window_start=short_window_start,
ttl_seconds=short_ttl,
max_allowed=int(configs.ABUSE_RATE_LIMIT_PER_IP),
error_detail="IP rate limit exceeded for sensitive task operations.",
)
await self._enforce_limit(
scope_type="external_user",
scope_value=normalized_external_user,
window_name=short_window_name,
window_start=short_window_start,
ttl_seconds=short_ttl,
max_allowed=int(configs.ABUSE_RATE_LIMIT_PER_EXTERNAL_USER),
error_detail="externalUserId rate limit exceeded for sensitive task operations.",
)
daily_window_name = "task_mutation_daily"
daily_window_start = self._get_daily_bucket_start(now)
daily_ttl = self._ttl_for_seconds(86400)
await self._enforce_limit(
scope_type="api_key",
scope_value=normalized_api_key,
window_name=daily_window_name,
window_start=daily_window_start,
ttl_seconds=daily_ttl,
max_allowed=int(configs.ABUSE_DAILY_QUOTA_PER_API_KEY),
error_detail="Daily API key quota exceeded for sensitive task operations.",
)
async def _enforce_limit(
self,
scope_type: str,
scope_value: Optional[str],
window_name: str,
window_start: datetime,
ttl_seconds: int,
max_allowed: int,
error_detail: str,
) -> None:
"""
Increment one rate-limit bucket and raise if it exceeds the cap.
No-ops when the scope value is empty or the cap is non-positive.
Args:
scope_type (str): Dimension being limited (``api_key``/``ip``/...).
scope_value (Optional[str]): Concrete value within ``scope_type``.
window_name (str): Name of the time window/bucket.
window_start (datetime): Start of the bucket window.
ttl_seconds (int): TTL applied to the bucket.
max_allowed (int): Maximum requests allowed in the window.
error_detail (str): Message used when the limit is exceeded.
Raises:
TooManyRequestsError: If the counter exceeds ``max_allowed``.
"""
if not scope_value:
return
if max_allowed <= 0:
return
counter = await self.counter_backend.increment_and_get(
scope_type=scope_type,
scope_value=scope_value,
window_name=window_name,
window_start=window_start,
ttl_seconds=ttl_seconds,
)
if counter > max_allowed:
raise TooManyRequestsError(detail=error_detail)
@staticmethod
def _normalize_scope_value(value: Optional[str]) -> Optional[str]:
"""
Trim a scope value, collapsing blanks to ``None``.
Args:
value (Optional[str]): Raw scope value.
Returns:
Optional[str]: The stripped value, or ``None`` if empty/blank.
"""
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
return normalized
@staticmethod
def _normalize_now(now: Optional[datetime]) -> datetime:
"""
Return ``now`` as a UTC-aware datetime, defaulting to the current time.
Args:
now (Optional[datetime]): Caller-supplied time, naive or aware.
Returns:
datetime: A timezone-aware UTC datetime.
"""
if now is None:
return datetime.now(timezone.utc)
if now.tzinfo is None:
return now.replace(tzinfo=timezone.utc)
return now.astimezone(timezone.utc)
@staticmethod
def _get_window_bucket_start(now: datetime, window_seconds: int) -> datetime:
"""
Align ``now`` down to the start of its fixed-size time bucket.
Args:
now (datetime): The reference time.
window_seconds (int): Bucket width in seconds.
Returns:
datetime: UTC start of the bucket containing ``now``.
"""
now = now.astimezone(timezone.utc)
unix_now = int(now.timestamp())
bucket_start = unix_now - (unix_now % window_seconds)
return datetime.fromtimestamp(bucket_start, tz=timezone.utc)
@staticmethod
def _get_daily_bucket_start(now: datetime) -> datetime:
"""
Return midnight UTC of ``now``'s day (the daily-quota bucket start).
Args:
now (datetime): The reference time.
Returns:
datetime: UTC start of the current day.
"""
now = now.astimezone(timezone.utc)
return datetime(
year=now.year,
month=now.month,
day=now.day,
tzinfo=timezone.utc,
)
@staticmethod
def _ttl_for_seconds(window_seconds: int) -> int:
"""
Compute a bucket TTL from a window length plus a skew buffer.
Args:
window_seconds (int): The window length in seconds.
Returns:
int: TTL in seconds (window plus ``RATE_LIMIT_TTL_BUFFER_SECONDS``)
so buckets survive clock skew between instances and Redis.
"""
# Extra buffer absorbs clock skew between API instances and Redis so
# a bucket cannot die just before the next request promotes it.
buffer = max(1, int(configs.RATE_LIMIT_TTL_BUFFER_SECONDS))
return max(1, int(window_seconds)) + buffer