Source code for app.repository.wallet_repository
from contextlib import AbstractAsyncContextManager
from typing import Callable, Optional
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import configs
from app.core.exceptions import NotFoundError
from app.model.wallet import Wallet
from app.repository.base_repository import BaseRepository
[docs]
class WalletRepository(BaseRepository):
"""
Repository class for wallets.
"""
def __init__(
self,
session_factory: Callable[..., AbstractAsyncContextManager[AsyncSession]],
model=Wallet,
) -> None:
super().__init__(session_factory, model)
[docs]
async def upsert_points_balance(
self,
user_id,
points_delta: int,
api_key: Optional[str] = None,
oauth_user_id: Optional[str] = None,
session: Optional[AsyncSession] = None,
auto_commit: bool = True,
) -> Wallet:
"""
Atomically increments a user's wallet points balance via
``INSERT ... ON CONFLICT DO UPDATE``.
"""
if session is None and not auto_commit:
raise ValueError(
"auto_commit=False requires an external session managed by the caller."
)
if session is None:
async with self.session_factory() as managed_session:
return await self.upsert_points_balance(
user_id=user_id,
points_delta=points_delta,
api_key=api_key,
oauth_user_id=oauth_user_id,
session=managed_session,
auto_commit=auto_commit,
)
wallet_table = self.model.__table__
insert_stmt = insert(wallet_table).values(
userId=user_id,
coinsBalance=0.0,
pointsBalance=points_delta,
conversionRate=configs.DEFAULT_CONVERTION_RATE_POINTS_TO_COIN,
apiKey_used=api_key,
oauth_user_id=oauth_user_id,
)
on_conflict_updates = {
"pointsBalance": wallet_table.c.pointsBalance + points_delta,
"updated_at": func.now(),
}
if api_key is not None:
on_conflict_updates["apiKey_used"] = api_key
if oauth_user_id is not None:
on_conflict_updates["oauth_user_id"] = oauth_user_id
upsert_stmt = insert_stmt.on_conflict_do_update(
index_elements=[wallet_table.c.userId],
set_=on_conflict_updates,
).returning(wallet_table.c.id)
wallet_id = (await session.execute(upsert_stmt)).scalar_one()
if auto_commit:
await session.commit()
else:
await session.flush()
wallet = (
(
await session.execute(
select(self.model).filter(self.model.id == wallet_id)
)
)
.scalars()
.first()
)
if wallet is None:
raise NotFoundError(detail=f"Wallet not found by id: {wallet_id}")
return wallet