Source code for app.repository.user_repository

from contextlib import AbstractAsyncContextManager
from typing import Callable, Optional

from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.exceptions import NotFoundError
from app.model.users import Users
from app.repository.base_repository import BaseRepository


[docs] class UserRepository(BaseRepository): """ Repository class for users. """ def __init__( self, session_factory: Callable[..., AbstractAsyncContextManager[AsyncSession]], model=Users, ) -> None: super().__init__(session_factory, model)
[docs] async def create_user_by_externalUserId( self, externalUserId: str, oauth_user_id: Optional[str] = None, ) -> Users: """ Creates a new user with the provided external user ID. Concurrency-safe: when a parallel transaction inserts the same externalUserId first, returns the already-existing row instead of raising IntegrityError to the caller. """ async with self.session_factory() as session: user = Users(externalUserId=externalUserId, oauth_user_id=oauth_user_id) session.add(user) try: await session.commit() await session.refresh(user) return user except IntegrityError: await session.rollback() existing = ( ( await session.execute( select(self.model).filter_by(externalUserId=externalUserId) ) ) .scalars() .first() ) if existing is not None: return existing raise
[docs] async def get_or_create_by_externalUserId( self, externalUserId: str, oauth_user_id: Optional[str] = None, session: Optional[AsyncSession] = None, auto_commit: bool = True, ) -> Users: """ Returns an existing user by externalUserId or creates it atomically via ``INSERT ... ON CONFLICT DO UPDATE``. """ if session is None and not auto_commit: raise ValueError( "auto_commit=False requires an external session managed by the caller." ) if session is None: async with self.session_factory() as managed_session: return await self.get_or_create_by_externalUserId( externalUserId=externalUserId, oauth_user_id=oauth_user_id, session=managed_session, auto_commit=auto_commit, ) users_table = self.model.__table__ insert_values = { "externalUserId": externalUserId, "oauth_user_id": oauth_user_id, } insert_stmt = insert(users_table).values(**insert_values) update_values = {"updated_at": func.now()} if oauth_user_id is not None: update_values["oauth_user_id"] = oauth_user_id upsert_stmt = insert_stmt.on_conflict_do_update( index_elements=[users_table.c.externalUserId], set_=update_values, ).returning(users_table.c.id) user_id = (await session.execute(upsert_stmt)).scalar_one() if auto_commit: await session.commit() else: await session.flush() user = ( (await session.execute(select(self.model).filter(self.model.id == user_id))) .scalars() .first() ) if user is None: raise NotFoundError( detail=f"User not found after upsert by externalUserId: {externalUserId}" ) return user