from contextlib import AbstractContextManager
from typing import Callable
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload
from app.core.config import configs
from app.core.exceptions import DuplicatedError, NotFoundError
from app.util.query_builder import dict_to_sqlalchemy_filter_options
[docs]
class BaseRepository:
"""
Base repository providing common CRUD operations.
Attributes:
session_factory (Callable[..., AbstractContextManager[Session]]):
Factory for creating SQLAlchemy sessions.
model: SQLAlchemy model class.
"""
def __init__(
self, session_factory: Callable[..., AbstractContextManager[Session]], model
) -> None:
"""
Initializes the BaseRepository with the provided session factory and
model.
Args:
session_factory (Callable[..., AbstractContextManager[Session]]):
The session factory.
model: The SQLAlchemy model class.
"""
self.session_factory = session_factory
self.model = model
[docs]
def read_by_options(self, schema, eager=False):
"""
Reads records by specified options.
Args:
schema: The schema containing query options.
eager (bool): Whether to use eager loading.
Returns:
dict: Query results and search options.
"""
with self.session_factory() as session:
schema_as_dict = schema.dict(exclude_none=True)
ordering = schema_as_dict.get("ordering", configs.ORDERING)
order_query = (
getattr(self.model, ordering[1:]).desc()
if ordering.startswith("-")
else getattr(self.model, ordering).asc()
)
page = schema_as_dict.get("page", configs.PAGE)
page_size = schema_as_dict.get("page_size", configs.PAGE_SIZE)
filter_options = dict_to_sqlalchemy_filter_options(
self.model, schema.dict(exclude_none=True)
)
query = session.query(self.model)
if eager:
for eager in getattr(self.model, "eagers", []):
query = query.options(joinedload(getattr(self.model, eager)))
filtered_query = query.filter(filter_options)
query = filtered_query.order_by(order_query)
if page_size == "all":
query = query.all()
else:
query = query.limit(page_size).offset((page - 1) * page_size).all()
total_count = filtered_query.count()
return {
"items": query,
"search_options": {
"page": page,
"page_size": page_size,
"ordering": ordering,
"total_count": total_count,
},
}
[docs]
def read_by_id(
self,
id: int,
eager=False,
not_found_raise_exception=True,
not_found_message="Not found id : {id}",
):
"""
Reads a record by its ID.
Args:
id (int): The record ID.
eager (bool): Whether to use eager loading.
not_found_raise_exception (bool): Whether to raise an exception if
the record is not found.
not_found_message (str): The message for the not found exception.
Returns:
object: The record if found, otherwise None or raises an exception.
"""
with self.session_factory() as session:
query = session.query(self.model)
if eager:
for eager in getattr(self.model, "eagers", []):
query = query.options(joinedload(getattr(self.model, eager)))
query = query.filter(self.model.id == id).first()
if not query and not_found_raise_exception:
raise NotFoundError(detail=not_found_message.format(id=id))
if not not_found_raise_exception and not query:
return None
return query
[docs]
def read_by_column(
self,
column: str,
value: str,
eager=False,
only_one=True,
not_found_raise_exception=True,
not_found_message="Not found {column} : {value}",
):
"""
Reads records by a specified column and value.
Args:
column (str): The column name.
value (str): The value to filter by.
eager (bool): Whether to use eager loading.
only_one (bool): Whether to return only one record.
not_found_raise_exception (bool): Whether to raise an exception if
the record is not found.
not_found_message (str): The message for the not found exception.
Returns:
object or list: The record(s) if found, otherwise None or raises
an exception.
"""
with self.session_factory() as session:
query = session.query(self.model)
if eager:
for eager in getattr(self.model, "eagers", []):
query = query.options(joinedload(getattr(self.model, eager)))
if only_one:
query = query.filter(getattr(self.model, column) == value).first()
if not query and not_found_raise_exception:
raise NotFoundError(
detail=not_found_message.format(column=column, value=value)
)
return query
query = query.filter(getattr(self.model, column) == value).all()
return query
[docs]
async def create(self, schema):
"""
Creates a new record.
Args:
schema: The schema containing the record data.
Returns:
object: The created record.
"""
with self.session_factory() as session:
query = self.model(**schema.dict())
try:
session.add(query)
session.commit()
session.refresh(query)
except IntegrityError as e:
raise DuplicatedError(detail=str(e.orig))
return query
[docs]
def update(self, id: int, schema):
"""
Updates a record by its ID.
Args:
id (int): The record ID.
schema: The schema containing the updated data.
Returns:
object: The updated record.
"""
with self.session_factory() as session:
session.query(self.model).filter(self.model.id == id).update(
schema.dict(exclude_none=True)
)
session.commit()
return self.read_by_id(id)
[docs]
def update_attr(self, id: int, column: str, value):
"""
Updates a specific attribute of a record by its ID.
Args:
id (int): The record ID.
column (str): The column name.
value: The new value of the attribute.
Returns:
object: The updated record.
"""
with self.session_factory() as session:
session.query(self.model).filter(self.model.id == id).update(
{column: value}
)
session.commit()
return self.read_by_id(id)
[docs]
def whole_update(self, id: int, schema):
"""
Replaces a record entirely by its ID.
Args:
id (int): The record ID.
schema: The schema containing the new data.
Returns:
object: The updated record.
"""
with self.session_factory() as session:
session.query(self.model).filter(self.model.id == id).update(schema.dict())
session.commit()
return self.read_by_id(id)
[docs]
def delete_by_id(self, id: int):
"""
Deletes a record by its ID.
Args:
id (int): The record ID.
Returns:
None
"""
with self.session_factory() as session:
query = session.query(self.model).filter(self.model.id == id).first()
if not query:
raise NotFoundError(detail=f"Not found id : {id}")
session.delete(query)
session.commit()