├── src ├── __init__.py ├── api │ ├── __init__.py │ ├── v1 │ │ ├── __init__.py │ │ ├── tools │ │ │ ├── __init__.py │ │ │ └── validate.py │ │ ├── events │ │ │ ├── __init__.py │ │ │ └── email.py │ │ ├── integration │ │ │ ├── __init__.py │ │ │ └── dishka.py │ │ ├── subscribers │ │ │ └── __init__.py │ │ ├── dtos │ │ │ ├── status.py │ │ │ ├── role.py │ │ │ ├── auth.py │ │ │ ├── __init__.py │ │ │ ├── user.py │ │ │ └── base.py │ │ ├── constants.py │ │ ├── endpoints │ │ │ ├── admin │ │ │ │ ├── __init__.py │ │ │ │ └── user.py │ │ │ ├── healthcheck.py │ │ │ ├── __init__.py │ │ │ ├── user.py │ │ │ └── auth.py │ │ ├── middlewares │ │ │ ├── __init__.py │ │ │ └── auth.py │ │ ├── handlers │ │ │ ├── user │ │ │ │ ├── __init__.py │ │ │ │ ├── select.py │ │ │ │ ├── update.py │ │ │ │ ├── select_many.py │ │ │ │ └── create.py │ │ │ ├── auth │ │ │ │ ├── __init__.py │ │ │ │ ├── refresh.py │ │ │ │ ├── logout.py │ │ │ │ ├── confirm.py │ │ │ │ ├── login.py │ │ │ │ └── register.py │ │ │ └── __init__.py │ │ ├── setup.py │ │ ├── permission.py │ │ └── di.py │ ├── common │ │ ├── __init__.py │ │ ├── broker │ │ │ ├── __init__.py │ │ │ └── nats │ │ │ │ ├── __init__.py │ │ │ │ ├── connection.py │ │ │ │ ├── message.py │ │ │ │ ├── core.py │ │ │ │ └── stream.py │ │ ├── bus │ │ │ ├── __init__.py │ │ │ └── core.py │ │ ├── interfaces │ │ │ ├── __init__.py │ │ │ ├── event_bus.py │ │ │ ├── mediator.py │ │ │ ├── handler.py │ │ │ ├── broker.py │ │ │ └── dto.py │ │ ├── serializers │ │ │ ├── __init__.py │ │ │ └── msgspec.py │ │ ├── mediator │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ └── core.py │ │ ├── events │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── nats.py │ │ ├── middlewares │ │ │ ├── __init__.py │ │ │ ├── proccess_time.py │ │ │ └── request_id.py │ │ ├── tools.py │ │ ├── converters.py │ │ ├── exceptions.py │ │ └── docs.py │ ├── server │ │ ├── utils.py │ │ ├── __init__.py │ │ └── granian.py │ └── setup.py ├── common │ ├── __init__.py │ ├── tools │ │ ├── __init__.py │ │ ├── types.py │ │ ├── text.py │ │ ├── cache.py │ │ ├── singleton.py │ │ └── files.py │ ├── di │ │ ├── __init__.py │ │ ├── marker.py │ │ ├── base.py │ │ └── container.py │ ├── logger.py │ └── exceptions.py ├── settings │ ├── __init__.py │ └── core.py ├── services │ ├── cache │ │ ├── __init__.py │ │ └── redis.py │ ├── provider │ │ ├── __init__.py │ │ ├── types.py │ │ ├── middleware │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── manager.py │ │ │ ├── logging.py │ │ │ └── error.py │ │ ├── response.py │ │ ├── errors.py │ │ ├── base.py │ │ └── aiohttp.py │ ├── interfaces │ │ ├── __init__.py │ │ ├── encrypt.py │ │ └── hasher.py │ ├── security │ │ ├── __init__.py │ │ ├── argon2.py │ │ ├── aes.py │ │ └── jwt.py │ ├── external │ │ ├── request.py │ │ ├── __init__.py │ │ └── client.py │ ├── internal │ │ ├── __init__.py │ │ └── auth.py │ └── __init__.py ├── database │ ├── interfaces │ │ ├── __init__.py │ │ ├── gateway.py │ │ └── crud.py │ ├── repositories │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── role.py │ │ │ └── user.py │ │ ├── base.py │ │ ├── __init__.py │ │ ├── role.py │ │ ├── user.py │ │ └── crud.py │ ├── models │ │ ├── base │ │ │ ├── __init__.py │ │ │ ├── mixins │ │ │ │ ├── __init__.py │ │ │ │ ├── with_time.py │ │ │ │ └── with_id.py │ │ │ └── core.py │ │ ├── types.py │ │ ├── __init__.py │ │ ├── role.py │ │ └── user.py │ ├── types.py │ ├── exceptions.py │ ├── connection.py │ ├── __init__.py │ ├── _utils.py │ ├── manager.py │ └── tools.py └── __main__.py ├── tests ├── __init__.py └── conftest.py ├── nats └── nats.conf ├── migrations ├── README ├── script.py.mako ├── versions │ ├── 02_82e484eb7a8e_unique.py │ └── 01_3d1a9bc30546_initial.py └── env.py ├── jetstream └── stream.json ├── .pre-commit-config.yaml ├── Dockerfile ├── alembic.ini ├── .env.example ├── makefile ├── docker-compose.dev.yml ├── pyproject.toml ├── README.md └── .gitignore /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/v1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/v1/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/settings/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/common/broker/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/common/bus/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/v1/events/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/services/cache/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/services/provider/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/common/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/common/serializers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/v1/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/v1/subscribers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/database/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/services/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/database/repositories/types/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nats/nats.conf: -------------------------------------------------------------------------------- 1 | # Paste your NATS configuration below -------------------------------------------------------------------------------- /migrations/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /jetstream/stream.json: -------------------------------------------------------------------------------- 1 | { 2 | "streams": [ 3 | ] 4 | } 5 | -------------------------------------------------------------------------------- /src/database/models/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Base 2 | 3 | __all__ = ("Base",) 4 | -------------------------------------------------------------------------------- /src/database/types.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | type OrderBy = Literal["asc", "desc"] 4 | -------------------------------------------------------------------------------- /src/api/common/mediator/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import MediatorImpl 2 | 3 | __all__ = ("MediatorImpl",) 4 | -------------------------------------------------------------------------------- /src/common/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .files import open_file_sync 2 | 3 | __all__ = ("open_file_sync",) 4 | -------------------------------------------------------------------------------- /src/api/v1/dtos/status.py: -------------------------------------------------------------------------------- 1 | from src.api.v1.dtos.base import DTO 2 | 3 | 4 | class Status(DTO): 5 | status: bool 6 | -------------------------------------------------------------------------------- /src/api/server/utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | 4 | def workers_count() -> int: 5 | return mp.cpu_count() - 1 6 | -------------------------------------------------------------------------------- /src/common/tools/types.py: -------------------------------------------------------------------------------- 1 | def is_typevar[T](t: T) -> bool: 2 | return hasattr(t, "__bound__") or hasattr(t, "__constraints__") 3 | -------------------------------------------------------------------------------- /src/services/provider/types.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | type RequestMethodType = Literal[ 4 | "GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD" 5 | ] 6 | -------------------------------------------------------------------------------- /src/common/di/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import inject 2 | from .container import container 3 | from .marker import Depends, FromDepends 4 | 5 | __all__ = ("container", "Depends", "FromDepends", "inject") 6 | -------------------------------------------------------------------------------- /src/api/v1/dtos/role.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from uuid_utils.compat import UUID 4 | 5 | from src.api.v1.dtos.base import DTO 6 | 7 | 8 | class Role(DTO): 9 | uuid: UUID 10 | name: str 11 | -------------------------------------------------------------------------------- /src/common/tools/text.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any 3 | 4 | 5 | def pascal_to_snake(obj: Any) -> str: 6 | snake_case = re.sub(r"(? bytes: ... 6 | 7 | def decrypt(self, encrypted_text: bytes) -> str: ... 8 | -------------------------------------------------------------------------------- /src/database/models/types.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | from typing import Literal 3 | 4 | 5 | class RoleTypeEnum(StrEnum): 6 | ADMIN = "Admin" 7 | USER = "User" 8 | 9 | 10 | type Roles = Literal["Admin", "User"] 11 | -------------------------------------------------------------------------------- /src/services/interfaces/hasher.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | 4 | class AbstractHasher(Protocol): 5 | def hash_password(self, plain: str) -> str: ... 6 | 7 | def verify_password(self, hashed: str, plain: str) -> bool: ... 8 | -------------------------------------------------------------------------------- /src/database/repositories/types/role.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, TypedDict 2 | 3 | from src.database.models.types import Roles 4 | 5 | RoleLoads = Literal["users"] 6 | 7 | 8 | class CreateRoleType(TypedDict): 9 | name: Roles 10 | -------------------------------------------------------------------------------- /src/api/common/interfaces/event_bus.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | from src.api.common.events.base import Event 4 | 5 | 6 | @runtime_checkable 7 | class EventBus(Protocol): 8 | async def publish(self, event: Event) -> None: ... 9 | -------------------------------------------------------------------------------- /src/database/models/base/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | from .with_id import IDMixin, UUIDMixin 2 | from .with_time import DeletedTimeMixin, TimeMixin 3 | 4 | __all__ = ( 5 | "IDMixin", 6 | "TimeMixin", 7 | "UUIDMixin", 8 | "DeletedTimeMixin", 9 | ) 10 | -------------------------------------------------------------------------------- /src/api/common/interfaces/mediator.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | from src.api.common.interfaces.dto import DTO 4 | 5 | 6 | @runtime_checkable 7 | class Mediator(Protocol): 8 | async def send[Q: DTO, R: DTO](self, query: Q) -> R: ... 9 | -------------------------------------------------------------------------------- /src/database/exceptions.py: -------------------------------------------------------------------------------- 1 | class DatabaseError(Exception): 2 | pass 3 | 4 | 5 | class CommitError(DatabaseError): 6 | pass 7 | 8 | 9 | class RollbackError(DatabaseError): 10 | pass 11 | 12 | 13 | class InvalidParamsError(DatabaseError): 14 | pass 15 | -------------------------------------------------------------------------------- /src/api/v1/constants.py: -------------------------------------------------------------------------------- 1 | from typing import Final 2 | 3 | MIN_PAGINATION_LIMIT: Final[int] = 10 4 | MAX_PAGINATION_LIMIT: Final[int] = 200 5 | 6 | 7 | MIN_PASSWORD_LENGTH: Final[int] = 8 8 | MAX_PASSWORD_LENGTH: Final[int] = 32 9 | 10 | MAX_LOGIN_LENGTH: Final[int] = 128 11 | -------------------------------------------------------------------------------- /src/api/v1/tools/validate.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | EMAIL_REGEX = re.compile(r"^[^@]+@[^@]+\.[^@]+$") 4 | 5 | 6 | def validate_email(value: str) -> str: 7 | if not EMAIL_REGEX.fullmatch(value): 8 | raise ValueError(f"Invalid email: {value}") 9 | return value 10 | -------------------------------------------------------------------------------- /src/api/common/broker/nats/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import NatsBroker, NatsJetStreamBroker 2 | from .message import NatsJetStreamMessage, NatsMessage 3 | 4 | __all__ = ( 5 | "NatsBroker", 6 | "NatsMessage", 7 | "NatsJetStreamBroker", 8 | "NatsJetStreamMessage", 9 | ) 10 | -------------------------------------------------------------------------------- /src/api/common/events/__init__.py: -------------------------------------------------------------------------------- 1 | from .nats import ( 2 | NatsEvent, 3 | NatsJetStreamEvent, 4 | rebuild_jetstream_event, 5 | rebuild_nats_event, 6 | ) 7 | 8 | __all__ = ( 9 | "NatsEvent", 10 | "NatsJetStreamEvent", 11 | "rebuild_nats_event", 12 | "rebuild_jetstream_event", 13 | ) 14 | -------------------------------------------------------------------------------- /src/api/v1/endpoints/admin/__init__.py: -------------------------------------------------------------------------------- 1 | from src.api.v1.integration.dishka import DishkaRouter 2 | 3 | from .user import UserController 4 | 5 | 6 | def setup_admin_controllers() -> DishkaRouter: 7 | router = DishkaRouter("/admin", route_handlers=[]) 8 | router.register(UserController) 9 | return router 10 | -------------------------------------------------------------------------------- /src/common/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from src.settings.core import ( 4 | DATETIME_FORMAT, 5 | LOG_LEVEL, 6 | LOGGING_FORMAT, 7 | PROJECT_NAME, 8 | ) 9 | 10 | logging.basicConfig(level=LOG_LEVEL, format=LOGGING_FORMAT, datefmt=DATETIME_FORMAT) 11 | 12 | log = logging.getLogger(PROJECT_NAME) 13 | -------------------------------------------------------------------------------- /src/api/common/interfaces/handler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any 3 | 4 | from src.api.common.interfaces.dto import DTO 5 | 6 | 7 | class Handler[Q: DTO, R](abc.ABC): 8 | __slots__ = () 9 | 10 | @abc.abstractmethod 11 | async def __call__(self, query: Q) -> R: ... 12 | 13 | 14 | type HandlerType = Handler[Any, Any] 15 | -------------------------------------------------------------------------------- /src/common/di/marker.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, TypeVar, Union, cast 2 | 3 | from dishka import FromDishka 4 | 5 | T = TypeVar("T") 6 | 7 | if TYPE_CHECKING: 8 | Depends = Union[T, T] 9 | else: 10 | 11 | class Depends(FromDishka): ... 12 | 13 | 14 | def FromDepends[T]() -> Depends[T]: 15 | return cast(Depends[T], ...) 16 | -------------------------------------------------------------------------------- /src/api/common/middlewares/__init__.py: -------------------------------------------------------------------------------- 1 | from litestar.types.composite_types import Middleware 2 | 3 | from .proccess_time import ProcessTimeHeader 4 | from .request_id import XRequestIdMiddleware 5 | 6 | 7 | def setup_middlewares() -> tuple[Middleware, ...]: 8 | return ( 9 | XRequestIdMiddleware(), 10 | ProcessTimeHeader(), 11 | ) 12 | -------------------------------------------------------------------------------- /src/api/v1/dtos/auth.py: -------------------------------------------------------------------------------- 1 | from src.api.v1.dtos.base import DTO 2 | 3 | 4 | class Fingerprint(DTO): 5 | fingerprint: str 6 | 7 | 8 | class Login(Fingerprint): 9 | login: str 10 | password: str 11 | 12 | 13 | class Register(Fingerprint): 14 | login: str 15 | password: str 16 | 17 | 18 | class VerificationCode(DTO): 19 | code: str 20 | -------------------------------------------------------------------------------- /src/api/v1/events/email.py: -------------------------------------------------------------------------------- 1 | import msgspec 2 | 3 | from src.api.common.events.nats import NatsEvent, rebuild_nats_event 4 | 5 | 6 | @rebuild_nats_event(subject="email.send") 7 | class SendEmail[T](NatsEvent, kw_only=True): 8 | from_: str = msgspec.field(name="from") 9 | to: str 10 | title: str 11 | template: str 12 | props: T | None = None 13 | -------------------------------------------------------------------------------- /src/api/v1/endpoints/healthcheck.py: -------------------------------------------------------------------------------- 1 | from litestar import get, status_codes 2 | 3 | from src.api.v1 import dtos 4 | 5 | 6 | @get( 7 | "/healthcheck", 8 | tags=["Healthcheck"], 9 | status_code=status_codes.HTTP_200_OK, 10 | exclude_from_auth=True, 11 | ) 12 | async def healthcheck_endpoint() -> dtos.Status: 13 | return dtos.Status(status=True) 14 | -------------------------------------------------------------------------------- /src/services/security/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from typing import Final 4 | 5 | DEFAULT_LENGTH: Final[int] = 12 6 | 7 | 8 | def generate_password(length: int = DEFAULT_LENGTH) -> str: 9 | characters = string.ascii_letters + string.digits + string.punctuation 10 | password = "".join(random.choice(characters) for _ in range(length)) 11 | 12 | return password 13 | -------------------------------------------------------------------------------- /src/database/repositories/types/user.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, TypedDict 2 | 3 | from uuid_utils.compat import UUID 4 | 5 | UserLoads = Literal["role"] 6 | 7 | 8 | class CreateUserType(TypedDict): 9 | login: str 10 | password: str 11 | role_uuid: UUID 12 | 13 | 14 | class UpdateUserType(TypedDict, total=False): 15 | login: str | None 16 | password: str | None 17 | role_uuid: UUID | None 18 | -------------------------------------------------------------------------------- /src/api/v1/middlewares/__init__.py: -------------------------------------------------------------------------------- 1 | from litestar import Router 2 | from litestar.middleware.base import DefineMiddleware 3 | 4 | from .auth import AuthenticationMiddleware 5 | 6 | 7 | def setup_middlewares(router: Router) -> None: 8 | router.middleware.append( 9 | DefineMiddleware( 10 | AuthenticationMiddleware, 11 | exclude_http_methods=["OPTIONS", "HEAD"], 12 | ) 13 | ) 14 | -------------------------------------------------------------------------------- /src/__main__.py: -------------------------------------------------------------------------------- 1 | from contextlib import suppress 2 | 3 | from src.api.server import serve 4 | from src.api.setup import init_app 5 | from src.api.v1.setup import init_v1_router 6 | from src.settings.core import load_settings 7 | 8 | settings = load_settings() 9 | app = init_app(settings, init_v1_router(settings=settings)) 10 | 11 | 12 | if __name__ == "__main__": 13 | with suppress(KeyboardInterrupt): 14 | serve(settings=settings) 15 | -------------------------------------------------------------------------------- /src/common/tools/cache.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | def default_key_builder( 5 | *keys: str, 6 | data: dict[str, Any] | None = None, 7 | separator: str = ".", 8 | **kwargs: Any, 9 | ) -> str: 10 | parts = list(keys) 11 | if data: 12 | parts.extend(f"{k}{separator}{v}" for k, v in data.items()) 13 | 14 | parts.extend(f"{k}{separator}{v}" for k, v in kwargs.items()) 15 | return separator.join(parts) 16 | -------------------------------------------------------------------------------- /src/services/external/request.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any 3 | 4 | from src.services.provider.base import AsyncProvider 5 | 6 | 7 | class Request[ResultType]: 8 | async def __call__(self, provider: AsyncProvider, **kwargs: Any) -> ResultType: 9 | return await self.handle(provider, **kwargs) 10 | 11 | @abc.abstractmethod 12 | async def handle(self, provider: AsyncProvider, **kwargs: Any) -> ResultType: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /src/api/v1/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | from litestar import Router 2 | 3 | from .admin import setup_admin_controllers 4 | from .auth import AuthController 5 | from .healthcheck import healthcheck_endpoint 6 | from .user import UserController 7 | 8 | 9 | def setup_controllers(router: Router) -> None: 10 | router.register(healthcheck_endpoint) 11 | router.register(UserController) 12 | router.register(AuthController) 13 | 14 | router.register(setup_admin_controllers()) 15 | -------------------------------------------------------------------------------- /src/services/provider/middleware/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | BaseRequestMiddleware, 3 | RequestMiddlewareType, 4 | ) 5 | from .error import RequestErrorMiddleware 6 | from .logging import RequestLoggingMiddleware 7 | from .manager import RequestMiddlewareManager 8 | 9 | __all__ = ( 10 | "BaseRequestMiddleware", 11 | "RequestMiddlewareType", 12 | "RequestMiddlewareManager", 13 | "RequestLoggingMiddleware", 14 | "RequestErrorMiddleware", 15 | ) 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-mypy 3 | rev: v1.13.0 4 | hooks: 5 | - id: mypy 6 | args: ["--config-file=pyproject.toml"] 7 | files: ^src/ 8 | language: system 9 | 10 | - repo: https://github.com/charliermarsh/ruff-pre-commit 11 | rev: v0.8.2 12 | hooks: 13 | - id: ruff-format 14 | - id: ruff 15 | args: ["--config=pyproject.toml", "--fix"] 16 | language: system 17 | -------------------------------------------------------------------------------- /src/api/common/broker/nats/connection.py: -------------------------------------------------------------------------------- 1 | from nats.aio.client import Client as NatsClient 2 | from src.api.common.broker.nats.core import NatsBroker 3 | from src.settings.core import Settings 4 | 5 | 6 | async def create_nats_connection(broker: NatsBroker, settings: Settings) -> NatsClient: 7 | await broker.nats.connect( 8 | servers=settings.nats.servers, 9 | user=settings.nats.user, 10 | password=settings.nats.password, 11 | ) 12 | return broker.nats 13 | -------------------------------------------------------------------------------- /src/api/common/interfaces/broker.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Protocol, runtime_checkable 3 | 4 | from src.api.common.events.base import Event 5 | 6 | 7 | @dataclass 8 | class BrokerMessage: ... 9 | 10 | 11 | @runtime_checkable 12 | class Broker[E: Event, M: BrokerMessage](Protocol): 13 | async def publish(self, message: M) -> None: ... 14 | 15 | def _build_message(self, event: E) -> M: ... 16 | 17 | 18 | type BrokerType = Broker[Any, Any] 19 | -------------------------------------------------------------------------------- /src/database/repositories/base.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.ext.asyncio import AsyncSession 2 | 3 | from src.database.models.base import Base 4 | from src.database.repositories.crud import CRUDRepository 5 | 6 | 7 | class BaseRepository[M: Base]: 8 | __slots__ = ("_session", "_crud", "model") 9 | 10 | def __init__(self, session: AsyncSession, model: type[M]) -> None: 11 | self._session = session 12 | self.model = model 13 | self._crud = CRUDRepository(session, model) 14 | -------------------------------------------------------------------------------- /src/api/v1/handlers/user/__init__.py: -------------------------------------------------------------------------------- 1 | from .create import CreateUserHandler, CreateUserQuery 2 | from .select import SelectUserHandler, SelectUserQuery 3 | from .select_many import SelectManyUserHandler, SelectManyUserQuery 4 | from .update import UpdateUserHandler, UpdateUserQuery 5 | 6 | __all__ = ( 7 | "CreateUserHandler", 8 | "CreateUserQuery", 9 | "SelectManyUserHandler", 10 | "SelectManyUserQuery", 11 | "SelectUserHandler", 12 | "SelectUserQuery", 13 | "UpdateUserHandler", 14 | "UpdateUserQuery", 15 | ) 16 | -------------------------------------------------------------------------------- /src/api/v1/dtos/__init__.py: -------------------------------------------------------------------------------- 1 | from .auth import Fingerprint, Login, Register, VerificationCode 2 | from .base import DTO 3 | from .role import Role 4 | from .status import Status 5 | from .user import UpdateUser, User 6 | 7 | __all__ = ( 8 | "DTO", 9 | "User", 10 | "Register", 11 | "Fingerprint", 12 | "Login", 13 | "Role", 14 | "Status", 15 | "VerificationCode", 16 | "UpdateUser", 17 | ) 18 | 19 | 20 | class OffsetResult[T](DTO): 21 | data: list[T] 22 | offset: int 23 | limit: int 24 | total: int 25 | -------------------------------------------------------------------------------- /src/database/connection.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from sqlalchemy.ext.asyncio import ( 4 | AsyncEngine, 5 | AsyncSession, 6 | async_sessionmaker, 7 | create_async_engine, 8 | ) 9 | 10 | type SessionFactoryType = async_sessionmaker[AsyncSession] 11 | 12 | 13 | def create_sa_engine(url: str, **kwargs: Any) -> AsyncEngine: 14 | return create_async_engine(url, **kwargs) 15 | 16 | 17 | def create_sa_session_factory(engine: AsyncEngine) -> SessionFactoryType: 18 | return async_sessionmaker(engine, autoflush=False, expire_on_commit=False) 19 | -------------------------------------------------------------------------------- /src/common/di/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | from dishka import AsyncContainer 4 | from dishka.integrations.base import wrap_injection 5 | 6 | from src.common.di.container import container 7 | 8 | 9 | def inject[T, **P](func: Callable[P, T]) -> Callable[P, T]: 10 | def container_getter(*args: Any, **kwargs: Any) -> AsyncContainer: 11 | return container.get_container() 12 | 13 | return wrap_injection( 14 | func=func, 15 | container_getter=container_getter, 16 | is_async=True, 17 | manage_scope=True, 18 | ) 19 | -------------------------------------------------------------------------------- /src/api/common/broker/nats/message.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | from src.api.common.interfaces.broker import BrokerMessage 5 | 6 | 7 | @dataclass 8 | class NatsMessage(BrokerMessage): 9 | subject: str = "" 10 | payload: bytes = b"" 11 | reply: str = "" 12 | headers: dict[str, Any] | None = None 13 | 14 | 15 | @dataclass 16 | class NatsJetStreamMessage(BrokerMessage): 17 | subject: str = "" 18 | payload: bytes = b"" 19 | timeout: float | None = None 20 | stream: str | None = None 21 | headers: dict[str, Any] | None = None 22 | -------------------------------------------------------------------------------- /src/api/v1/handlers/auth/__init__.py: -------------------------------------------------------------------------------- 1 | from .confirm import ConfirmRegisterHandler, ConfirmRegisterQuery 2 | from .login import LoginHandler, LoginQuery 3 | from .logout import LogoutHandler, LogoutQuery 4 | from .refresh import RefreshTokenHandler, RefreshTokenQuery 5 | from .register import RegisterHandler, RegisterQuery 6 | 7 | __all__ = ( 8 | "ConfirmRegisterHandler", 9 | "ConfirmRegisterQuery", 10 | "RegisterHandler", 11 | "RegisterQuery", 12 | "LoginHandler", 13 | "LoginQuery", 14 | "LogoutHandler", 15 | "LogoutQuery", 16 | "RefreshTokenQuery", 17 | "RefreshTokenHandler", 18 | ) 19 | -------------------------------------------------------------------------------- /src/api/common/serializers/msgspec.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import msgspec 4 | 5 | 6 | def msgpack_encoder(obj: Any, *args: Any, **kw: Any) -> bytes: 7 | return msgspec.msgpack.encode(obj, *args, **kw) 8 | 9 | 10 | def msgpack_decoder(obj: Any, *args: Any, **kw: Any) -> Any: 11 | return msgspec.msgpack.decode(obj, *args, strict=kw.pop("strict", False), **kw) 12 | 13 | 14 | def msgspec_encoder(obj: Any, *args: Any, **kw: Any) -> str: 15 | return msgspec.json.encode(obj, *args, **kw).decode(encoding="utf-8") 16 | 17 | 18 | def msgspec_decoder(obj: Any, *args: Any, **kw: Any) -> Any: 19 | return msgspec.json.decode(obj, *args, **kw) 20 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12.4 2 | 3 | ENV PYTHONDONTWRITEBYTECODE 1 4 | ENV PYTHONBUFFERED 1 5 | 6 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ 7 | 8 | RUN apt-get update && apt-get install -y --no-install-recommends \ 9 | curl \ 10 | ca-certificates && \ 11 | rm -rf /var/lib/apt/lists/* 12 | 13 | ADD https://astral.sh/uv/install.sh /uv-installer.sh 14 | 15 | RUN sh /uv-installer.sh && rm /uv-installer.sh 16 | 17 | ENV PATH="/root/.local/bin/:$PATH" 18 | ENV UV_LINK_MODE=copy 19 | 20 | COPY . . 21 | 22 | EXPOSE 8091 23 | 24 | RUN uv sync --frozen --no-dev --compile-bytecode 25 | CMD ["sh", "-c", "uv run alembic upgrade head && uv run python -m src"] 26 | -------------------------------------------------------------------------------- /src/api/server/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from src.settings.core import Settings 4 | 5 | from .granian import run_api_granian 6 | 7 | 8 | def serve(settings: Settings, suffix: str = "app", **kwargs: Any) -> None: 9 | match settings.server.type: 10 | case "granian": 11 | run_api_granian(f"src.__main__:{suffix}", settings.server, **kwargs) 12 | case _: 13 | raise ValueError( 14 | f"Unsupported server type: '{settings.server.type}'. " 15 | "Currently only 'granian' server type is supported. " 16 | "Please update your server configuration to use 'granian'." 17 | ) 18 | -------------------------------------------------------------------------------- /src/database/models/__init__.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import RelationshipProperty 2 | 3 | from src.database._utils import frozendict 4 | 5 | from .base import Base 6 | from .role import Role 7 | from .user import User 8 | 9 | __all__ = ("Base", "User", "Role") 10 | 11 | 12 | def _retrieve_relationships() -> dict[ 13 | type[Base], tuple[RelationshipProperty[type[Base]], ...] 14 | ]: 15 | return { 16 | mapper.class_: tuple(mapper.relationships.values()) 17 | for mapper in Base.registry.mappers 18 | } 19 | 20 | 21 | MODELS_RELATIONSHIPS_NODE: frozendict[ 22 | type[Base], tuple[RelationshipProperty[type[Base]], ...] 23 | ] = frozendict(_retrieve_relationships()) 24 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | [alembic] 2 | script_location = ./migrations 3 | 4 | prepend_sys_path = . 5 | 6 | [loggers] 7 | keys = root,sqlalchemy,alembic 8 | 9 | [handlers] 10 | keys = console 11 | 12 | [formatters] 13 | keys = generic 14 | 15 | [logger_root] 16 | level = WARN 17 | handlers = console 18 | qualname = 19 | 20 | [logger_sqlalchemy] 21 | level = WARN 22 | handlers = 23 | qualname = sqlalchemy.engine 24 | 25 | [logger_alembic] 26 | level = INFO 27 | handlers = 28 | qualname = alembic 29 | 30 | [handler_console] 31 | class = StreamHandler 32 | args = (sys.stderr,) 33 | level = NOTSET 34 | formatter = generic 35 | 36 | [formatter_generic] 37 | format = %(levelname)-5.5s [%(name)s] %(message)s 38 | datefmt = %H:%M:%S 39 | -------------------------------------------------------------------------------- /src/api/v1/setup.py: -------------------------------------------------------------------------------- 1 | from litestar import Router 2 | 3 | from src.api.common import tools 4 | from src.api.v1.di import setup_dependencies 5 | from src.api.v1.endpoints import setup_controllers 6 | from src.api.v1.integration.dishka import DishkaRouter 7 | from src.api.v1.middlewares import setup_middlewares 8 | from src.settings.core import Settings 9 | 10 | 11 | def init_v1_router( 12 | *sub_routers: Router, path: str = "/v1", settings: Settings 13 | ) -> tools.RouterState: 14 | router = DishkaRouter(path, route_handlers=sub_routers) 15 | 16 | setup_controllers(router) 17 | setup_middlewares(router) 18 | state = setup_dependencies(settings) 19 | 20 | return tools.RouterState(router=router, state=state) 21 | -------------------------------------------------------------------------------- /src/common/di/container.py: -------------------------------------------------------------------------------- 1 | from dishka import AsyncContainer, Provider, make_async_container 2 | 3 | 4 | class DynamicContainer: 5 | __slots__ = ("_container",) 6 | 7 | def __init__(self) -> None: 8 | self._container: AsyncContainer | None = None 9 | 10 | def add_providers(self, *providers: Provider) -> None: 11 | new_container = make_async_container(*providers) 12 | new_container.parent_container = self._container 13 | self._container = new_container 14 | 15 | def get_container(self) -> AsyncContainer: 16 | if not self._container: 17 | raise RuntimeError("Container is not initialized") 18 | return self._container 19 | 20 | 21 | container = DynamicContainer() 22 | -------------------------------------------------------------------------------- /src/common/tools/singleton.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | type _AnyDependency = Callable[[], Any] 4 | 5 | 6 | def singleton[DependencyType]( 7 | dependency: DependencyType, 8 | ) -> Callable[[], DependencyType]: 9 | def singleton_factory() -> DependencyType: 10 | return dependency 11 | 12 | return singleton_factory 13 | 14 | 15 | def lazy[T]( 16 | v: Callable[..., T], *args: _AnyDependency, **deps: _AnyDependency 17 | ) -> Callable[[], T]: 18 | def _factory() -> T: 19 | return v(*(arg() for arg in args), **{k: dep() for k, dep in deps.items()}) 20 | 21 | return _factory 22 | 23 | 24 | def lazy_single[T, D](v: Callable[[D], T], dep: Callable[[], D]) -> Callable[[], T]: 25 | return lazy(v, dep) 26 | -------------------------------------------------------------------------------- /src/database/models/base/mixins/with_time.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from sqlalchemy import DateTime, func 4 | from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column 5 | 6 | 7 | @declarative_mixin 8 | class TimeMixin: 9 | created_at: Mapped[datetime] = mapped_column( 10 | DateTime(timezone=True), nullable=False, server_default=func.now() 11 | ) 12 | updated_at: Mapped[datetime] = mapped_column( 13 | DateTime(timezone=True), 14 | nullable=False, 15 | onupdate=func.now(), 16 | server_default=func.now(), 17 | ) 18 | 19 | 20 | @declarative_mixin 21 | class DeletedTimeMixin(TimeMixin): 22 | deleted_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=True) 23 | -------------------------------------------------------------------------------- /src/services/external/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | from src.services.provider.base import AsyncProvider 4 | from src.settings.core import Settings 5 | 6 | 7 | class ExternalServiceGateway: 8 | __slots__ = ("provider", "settings", "_cache") 9 | 10 | def __init__(self, provider: AsyncProvider, settings: Settings) -> None: 11 | self.provider = provider 12 | self.settings = settings 13 | self._cache: dict[str, Any] = {} 14 | 15 | def _from_cache[S](self, key: str, factory: Callable[..., S], **kwargs: Any) -> S: 16 | if not (cached := self._cache.get(key)): 17 | cached = factory(self.provider, **kwargs) 18 | self._cache[key] = cached 19 | 20 | return cached 21 | -------------------------------------------------------------------------------- /migrations/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from typing import Sequence, Union 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | ${imports if imports else ""} 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = ${repr(up_revision)} 16 | down_revision: Union[str, None] = ${repr(down_revision)} 17 | branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} 18 | depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} 19 | 20 | 21 | def upgrade() -> None: 22 | ${upgrades if upgrades else "pass"} 23 | 24 | 25 | def downgrade() -> None: 26 | ${downgrades if downgrades else "pass"} 27 | -------------------------------------------------------------------------------- /src/database/models/base/mixins/with_id.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Integer, text 2 | from sqlalchemy.dialects.postgresql import UUID 3 | from sqlalchemy.orm import Mapped, MappedAsDataclass, declarative_mixin, mapped_column 4 | from uuid_utils.compat import UUID as uuid_type 5 | from uuid_utils.compat import uuid4 6 | 7 | 8 | @declarative_mixin 9 | class IDMixin(MappedAsDataclass, init=False): 10 | id: Mapped[int] = mapped_column( 11 | Integer, 12 | primary_key=True, 13 | ) 14 | 15 | 16 | @declarative_mixin 17 | class UUIDMixin: 18 | uuid: Mapped[uuid_type] = mapped_column( 19 | UUID(as_uuid=True), 20 | primary_key=True, 21 | default=uuid4, 22 | server_default=text("gen_random_uuid()"), 23 | nullable=False, 24 | ) 25 | -------------------------------------------------------------------------------- /src/api/v1/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | from src.api.common.mediator import MediatorImpl 2 | 3 | from . import auth, user 4 | 5 | 6 | def setup_handlers(mediator: MediatorImpl) -> None: 7 | mediator.register(auth.RegisterQuery, auth.RegisterHandler) 8 | mediator.register(auth.LoginQuery, auth.LoginHandler) 9 | mediator.register(auth.LogoutQuery, auth.LogoutHandler) 10 | mediator.register(auth.RefreshTokenQuery, auth.RefreshTokenHandler) 11 | mediator.register(auth.ConfirmRegisterQuery, auth.ConfirmRegisterHandler) 12 | mediator.register(user.CreateUserQuery, user.CreateUserHandler) 13 | mediator.register(user.SelectUserQuery, user.SelectUserHandler) 14 | mediator.register(user.SelectManyUserQuery, user.SelectManyUserHandler) 15 | mediator.register(user.UpdateUserQuery, user.UpdateUserHandler) 16 | -------------------------------------------------------------------------------- /src/api/server/granian.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from granian.constants import Interfaces 4 | from granian.server import Server as Granian 5 | 6 | from src.api.server.utils import workers_count 7 | from src.common.logger import log 8 | from src.settings.core import ServerSettings 9 | 10 | 11 | def run_api_granian( 12 | target: Any, 13 | settings: ServerSettings, 14 | **kwargs: Any, 15 | ) -> None: 16 | server = Granian( 17 | target, 18 | address=settings.host, 19 | port=settings.port, 20 | workers=settings.workers if settings.workers != "auto" else workers_count(), 21 | runtime_threads=settings.threads, 22 | log_access=settings.log, 23 | interface=Interfaces.ASGI, 24 | **kwargs, 25 | ) 26 | log.info("Running API Granian") 27 | server.serve() 28 | -------------------------------------------------------------------------------- /src/api/v1/dtos/user.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from msgspec import Meta 4 | from uuid_utils.compat import UUID 5 | 6 | from src.api.v1.constants import MAX_PASSWORD_LENGTH, MIN_PASSWORD_LENGTH 7 | from src.api.v1.dtos.base import DTO 8 | from src.api.v1.dtos.role import Role 9 | 10 | 11 | class User(DTO): 12 | uuid: UUID 13 | login: str 14 | active: bool 15 | 16 | # relations 17 | role: Role | None = None 18 | 19 | 20 | class UpdateUser(DTO): 21 | password: ( 22 | Annotated[ 23 | str, 24 | Meta( 25 | min_length=MIN_PASSWORD_LENGTH, 26 | max_length=MAX_PASSWORD_LENGTH, 27 | description=( 28 | f"Password between `{MIN_PASSWORD_LENGTH}` and " 29 | f"`{MAX_PASSWORD_LENGTH}` characters long" 30 | ), 31 | ), 32 | ] 33 | | None 34 | ) = None 35 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # database settings 2 | DB_URI=postgresql+asyncpg://{}:{}@{}:{}/{} 3 | DB_HOST=localhost 4 | DB_PORT=54522 5 | DB_USER=username 6 | DB_NAME=dbname 7 | DB_PASSWORD=password 8 | DB_CONNECTION_POOL_SIZE=20 9 | DB_CONNECTION_MAX_OVERFLOW=70 10 | 11 | # server settings 12 | SERVER_HOST=0.0.0.0 13 | SERVER_PORT=8029 14 | SERVER_TYPE=granian 15 | SERVER_WORKERS=1 16 | 17 | #redis settings 18 | REDIS_HOST=localhost 19 | REDIS_PORT=63122 20 | 21 | # nats settings 22 | NATS_SERVERS=["nats://localhost:4222"] 23 | NATS_USER=admin 24 | NATS_PASSWORD=123456789123456789123456789123456789 25 | 26 | # jwt settings 27 | CIPHER_ALGORITHM=EdDSA 28 | CIPHER_SECRET_KEY=... 29 | CIPHER_PUBLIC_KEY=... 30 | CIPHER_ACCESS_TOKEN_EXPIRE_SECONDS=1800 31 | CIPHER_REFRESH_TOKEN_EXPIRE_SECONDS=604800 32 | 33 | APP_LOG_LEVEL=DEBUG 34 | APP_ROOT_PATH=/api 35 | APP_PRODUCTION=False 36 | APP_TITLE=Title 37 | APP_DEBUG=False 38 | APP_VERSION=0.0.1 39 | APP_SWAGGER=True 40 | -------------------------------------------------------------------------------- /src/database/models/role.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import sqlalchemy as sa 4 | import sqlalchemy.orm as orm 5 | 6 | from src.common.tools.text import pascal_to_snake 7 | from src.database.models import Base, types 8 | from src.database.models.base import mixins 9 | 10 | if TYPE_CHECKING: 11 | from src.database.models import User 12 | 13 | 14 | class Role(mixins.UUIDMixin, mixins.TimeMixin, Base): 15 | name: orm.Mapped[types.RoleTypeEnum] = orm.mapped_column( 16 | sa.Enum( 17 | types.RoleTypeEnum, 18 | native_enum=False, 19 | values_callable=lambda x: [e.value for e in x], 20 | name=pascal_to_snake(types.RoleTypeEnum), 21 | ), 22 | index=True, 23 | ) 24 | 25 | users: orm.Mapped[list["User"]] = orm.relationship( 26 | "User", 27 | back_populates="role", 28 | primaryjoin="Role.uuid == foreign(User.role_uuid)", 29 | ) 30 | -------------------------------------------------------------------------------- /src/services/internal/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | from src.services.cache.redis import RedisCache 4 | from src.services.security.jwt import JWT 5 | 6 | from .auth import AuthService 7 | 8 | 9 | class InternalServiceGateway: 10 | __slots__ = ("_cache", "_jwt", "_redis") 11 | 12 | def __init__(self, jwt: JWT, redis: RedisCache) -> None: 13 | self._redis = redis 14 | self._jwt = jwt 15 | self._cache: dict[str, Any] = {} 16 | 17 | @property 18 | def auth(self) -> AuthService: 19 | return self._from_cache("auth", AuthService, jwt=self._jwt, cache=self._redis) 20 | 21 | def _from_cache[S](self, key: str, factory: Callable[..., S], **kwargs: Any) -> S: 22 | if not (cached := self._cache.get(key)): 23 | cached = factory(**kwargs) 24 | self._cache[key] = cached 25 | 26 | return cached 27 | 28 | 29 | __all__ = ("InternalServiceGateway",) 30 | -------------------------------------------------------------------------------- /src/api/v1/handlers/auth/refresh.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from src.api.common.interfaces.handler import Handler 4 | from src.api.v1 import dtos 5 | from src.api.v1.dtos.base import DTO 6 | from src.common.exceptions import UnAuthorizedError 7 | from src.services import InternalServiceGateway 8 | from src.services.internal.auth import TokensExpire 9 | 10 | 11 | class RefreshTokenQuery(DTO): 12 | data: dtos.Fingerprint 13 | refresh_token: str 14 | 15 | 16 | @dataclass(slots=True) 17 | class RefreshTokenHandler(Handler[RefreshTokenQuery, TokensExpire]): 18 | internal_gateway: InternalServiceGateway 19 | 20 | async def __call__(self, query: RefreshTokenQuery) -> TokensExpire: 21 | if not (refresh_token := query.refresh_token): 22 | raise UnAuthorizedError("Not allowed") 23 | 24 | return await self.internal_gateway.auth.verify_refresh( 25 | query.data.fingerprint, refresh_token 26 | ) 27 | -------------------------------------------------------------------------------- /src/api/v1/handlers/auth/logout.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import uuid_utils.compat as uuid 4 | 5 | from src.api.common.interfaces.handler import Handler 6 | from src.api.v1 import dtos 7 | from src.api.v1.dtos.base import DTO 8 | from src.common.exceptions import UnAuthorizedError 9 | from src.services import InternalServiceGateway 10 | 11 | 12 | class LogoutQuery(DTO): 13 | user_uuid: uuid.UUID 14 | refresh_token: str 15 | 16 | 17 | @dataclass(slots=True) 18 | class LogoutHandler(Handler[LogoutQuery, dtos.Status]): 19 | internal_gateway: InternalServiceGateway 20 | 21 | async def __call__(self, query: LogoutQuery) -> dtos.Status: 22 | if not (refresh_token := query.refresh_token): 23 | raise UnAuthorizedError("Not allowed") 24 | 25 | result = await self.internal_gateway.auth.invalidate_refresh( 26 | refresh_token, query.user_uuid 27 | ) 28 | return dtos.Status(status=result) 29 | -------------------------------------------------------------------------------- /src/api/common/middlewares/proccess_time.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from litestar.datastructures import MutableScopeHeaders 4 | from litestar.enums import ScopeType 5 | from litestar.middleware import ASGIMiddleware 6 | from litestar.types import ASGIApp, Message, Receive, Scope, Send 7 | 8 | 9 | class ProcessTimeHeader(ASGIMiddleware): 10 | scopes = (ScopeType.HTTP,) 11 | 12 | async def handle( 13 | self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp 14 | ) -> None: 15 | start_time = time.perf_counter() 16 | 17 | async def send_wrapper(message: Message) -> None: 18 | if message["type"] == "http.response.start": 19 | process_time = time.perf_counter() - start_time 20 | headers = MutableScopeHeaders.from_message(message=message) 21 | headers["X-Process-Time"] = str(process_time) 22 | await send(message) 23 | 24 | await next_app(scope, receive, send_wrapper) 25 | -------------------------------------------------------------------------------- /src/services/provider/middleware/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing as t 3 | 4 | from src.services.provider.response import Response 5 | from src.services.provider.types import RequestMethodType 6 | 7 | 8 | class CallNextMiddlewareType(t.Protocol): 9 | async def __call__( 10 | self, method: RequestMethodType, url_or_endpoint: str, **kw: t.Any 11 | ) -> Response: ... 12 | 13 | 14 | class RequestMiddlewareType(t.Protocol): 15 | async def __call__( 16 | self, 17 | call_next: CallNextMiddlewareType, 18 | method: RequestMethodType, 19 | url_or_endpoint: str, 20 | **kw: t.Any, 21 | ) -> Response: ... 22 | 23 | 24 | class BaseRequestMiddleware(abc.ABC): 25 | @abc.abstractmethod 26 | async def __call__( 27 | self, 28 | call_next: CallNextMiddlewareType, 29 | method: RequestMethodType, 30 | url_or_endpoint: str, 31 | **kw: t.Any, 32 | ) -> Response: 33 | raise NotImplementedError 34 | -------------------------------------------------------------------------------- /src/api/v1/permission.py: -------------------------------------------------------------------------------- 1 | from litestar.connection import ASGIConnection 2 | from litestar.datastructures.state import State 3 | from litestar.handlers.base import BaseRouteHandler 4 | 5 | from src.api.v1 import dtos 6 | from src.common.exceptions import UnAuthorizedError 7 | from src.database.models.types import Roles 8 | 9 | 10 | class Permission: 11 | __slots__ = ("roles",) 12 | 13 | def __init__(self, *roles: Roles) -> None: 14 | self.roles = roles 15 | 16 | async def __call__( 17 | self, 18 | connection: ASGIConnection[BaseRouteHandler, dtos.User, None, State], 19 | _: BaseRouteHandler, 20 | ) -> None: 21 | valid_roles = self._ensure_valid_roles(connection.user) 22 | if not valid_roles: 23 | raise UnAuthorizedError("Not allowed") 24 | 25 | return 26 | 27 | def _ensure_valid_roles(self, user: dtos.User) -> bool: 28 | assert user.role is not None 29 | return any(role in user.role.name for role in self.roles) 30 | -------------------------------------------------------------------------------- /src/database/models/base/core.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Dict 3 | 4 | from sqlalchemy.ext.declarative import declared_attr 5 | from sqlalchemy.orm import DeclarativeBase 6 | 7 | 8 | class Base(DeclarativeBase): 9 | __abstract__: bool = True 10 | 11 | @declared_attr.directive 12 | def __tablename__(cls) -> str: 13 | return re.sub(r"(? Dict[str, Any]: 16 | result: Dict[str, Any] = {} 17 | for attr, value in self.__dict__.items(): 18 | if attr.startswith("_"): 19 | continue 20 | if isinstance(value, Base): 21 | result[attr] = value.as_dict() 22 | elif isinstance(value, list | tuple): 23 | result[attr] = type(value)( 24 | v.as_dict() if isinstance(v, Base) else v for v in value 25 | ) 26 | else: 27 | result[attr] = value 28 | 29 | return result 30 | -------------------------------------------------------------------------------- /src/api/v1/handlers/user/select.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import uuid_utils.compat as uuid 4 | 5 | from src.api.common.interfaces.handler import Handler 6 | from src.api.v1 import dtos 7 | from src.api.v1.dtos.base import DTO 8 | from src.database import DBGateway 9 | from src.database.repositories.types.user import UserLoads 10 | 11 | 12 | class SelectUserQuery(DTO): 13 | loads: tuple[UserLoads, ...] 14 | user_uuid: uuid.UUID | None = None 15 | login: str | None = None 16 | 17 | 18 | @dataclass(slots=True) 19 | class SelectUserHandler(Handler[SelectUserQuery, dtos.User]): 20 | database: DBGateway 21 | 22 | async def __call__(self, query: SelectUserQuery) -> dtos.User: 23 | async with self.database.manager.session: 24 | user = await self.database.user.select( 25 | *query.loads, 26 | user_uuid=query.user_uuid, 27 | login=query.login, 28 | ) 29 | 30 | return dtos.User.from_mapping(user.result().as_dict()) 31 | -------------------------------------------------------------------------------- /src/common/tools/files.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | ModeType = Literal[ 4 | "r", 5 | "w", 6 | "a", 7 | "rb", 8 | "wb", 9 | "ab", 10 | "r+", 11 | "w+", 12 | "a+", 13 | "rb+", 14 | "wb+", 15 | "ab+", 16 | ] 17 | 18 | 19 | def open_file_sync(path: str, mode: ModeType) -> Any: 20 | if mode not in { 21 | "r", 22 | "w", 23 | "a", 24 | "rb", 25 | "wb", 26 | "ab", 27 | "r+", 28 | "w+", 29 | "a+", 30 | "rb+", 31 | "wb+", 32 | "ab+", 33 | }: 34 | raise ValueError(f"Invalid mode '{mode}'. Must be a valid file access mode.") 35 | 36 | try: 37 | with open(path, mode) as file: 38 | if "r" in mode: 39 | return file.read() 40 | return 41 | except FileNotFoundError as e: 42 | raise FileNotFoundError(f"File not found: {path}") from e 43 | except OSError as e: 44 | raise OSError(f"Error accessing file {path}: {e}") from e 45 | -------------------------------------------------------------------------------- /migrations/versions/02_82e484eb7a8e_unique.py: -------------------------------------------------------------------------------- 1 | """unique 2 | 3 | Revision ID: 02_82e484eb7a8e 4 | Revises: 01_3d1a9bc30546 5 | Create Date: 2025-04-12 01:53:21.346825 6 | 7 | """ 8 | 9 | from typing import Sequence, Union 10 | 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision: str = "02_82e484eb7a8e" 15 | down_revision: Union[str, None] = "01_3d1a9bc30546" 16 | branch_labels: Union[str, Sequence[str], None] = None 17 | depends_on: Union[str, Sequence[str], None] = None 18 | 19 | 20 | def upgrade() -> None: 21 | # ### commands auto generated by Alembic - please adjust! ### 22 | op.drop_index("ix_user_login", table_name="user") 23 | op.create_index(op.f("ix_user_login"), "user", ["login"], unique=True) 24 | # ### end Alembic commands ### 25 | 26 | 27 | def downgrade() -> None: 28 | # ### commands auto generated by Alembic - please adjust! ### 29 | op.drop_index(op.f("ix_user_login"), table_name="user") 30 | op.create_index("ix_user_login", "user", ["login"], unique=False) 31 | # ### end Alembic commands ### 32 | -------------------------------------------------------------------------------- /src/database/repositories/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | from src.common.exceptions import ForbiddenError, NotFoundError 4 | 5 | type ExceptionType = Literal["update", "delete", "select", "create", "exists"] 6 | 7 | exception_template: dict[ExceptionType, Exception] = { 8 | "update": ForbiddenError("Cannot be updated"), 9 | "delete": ForbiddenError("Cannot be deleted"), 10 | "select": NotFoundError("Not found"), 11 | "create": ForbiddenError("Cannot be created"), 12 | "exists": NotFoundError("Not found"), 13 | } 14 | 15 | 16 | class Result[T]: 17 | __slots__ = ("data", "exception_type") 18 | 19 | def __init__(self, exception_type: ExceptionType, data: Optional[T]): 20 | self.data = data 21 | self.exception_type = exception_type 22 | 23 | def result_or_none(self) -> Optional[T]: 24 | return self.data 25 | 26 | def result(self) -> T: 27 | if self.data is not None: 28 | return self.data 29 | 30 | raise exception_template[self.exception_type] 31 | 32 | 33 | __all__ = ("Result",) 34 | -------------------------------------------------------------------------------- /src/api/common/middlewares/request_id.py: -------------------------------------------------------------------------------- 1 | from litestar.datastructures import MutableScopeHeaders 2 | from litestar.enums import ScopeType 3 | from litestar.middleware import ASGIMiddleware 4 | from litestar.types import ASGIApp, Message, Receive, Scope, Send 5 | from uuid_utils.compat import uuid4 6 | 7 | 8 | class XRequestIdMiddleware(ASGIMiddleware): 9 | scopes = (ScopeType.HTTP,) 10 | 11 | def __init__(self, header_name: str = "X-Request-Id") -> None: 12 | self.header_name = header_name 13 | 14 | async def handle( 15 | self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp 16 | ) -> None: 17 | request_id = uuid4().hex 18 | scope["state"]["request_id"] = request_id 19 | 20 | async def send_wrapper(message: Message) -> None: 21 | if message["type"] == "http.response.start": 22 | headers = MutableScopeHeaders.from_message(message=message) 23 | headers[self.header_name] = request_id 24 | 25 | await send(message) 26 | 27 | await next_app(scope, receive, send_wrapper) 28 | -------------------------------------------------------------------------------- /src/api/v1/handlers/user/update.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import uuid_utils.compat as uuid 4 | 5 | from src.api.common.interfaces.handler import Handler 6 | from src.api.v1 import dtos 7 | from src.api.v1.dtos.base import DTO 8 | from src.database import DBGateway 9 | from src.services.interfaces.hasher import AbstractHasher 10 | 11 | 12 | class UpdateUserQuery(DTO): 13 | user_uuid: uuid.UUID 14 | password: str | None = None 15 | 16 | 17 | @dataclass(slots=True) 18 | class UpdateUserHandler(Handler[UpdateUserQuery, dtos.User]): 19 | database: DBGateway 20 | hasher: AbstractHasher 21 | 22 | async def __call__(self, query: UpdateUserQuery) -> dtos.User: 23 | async with self.database: 24 | if query.password: 25 | query.password = self.hasher.hash_password(query.password) 26 | 27 | user = await self.database.user.update( 28 | query.user_uuid, 29 | **query.as_mapping(exclude_none=True, exclude={"user_uuid"}), 30 | ) 31 | 32 | return dtos.User.from_mapping(user.result().as_dict()) 33 | -------------------------------------------------------------------------------- /src/database/models/user.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import uuid 4 | from typing import TYPE_CHECKING 5 | 6 | import sqlalchemy as sa 7 | import sqlalchemy.orm as orm 8 | 9 | from src.database.models import Base 10 | from src.database.models.base import mixins 11 | 12 | if TYPE_CHECKING: 13 | from src.database.models import Role 14 | 15 | 16 | class User(mixins.UUIDMixin, mixins.TimeMixin, Base): 17 | login: orm.Mapped[str] = orm.mapped_column( 18 | sa.String, index=True, unique=True, nullable=False 19 | ) 20 | password: orm.Mapped[str] = orm.mapped_column(sa.String, nullable=False) 21 | active: orm.Mapped[bool] = orm.mapped_column( 22 | sa.Boolean, default=True, index=True, nullable=False 23 | ) 24 | 25 | role_uuid: orm.Mapped[uuid.UUID] = orm.mapped_column( 26 | sa.UUID(True), 27 | sa.ForeignKey("role.uuid", ondelete="RESTRICT", name="fk_user_role"), 28 | ) 29 | 30 | role: orm.Mapped[Role] = orm.relationship( 31 | "Role", 32 | back_populates="users", 33 | primaryjoin="User.role_uuid == Role.uuid", 34 | ) 35 | -------------------------------------------------------------------------------- /src/api/common/interfaces/dto.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from typing import Any, Protocol, Self, runtime_checkable 3 | 4 | 5 | @runtime_checkable 6 | class Serializable(Protocol): 7 | def as_string( 8 | self, exclude_none: bool = False, exclude: set[str] | None = None 9 | ) -> str: ... 10 | def as_mapping( 11 | self, exclude_none: bool = False, exclude: set[str] | None = None 12 | ) -> Mapping[str, Any]: ... 13 | def as_bytes( 14 | self, exclude_none: bool = False, exclude: set[str] | None = None 15 | ) -> bytes: ... 16 | 17 | 18 | @runtime_checkable 19 | class Deserializable(Protocol): 20 | @classmethod 21 | def from_string(cls, value: str) -> Self: ... 22 | @classmethod 23 | def from_mapping(cls, value: Mapping[str, Any]) -> Self: ... 24 | @classmethod 25 | def from_bytes(cls, value: bytes) -> Self: ... 26 | @classmethod 27 | def from_attributes(cls, value: Any) -> Self: ... 28 | 29 | 30 | @runtime_checkable 31 | class DTO(Serializable, Deserializable, Protocol): 32 | def copy(self, update: Mapping[str, Any] | None = None) -> Self: ... 33 | -------------------------------------------------------------------------------- /src/database/interfaces/gateway.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from types import TracebackType 4 | from typing import Optional, Protocol, Self, Type 5 | 6 | 7 | class AsyncContextManager(Protocol): 8 | async def __aexit__( 9 | self, 10 | exc_type: Optional[Type[BaseException]], 11 | exc_value: Optional[BaseException], 12 | traceback: Optional[TracebackType], 13 | ) -> None: ... 14 | 15 | async def __aenter__(self) -> AsyncContextManager: ... 16 | 17 | 18 | class BaseGateway: 19 | __slots__ = ("_context_manager",) 20 | 21 | def __init__(self, context_manager: AsyncContextManager) -> None: 22 | self._context_manager = context_manager 23 | 24 | async def __aenter__(self) -> Self: 25 | await self._context_manager.__aenter__() 26 | return self 27 | 28 | async def __aexit__( 29 | self, 30 | exc_type: Optional[type[BaseException]], 31 | exc_value: Optional[BaseException], 32 | traceback: Optional[TracebackType], 33 | ) -> None: 34 | await self._context_manager.__aexit__(exc_type, exc_value, traceback) 35 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .PHONY: upgrade 2 | upgrade: ## Run Alembic migrations 3 | alembic upgrade head 4 | 5 | .PHONY: downgrade 6 | downgrade: ## Rollback Alembic migrations 7 | alembic downgrade -1 8 | 9 | .PHONY: generate 10 | generate: 11 | alembic revision --autogenerate -m "$(NAME)" 12 | 13 | .PHONY: docker_build 14 | docker_build: ## Build Docker image 15 | docker compose build 16 | 17 | .PHONY: docker_rebuild 18 | docker_rebuild: ## Rebuild Docker image 19 | docker compose down 20 | docker compose build --no-cache 21 | 22 | .PHONY: docker_up 23 | docker_up: ## Run Docker container 24 | docker compose up -d 25 | 26 | .PHONY: docker_dev_up 27 | docker_dev_up: ## Run Docker container 28 | docker compose -f docker-compose.dev.yml up -d 29 | 30 | .PHONY: docker_dev_rebuild 31 | docker_dev_rebuild: ## Run Docker container 32 | docker compose -f docker-compose.dev.yml down 33 | docker compose -f docker-compose.dev.yml build --no-cache 34 | 35 | .PHONY: docker_dev_down 36 | docker_dev_down: ## Run Docker container 37 | docker compose -f docker-compose.dev.yml down 38 | 39 | .PHONY: docker_down 40 | docker_down: ## Stop Docker container 41 | docker compose down 42 | -------------------------------------------------------------------------------- /src/api/common/tools.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections.abc import Callable, Sequence 3 | from dataclasses import dataclass, field 4 | from typing import Any 5 | 6 | from litestar import Router 7 | from litestar.datastructures import State 8 | from litestar.types import LifespanHook 9 | 10 | 11 | class ClosableProxy: 12 | __slots__ = ( 13 | "_target", 14 | "_close_fn", 15 | ) 16 | 17 | def __init__(self, target: Any, close_fn: Callable[[], Any]) -> None: 18 | self._target = target 19 | self._close_fn = close_fn 20 | 21 | async def close(self) -> None: 22 | if inspect.isawaitable(self._close_fn) or inspect.iscoroutinefunction( 23 | self._close_fn 24 | ): 25 | await self._close_fn() 26 | else: 27 | self._close_fn() 28 | 29 | def __getattr__(self, key: str) -> Any: 30 | return getattr(self._target, key) 31 | 32 | def __repr__(self) -> str: 33 | return f"{self._target!r}" 34 | 35 | 36 | @dataclass(frozen=True) 37 | class RouterState: 38 | router: Router 39 | state: State 40 | on_startup: Sequence[LifespanHook] = field(default_factory=list) 41 | on_shutdown: Sequence[LifespanHook] = field(default_factory=list) 42 | -------------------------------------------------------------------------------- /src/api/common/converters.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import uuid 4 | from datetime import date, datetime, time, timedelta 5 | from decimal import Decimal 6 | from typing import Any 7 | 8 | import msgspec 9 | from litestar.serialization.msgspec_hooks import ( 10 | default_deserializer, 11 | default_serializer, 12 | ) 13 | 14 | 15 | def convert_to[T](cls: type[T], value: Any, **kw: Any) -> T: 16 | return msgspec.convert( 17 | value, 18 | cls, 19 | dec_hook=default_deserializer, 20 | builtin_types=( 21 | bytes, 22 | bytearray, 23 | datetime, 24 | time, 25 | date, 26 | timedelta, 27 | uuid.UUID, 28 | Decimal, 29 | ), 30 | **kw, 31 | ) 32 | 33 | 34 | def convert_from(value: Any, **kw: Any) -> Any: 35 | return msgspec.to_builtins( 36 | value, 37 | enc_hook=default_serializer, 38 | builtin_types=( 39 | datetime, 40 | date, 41 | timedelta, 42 | Decimal, 43 | uuid.UUID, 44 | bytes, 45 | bytearray, 46 | memoryview, 47 | time, 48 | ), 49 | **kw, 50 | ) 51 | -------------------------------------------------------------------------------- /src/api/common/events/base.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from datetime import datetime 3 | from typing import Any 4 | 5 | import msgspec 6 | from uuid_utils.compat import UUID, uuid4 7 | 8 | from src.api.common.converters import convert_from 9 | from src.api.common.serializers.msgspec import msgpack_encoder 10 | 11 | 12 | class Event(msgspec.Struct, kw_only=True): 13 | event_id: UUID = msgspec.field(default_factory=uuid4) 14 | event_timestamp: float = msgspec.field(default_factory=datetime.now().timestamp) 15 | 16 | def as_mapping( 17 | self, exclude_none: bool = False, exclude: set[str] | None = None 18 | ) -> Mapping[str, Any]: 19 | result: dict[str, Any] = convert_from(self) 20 | if exclude_none: 21 | result = {k: v for k, v in result.items() if v is not None} 22 | if exclude: 23 | result = {k: v for k, v in result.items() if k not in exclude} 24 | 25 | return result 26 | 27 | def as_bytes( 28 | self, exclude_none: bool = False, exclude: set[str] | None = None 29 | ) -> bytes: 30 | return msgpack_encoder( 31 | self.as_mapping(exclude_none=exclude_none, exclude=exclude) 32 | if exclude_none or exclude 33 | else self 34 | ) 35 | -------------------------------------------------------------------------------- /docker-compose.dev.yml: -------------------------------------------------------------------------------- 1 | services: 2 | redis: 3 | container_name: redis 4 | image: redis:7-alpine 5 | ports: 6 | - '63122:6379' 7 | 8 | postgres: 9 | container_name: postgres 10 | image: "postgres:15.6-alpine3.19" 11 | restart: unless-stopped 12 | ports: 13 | - '54522:5432' 14 | networks: 15 | - interconnect 16 | - postgres.network 17 | env_file: 18 | - "./.env" 19 | environment: 20 | POSTGRES_USER: ${DB_USER:-$USER} 21 | POSTGRES_DB: ${DB_NAME:-$USER} 22 | POSTGRES_PASSWORD: ${DB_PASSWORD} 23 | volumes: 24 | - postgres.data:/var/lib/postgresql/users:rw 25 | healthcheck: 26 | test: ["CMD-SHELL", "pg_isready -d $${POSTGRES_DB} -U $${POSTGRES_USER}"] 27 | interval: 10s 28 | timeout: 60s 29 | retries: 5 30 | start_period: 10s 31 | 32 | nats: 33 | container_name: 'nats' 34 | image: 'nats:latest' 35 | restart: 'unless-stopped' 36 | ports: 37 | - '4222:4222' 38 | - '6222:6222' 39 | - '8222:8222' 40 | command: 41 | - "--config" 42 | - "/etc/nats/nats.conf" 43 | volumes: 44 | - ./nats/nats.conf:/etc/nats/nats.conf 45 | - nats_data:/nats-assets/assets 46 | networks: 47 | - interconnect 48 | 49 | 50 | volumes: 51 | postgres.data: {} 52 | nats_data: {} 53 | networks: 54 | interconnect: 55 | driver: bridge 56 | postgres.network: {} 57 | -------------------------------------------------------------------------------- /src/common/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | 4 | class AppException(Exception): 5 | def __init__( 6 | self, 7 | message: str = "App Exception", 8 | headers: Optional[Dict[str, Any]] = None, 9 | ) -> None: 10 | self.content = {"message": message} 11 | self.headers = headers 12 | 13 | def as_dict(self) -> Dict[str, Any]: 14 | return self.__dict__ 15 | 16 | 17 | class DetailedError(AppException): 18 | def __init__( 19 | self, 20 | message: str, 21 | headers: Optional[Dict[str, Any]] = None, 22 | **additional: Any, 23 | ) -> None: 24 | super().__init__( 25 | message=message, 26 | headers=headers, 27 | ) 28 | self.content |= additional 29 | 30 | def __str__(self) -> str: 31 | return f"{type(self).__name__}: {self.content}\nHeaders: {self.headers or ''}" 32 | 33 | 34 | class UnAuthorizedError(DetailedError): ... 35 | 36 | 37 | class NotFoundError(DetailedError): ... 38 | 39 | 40 | class BadRequestError(DetailedError): ... 41 | 42 | 43 | class TooManyRequestsError(DetailedError): ... 44 | 45 | 46 | class ServiceUnavailableError(DetailedError): ... 47 | 48 | 49 | class BadGatewayError(DetailedError): ... 50 | 51 | 52 | class ForbiddenError(DetailedError): ... 53 | 54 | 55 | class ServiceNotImplementedError(DetailedError): ... 56 | 57 | 58 | class ConflictError(DetailedError): ... 59 | -------------------------------------------------------------------------------- /src/api/v1/handlers/user/select_many.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import uuid_utils.compat as uuid 4 | 5 | from src.api.common.interfaces.handler import Handler 6 | from src.api.v1 import dtos 7 | from src.api.v1.dtos.base import DTO 8 | from src.database import DBGateway 9 | from src.database.repositories.types.user import UserLoads 10 | from src.database.types import OrderBy 11 | 12 | 13 | class SelectManyUserQuery(DTO): 14 | loads: tuple[UserLoads, ...] 15 | login: str | None = None 16 | role_uuid: uuid.UUID | None = None 17 | order_by: OrderBy = "desc" 18 | offset: int = 0 19 | limit: int = 10 20 | 21 | 22 | @dataclass(slots=True) 23 | class SelectManyUserHandler(Handler[SelectManyUserQuery, dtos.OffsetResult[dtos.User]]): 24 | database: DBGateway 25 | 26 | async def __call__( 27 | self, query: SelectManyUserQuery 28 | ) -> dtos.OffsetResult[dtos.User]: 29 | async with self.database.manager.session: 30 | total, users = ( 31 | await self.database.user.select_many( 32 | *query.loads, 33 | **query.as_mapping(exclude={"loads"}), 34 | ) 35 | ).result() 36 | 37 | return dtos.OffsetResult[dtos.User]( 38 | data=[dtos.User.from_mapping(user.as_dict()) for user in users], 39 | offset=query.offset, 40 | limit=query.limit, 41 | total=total, 42 | ) 43 | -------------------------------------------------------------------------------- /src/services/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Callable 3 | 4 | from src.services.cache.redis import RedisCache 5 | from src.services.security.jwt import JWT 6 | from src.settings.core import Settings 7 | 8 | from .external import ExternalServiceGateway 9 | from .internal import InternalServiceGateway 10 | from .provider.base import AsyncProvider 11 | 12 | 13 | @dataclass(slots=True) 14 | class ServiceFactory: 15 | provider: AsyncProvider 16 | settings: Settings 17 | jwt: JWT 18 | redis: RedisCache 19 | _cache: dict[str, Any] = field(default_factory=dict) 20 | 21 | def internal(self) -> InternalServiceGateway: 22 | return self._from_cache( 23 | "internal", 24 | InternalServiceGateway, 25 | jwt=self.jwt, 26 | redis=self.redis, 27 | ) 28 | 29 | def external(self) -> ExternalServiceGateway: 30 | return self._from_cache( 31 | "external", 32 | ExternalServiceGateway, 33 | provider=self.provider, 34 | settings=self.settings, 35 | ) 36 | 37 | def _from_cache[S](self, key: str, factory: Callable[..., S], **kwargs: Any) -> S: 38 | if not (cached := self._cache.get(key)): 39 | cached = factory(**kwargs) 40 | self._cache[key] = cached 41 | 42 | return cached 43 | 44 | 45 | __all__ = ( 46 | "ServiceFactory", 47 | "InternalServiceGateway", 48 | "ExternalServiceGateway", 49 | ) 50 | -------------------------------------------------------------------------------- /src/services/provider/response.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from types import TracebackType 3 | from typing import ( 4 | Any, 5 | Optional, 6 | Protocol, 7 | ) 8 | 9 | 10 | class JsonResponse(Protocol): 11 | async def json(self, **kwargs: Any) -> dict[Any, Any]: ... 12 | 13 | 14 | class TextResponse(Protocol): 15 | async def text(self, **kwargs: Any) -> str: ... 16 | 17 | 18 | class BytesResponse(Protocol): 19 | async def read(self) -> bytes: ... 20 | 21 | 22 | class UrlResponse(Protocol): 23 | @property 24 | def url(self) -> str: ... 25 | 26 | 27 | class StatusResponse(Protocol): 28 | @property 29 | def status(self) -> int: ... 30 | 31 | 32 | class HeadersResponse(Protocol): 33 | @property 34 | def headers(self) -> Mapping[str, Any]: ... 35 | 36 | 37 | class CookiesResponse(Protocol): 38 | @property 39 | def cookies(self) -> Mapping[str, Any]: ... 40 | 41 | 42 | class ContextManagerResponse(Protocol): 43 | async def __aenter__(self) -> "ContextManagerResponse": ... 44 | 45 | async def __aexit__( 46 | self, 47 | exc_type: Optional[type[BaseException]], 48 | exc_value: Optional[BaseException], 49 | traceback: Optional[TracebackType], 50 | ) -> None: ... 51 | 52 | 53 | class Response( 54 | JsonResponse, 55 | TextResponse, 56 | BytesResponse, 57 | HeadersResponse, 58 | CookiesResponse, 59 | UrlResponse, 60 | StatusResponse, 61 | ContextManagerResponse, 62 | ): 63 | pass 64 | -------------------------------------------------------------------------------- /src/api/common/mediator/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Optional, cast 3 | 4 | from src.api.common.interfaces.handler import HandlerType 5 | 6 | type Dependency = Callable[[], Any] | Any 7 | 8 | 9 | def _resolve_factory[T](handler: Callable[[], T] | T, compare_with: type[Any]) -> T: 10 | if isinstance(handler, compare_with): 11 | return cast(T, handler) 12 | 13 | return handler() if callable(handler) else handler 14 | 15 | 16 | def _predict_dependency_or_raise( 17 | provided: dict[str, Any], 18 | required: dict[str, Any], 19 | non_checkable: Optional[set[str]] = None, 20 | ) -> dict[str, Any]: 21 | non_checkable = non_checkable or set() 22 | missing = [k for k in provided if k not in required and k not in non_checkable] 23 | if missing: 24 | missing_details = ", ".join(f"`{k}`:`{provided[k]}`" for k in missing) 25 | raise TypeError(f"Did you forget to set dependency for {missing_details}?") 26 | 27 | return {k: required.get(k, provided[k]) for k in provided} 28 | 29 | 30 | def _create_handler_factory[D: Dependency]( 31 | handler: type[HandlerType], **dependencies: D 32 | ) -> Callable[[], HandlerType]: 33 | def _factory() -> HandlerType: 34 | return handler( 35 | **{k: v() if callable(v) else v for k, v in dependencies.items()} 36 | ) 37 | 38 | return _factory 39 | 40 | 41 | def _retrieve_handler_params(handler: type[HandlerType]) -> dict[str, Any]: 42 | return {k: v.annotation for k, v in inspect.signature(handler).parameters.items()} 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "template" 3 | version = "0.1.0" 4 | description = "" 5 | authors = [{ name = "ocbunknown" }] 6 | requires-python = "~=3.12" 7 | dependencies = [ 8 | "litestar==2.15.2", 9 | "uvloop>=0.21.0,<0.22", 10 | "granian>=2.2.0,<3", 11 | "pydantic-settings>=2.8.1,<3", 12 | "uuid-utils>=0.10.0,<0.11", 13 | "redis>=5.2.1,<6", 14 | "types-redis>=4.6.0.20241004,<5", 15 | "dishka>=1.5.2,<2", 16 | "nats-py>=2.10.0,<3", 17 | "aiohttp>=3.11.15,<4", 18 | "pydantic>=2.11.1,<3", 19 | "argon2-cffi>=23.1.0,<24", 20 | "pyjwt>=2.10.1,<3", 21 | "sqlalchemy>=2.0.40,<3", 22 | "alembic>=1.15.2,<2", 23 | "asyncpg>=0.30.0,<0.31", 24 | "greenlet>=3.1.1", 25 | ] 26 | 27 | 28 | [dependency-groups] 29 | dev = [ 30 | "mypy>=1.15.0,<2", 31 | "pre-commit>=4.2.0", 32 | "pytest>=8.3.5", 33 | "ruff>=0.11.2,<0.12", 34 | "testcontainers>=4.10.0", 35 | ] 36 | 37 | 38 | [tool.mypy] 39 | warn_unused_ignores = false 40 | follow_imports_for_stubs = true 41 | pretty = true 42 | show_absolute_path = true 43 | hide_error_codes = false 44 | show_error_context = true 45 | strict = true 46 | warn_unreachable = true 47 | warn_no_return = true 48 | 49 | [tool.ruff] 50 | lint.ignore = ["E501", "B008", "C901", "W191", "UP007", "UP006", "UP035"] 51 | 52 | lint.select = ["E", "W", "F", "I", "C", "B", "UP"] 53 | 54 | [tool.pytest.ini_options] 55 | pythonpath = [".", "src"] 56 | testpaths = ["tests"] 57 | asyncio_mode = "auto" 58 | asyncio_default_fixture_loop_scope = "function" 59 | addopts = ["-p no:warnings", ""] 60 | -------------------------------------------------------------------------------- /src/database/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | from src.database import models 4 | from src.database.connection import SessionFactoryType 5 | from src.database.interfaces.gateway import BaseGateway 6 | from src.database.manager import TransactionManager 7 | from src.database.repositories.role import RoleRepository 8 | from src.database.repositories.user import UserRepository 9 | 10 | 11 | class DBGateway(BaseGateway): 12 | __slots__ = ("manager", "_cache") 13 | 14 | def __init__(self, manager: TransactionManager) -> None: 15 | super().__init__(manager) 16 | self.manager = manager 17 | self._cache: dict[str, Any] = {} 18 | 19 | @property 20 | def user(self) -> UserRepository: 21 | return self._from_cache("user", UserRepository, model=models.User) 22 | 23 | @property 24 | def role(self) -> RoleRepository: 25 | return self._from_cache("role", RoleRepository, model=models.Role) 26 | 27 | def _from_cache[S](self, key: str, factory: Callable[..., S], **kwargs: Any) -> S: 28 | if not (cached := self._cache.get(key)): 29 | cached = factory(self.manager.session, **kwargs) 30 | self._cache[key] = cached 31 | 32 | return cached 33 | 34 | 35 | def create_database_factory( 36 | manager: type[TransactionManager], session_factory: SessionFactoryType 37 | ) -> Callable[[], DBGateway]: 38 | def _create() -> DBGateway: 39 | return DBGateway(manager(session_factory())) 40 | 41 | return _create 42 | 43 | 44 | __all__ = ("DBGateway", "create_database_factory") 45 | -------------------------------------------------------------------------------- /src/database/_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Iterator, Mapping 4 | from typing import Any 5 | 6 | 7 | def _filter_none(values: list[tuple[str, Any]]) -> dict[str, Any]: 8 | return {k: v for k, v in values if v is not None} 9 | 10 | 11 | class frozendict[K, V](Mapping[K, V]): 12 | def __init__(self, *args: Any, **kwargs: Any) -> None: 13 | self._dict = dict(*args, **kwargs) 14 | self._hash: int | None = None 15 | 16 | def __getitem__(self, key: Any) -> Any: 17 | return self._dict[key] 18 | 19 | def __contains__(self, key: Any) -> Any: 20 | return key in self._dict 21 | 22 | def copy(self, **add_or_replace: Any) -> frozendict[K, V]: 23 | return type(self)(self, **add_or_replace) 24 | 25 | def __iter__(self) -> Iterator[K]: 26 | return iter(self._dict) 27 | 28 | def __len__(self) -> int: 29 | return len(self._dict) 30 | 31 | def __repr__(self) -> str: 32 | return f"<{type(self).__name__} {self._dict!r}" 33 | 34 | def __eq__(self, other: Any) -> bool: 35 | if isinstance(other, frozendict): 36 | return self._dict == other._dict 37 | 38 | if isinstance(other, dict): 39 | return self._dict == other 40 | 41 | return NotImplemented 42 | 43 | def __hash__(self) -> int: 44 | if self._hash is None: 45 | h = 0 46 | for key, value in self._dict.items(): 47 | h ^= hash((key, value)) 48 | self._hash = h 49 | 50 | return self._hash 51 | -------------------------------------------------------------------------------- /src/services/external/client.py: -------------------------------------------------------------------------------- 1 | from types import TracebackType 2 | from typing import Any, Optional, Self, Unpack 3 | from urllib.parse import urljoin 4 | 5 | from src.services.external.request import Request 6 | from src.services.provider.aiohttp import AiohttpProvider, ParamsType 7 | from src.services.provider.base import AsyncProvider 8 | 9 | 10 | class Client: 11 | __slots__ = ("_provider", "_url") 12 | 13 | def __init__( 14 | self, 15 | url: str = "", 16 | provider: AsyncProvider | None = None, 17 | *, 18 | proxy: str | None, 19 | **options: Unpack[ParamsType], 20 | ) -> None: 21 | self._provider = provider or AiohttpProvider(url, proxy, **options) 22 | self._url = url 23 | 24 | @property 25 | def url(self) -> str: 26 | return self._provider.url or self._url 27 | 28 | def endpoint_url(self, endpoint: str) -> str: 29 | return urljoin(self.url, endpoint) 30 | 31 | async def send[R](self, request: Request[R], /, **kwargs: Any) -> R: 32 | return await request(self._provider, **kwargs) 33 | 34 | async def __call__[R](self, request: Request[R], /, **kwargs: Any) -> R: 35 | return await self.send(request, **kwargs) 36 | 37 | async def __aenter__(self) -> Self: 38 | return self 39 | 40 | async def __aexit__( 41 | self, 42 | exc_type: Optional[type[BaseException]], 43 | exc_value: Optional[BaseException], 44 | traceback: Optional[TracebackType], 45 | ) -> None: 46 | await self._provider.__aexit__(exc_type, exc_value, traceback) 47 | -------------------------------------------------------------------------------- /src/services/provider/middleware/manager.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from functools import partial 3 | 4 | from src.services.provider.middleware.base import ( 5 | CallNextMiddlewareType, 6 | RequestMiddlewareType, 7 | ) 8 | 9 | 10 | class RequestMiddlewareManager(t.Sequence[RequestMiddlewareType]): 11 | def __init__(self, *middlewares: RequestMiddlewareType) -> None: 12 | self._middlewares: t.List[RequestMiddlewareType] = list(middlewares) 13 | 14 | def register( 15 | self, 16 | middleware: RequestMiddlewareType, 17 | ) -> RequestMiddlewareType: 18 | self._middlewares.append(middleware) 19 | return middleware 20 | 21 | __call__ = register 22 | 23 | def unregister(self, middleware: RequestMiddlewareType) -> None: 24 | self._middlewares.remove(middleware) 25 | 26 | @t.overload 27 | def __getitem__(self, item: int) -> RequestMiddlewareType: ... 28 | @t.overload 29 | def __getitem__(self, item: slice) -> t.Sequence[RequestMiddlewareType]: ... 30 | def __getitem__( 31 | self, item: t.Union[int, slice] 32 | ) -> t.Union[RequestMiddlewareType, t.Sequence[RequestMiddlewareType]]: 33 | return self._middlewares[item] 34 | 35 | def __len__(self) -> int: 36 | return len(self._middlewares) 37 | 38 | def wrap_middleware( 39 | self, callback: CallNextMiddlewareType, **kw: t.Any 40 | ) -> CallNextMiddlewareType: 41 | middleware = partial(callback, **kw) 42 | 43 | for m in reversed(self._middlewares): 44 | middleware = partial(m, middleware) 45 | 46 | return middleware 47 | -------------------------------------------------------------------------------- /src/services/security/argon2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import Any, Dict, Literal 3 | 4 | from argon2 import Parameters, PasswordHasher 5 | from argon2.exceptions import VerificationError, VerifyMismatchError 6 | from argon2.profiles import ( 7 | CHEAPEST, 8 | PRE_21_2, 9 | RFC_9106_HIGH_MEMORY, 10 | RFC_9106_LOW_MEMORY, 11 | ) 12 | 13 | from src.services.interfaces.hasher import AbstractHasher 14 | 15 | ProfileType = Literal[ 16 | "RFC_9106_LOW_MEMORY", "RFC_9106_HIGH_MEMORY", "CHEAPEST", "PRE_21_2", "DEFAULT" 17 | ] 18 | PROFILES: Dict[str, Parameters] = { 19 | "RFC_9106_LOW_MEMORY": RFC_9106_LOW_MEMORY, 20 | "RFC_9106_HIGH_MEMORY": RFC_9106_HIGH_MEMORY, 21 | "CHEAPEST": CHEAPEST, 22 | "PRE_21_2": PRE_21_2, 23 | } 24 | 25 | 26 | class Argon2(AbstractHasher): 27 | __slots__ = ("_hasher",) 28 | 29 | def __init__(self, hasher: PasswordHasher) -> None: 30 | self._hasher = hasher 31 | 32 | def hash_password(self, plain: str) -> str: 33 | return self._hasher.hash(plain) 34 | 35 | def verify_password(self, hashed: str, plain: str) -> bool: 36 | try: 37 | return self._hasher.verify(hashed, plain) 38 | except (VerificationError, VerifyMismatchError): 39 | return False 40 | 41 | 42 | def get_argon2_hasher(profile: ProfileType = "DEFAULT", **kwargs: Any) -> Argon2: 43 | if profile == "DEFAULT": # only need if something gonna change in argon2 module 44 | kw = {} 45 | else: 46 | kw = asdict(PROFILES[profile]) 47 | kw.pop("version", None) 48 | 49 | return Argon2(PasswordHasher(**(kw | kwargs))) 50 | -------------------------------------------------------------------------------- /src/database/interfaces/crud.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections.abc import Mapping, Sequence 3 | from typing import Any, Optional 4 | 5 | 6 | class AbstractCRUDRepository[EntryType](abc.ABC): 7 | __slots__ = ("model",) 8 | 9 | def __init__(self, model: type[EntryType]) -> None: 10 | self.model = model 11 | 12 | @abc.abstractmethod 13 | async def insert(self, **values: Mapping[str, Any]) -> Optional[EntryType]: 14 | raise NotImplementedError 15 | 16 | @abc.abstractmethod 17 | async def insert_many( 18 | self, values: Sequence[Mapping[str, Any]] 19 | ) -> Sequence[EntryType]: 20 | raise NotImplementedError 21 | 22 | @abc.abstractmethod 23 | async def select(self, *clauses: Any) -> Optional[EntryType]: 24 | raise NotImplementedError 25 | 26 | @abc.abstractmethod 27 | async def select_many( 28 | self, *clauses: Any, offset: Optional[int] = None, limit: Optional[int] = None 29 | ) -> Sequence[EntryType]: 30 | raise NotImplementedError 31 | 32 | @abc.abstractmethod 33 | async def update( 34 | self, *clauses: Any, **values: Mapping[str, Any] 35 | ) -> Sequence[EntryType]: 36 | raise NotImplementedError 37 | 38 | @abc.abstractmethod 39 | async def update_many(self, values: Sequence[Mapping[str, Any]]) -> Any: 40 | raise NotImplementedError 41 | 42 | @abc.abstractmethod 43 | async def delete(self, *clauses: Any) -> Sequence[EntryType]: 44 | raise NotImplementedError 45 | 46 | async def exists(self, *clauses: Any) -> bool: 47 | raise NotImplementedError 48 | 49 | async def count(self, *clauses: Any) -> int: 50 | raise NotImplementedError 51 | -------------------------------------------------------------------------------- /src/services/provider/middleware/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import typing as t 5 | 6 | from src.services.provider.middleware.base import ( 7 | BaseRequestMiddleware, 8 | CallNextMiddlewareType, 9 | ) 10 | from src.services.provider.response import Response 11 | from src.services.provider.types import RequestMethodType 12 | 13 | 14 | class RequestLoggingMiddleware(BaseRequestMiddleware): 15 | __slots__ = ( 16 | "logger", 17 | "detailed", 18 | ) 19 | 20 | def __init__( 21 | self, 22 | name: str = "LOG", 23 | level: logging._Level = "INFO", 24 | detailed: bool = False, 25 | custom_logger: t.Optional[logging.Logger] = None, 26 | ) -> None: 27 | self.detailed = detailed 28 | self.logger = custom_logger or self._setup_logger(name=name, level=level) 29 | 30 | def _setup_logger(self, name: str, level: logging._Level) -> logging.Logger: 31 | logger = logging.getLogger(name) 32 | logger.setLevel(level) 33 | logger.addHandler(logging.StreamHandler()) 34 | 35 | return logger 36 | 37 | async def __call__( 38 | self, 39 | call_next: CallNextMiddlewareType, 40 | method: RequestMethodType, 41 | url_or_endpoint: str, 42 | **kw: t.Any, 43 | ) -> Response: 44 | if not self.detailed: 45 | msg = f"Making request with method: {method}, url: {url_or_endpoint}" 46 | else: 47 | msg = f"Making request with:\nMethod: {method}\nUrl: {url_or_endpoint}\nParams: {kw}" 48 | 49 | self.logger.info(msg) 50 | 51 | return await call_next(method=method, url_or_endpoint=url_or_endpoint, **kw) 52 | -------------------------------------------------------------------------------- /src/services/security/aes.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | 4 | from cryptography.hazmat.backends import default_backend 5 | from cryptography.hazmat.primitives import padding 6 | from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 7 | 8 | from src.services.interfaces.encrypt import AbstractEncrypt 9 | 10 | 11 | class AESCipher(AbstractEncrypt): 12 | def __init__(self, key: str) -> None: 13 | self._key = self._decode_key(key) 14 | 15 | def encrypt(self, plaintext: str) -> bytes: 16 | iv = os.urandom(16) 17 | 18 | cipher = Cipher( 19 | algorithms.AES(self._key), modes.CBC(iv), backend=default_backend() 20 | ) 21 | encryptor = cipher.encryptor() 22 | 23 | padder = padding.PKCS7(algorithms.AES.block_size).padder() 24 | padded_data = padder.update(plaintext.encode()) + padder.finalize() 25 | 26 | ciphertext = encryptor.update(padded_data) + encryptor.finalize() 27 | 28 | return iv + ciphertext 29 | 30 | def decrypt(self, ciphertext: bytes) -> str: 31 | iv = ciphertext[:16] 32 | actual_ciphertext = ciphertext[16:] 33 | 34 | cipher = Cipher( 35 | algorithms.AES(self._key), modes.CBC(iv), backend=default_backend() 36 | ) 37 | decryptor = cipher.decryptor() 38 | 39 | padded_plaintext = decryptor.update(actual_ciphertext) + decryptor.finalize() 40 | 41 | unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() 42 | plaintext = unpadder.update(padded_plaintext) + unpadder.finalize() 43 | 44 | return plaintext.decode() 45 | 46 | def _decode_key(self, encoded_key: str) -> bytes: 47 | return base64.b64decode(encoded_key) 48 | -------------------------------------------------------------------------------- /src/api/common/bus/core.py: -------------------------------------------------------------------------------- 1 | from typing import Self, cast, get_args 2 | 3 | from src.api.common.events.base import Event 4 | from src.api.common.interfaces.broker import BrokerType 5 | from src.common.tools.types import is_typevar 6 | 7 | 8 | class EventBusImpl: 9 | __slots__ = ("_brokers_registry", "_brokers") 10 | 11 | def __init__(self) -> None: 12 | self._brokers_registry: dict[type[Event], BrokerType] = {} 13 | self._brokers: set[BrokerType] = set() 14 | 15 | async def publish(self, event: Event) -> None: 16 | broker = self._resolve_broker(event) 17 | await broker.publish(broker._build_message(event)) 18 | 19 | def middleware(self) -> Self: 20 | return self 21 | 22 | @classmethod 23 | def builder(cls) -> Self: 24 | return cls() 25 | 26 | def brokers(self, *brokers: BrokerType) -> Self: 27 | self._brokers = set(brokers) 28 | return self 29 | 30 | def build(self) -> Self: 31 | for broker in self._brokers: 32 | self._brokers_registry[self._resolve_event(broker)] = broker 33 | 34 | return self 35 | 36 | def _resolve_broker(self, event: Event) -> BrokerType: 37 | event_subclasses = Event.__subclasses__() 38 | for event_subclass in event_subclasses: 39 | if issubclass(type(event), event_subclass): 40 | return self._brokers_registry[event_subclass] 41 | 42 | raise ValueError(f"Broker for event {event} not found") 43 | 44 | def _resolve_event(self, broker: BrokerType) -> type[Event]: 45 | event = get_args(broker.__orig_bases__[0])[0] # type: ignore 46 | if is_typevar(event): 47 | raise TypeError(f"Event type {event} is a TypeVar") 48 | 49 | return cast(type[Event], event) 50 | -------------------------------------------------------------------------------- /src/api/v1/handlers/auth/confirm.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from src.api.common.interfaces.handler import Handler 4 | from src.api.v1 import dtos 5 | from src.api.v1.dtos.base import DTO 6 | from src.common.exceptions import NotFoundError 7 | from src.common.tools.cache import default_key_builder 8 | from src.database import DBGateway 9 | from src.services import InternalServiceGateway 10 | from src.services.cache.redis import RedisCache 11 | from src.services.interfaces.hasher import AbstractHasher 12 | from src.services.internal.auth import TokensExpire 13 | 14 | 15 | class ConfirmRegisterQuery(DTO): 16 | code: str 17 | 18 | 19 | @dataclass(slots=True) 20 | class ConfirmRegisterHandler(Handler[ConfirmRegisterQuery, TokensExpire]): 21 | internal_gateway: InternalServiceGateway 22 | database: DBGateway 23 | hasher: AbstractHasher 24 | redis: RedisCache 25 | 26 | async def __call__(self, query: ConfirmRegisterQuery) -> TokensExpire: 27 | key = default_key_builder(code=query.code) 28 | 29 | cached_data = await self.redis.get(key) 30 | if not cached_data: 31 | raise NotFoundError("Code has expire or invalid") 32 | 33 | await self.redis.delete(key) 34 | cache_user = dtos.Register.from_string(cached_data) 35 | 36 | async with self.database: 37 | role = (await self.database.role.select(name="User")).result() 38 | user = ( 39 | await self.database.user.create( 40 | login=cache_user.login, 41 | password=self.hasher.hash_password(cache_user.password), 42 | role_uuid=role.uuid, 43 | ) 44 | ).result() 45 | 46 | return await self.internal_gateway.auth.login(cache_user.fingerprint, user.uuid) 47 | -------------------------------------------------------------------------------- /src/api/v1/endpoints/user.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from litestar import ( 4 | Controller, 5 | Request, 6 | get, 7 | patch, 8 | status_codes, 9 | ) 10 | from litestar.datastructures import State 11 | from litestar.middleware.rate_limit import RateLimitConfig 12 | from litestar.params import Parameter 13 | 14 | from src.api.common.interfaces.mediator import Mediator 15 | from src.api.v1 import dtos, handlers 16 | from src.common.di import Depends 17 | from src.database.repositories.types.user import UserLoads 18 | 19 | 20 | class UserController(Controller): 21 | path = "/users" 22 | tags = ["Users"] 23 | security = [{"BearerToken": []}] 24 | 25 | @get("/me", status_code=status_codes.HTTP_200_OK) 26 | async def select_user_endpoint( 27 | self, 28 | s: Annotated[ 29 | tuple[UserLoads, ...], 30 | Parameter( 31 | required=False, 32 | default=(), 33 | description="Search for additional relations", 34 | ), 35 | ], 36 | request: Request[dtos.User, None, State], 37 | mediator: Depends[Mediator], 38 | ) -> dtos.User: 39 | return await mediator.send( 40 | handlers.user.SelectUserQuery(loads=s, user_uuid=request.user.uuid) 41 | ) 42 | 43 | @patch( 44 | status_code=status_codes.HTTP_200_OK, 45 | middleware=[RateLimitConfig(rate_limit=("minute", 5)).middleware], 46 | ) 47 | async def update_user_endpoint( 48 | self, 49 | data: dtos.UpdateUser, 50 | request: Request[dtos.User, None, State], 51 | mediator: Depends[Mediator], 52 | ) -> dtos.User: 53 | return await mediator.send( 54 | handlers.user.UpdateUserQuery( 55 | user_uuid=request.user.uuid, **data.as_mapping() 56 | ) 57 | ) 58 | -------------------------------------------------------------------------------- /src/api/v1/handlers/user/create.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Annotated 3 | 4 | import uuid_utils.compat as uuid 5 | from msgspec import Meta 6 | 7 | from src.api.common.interfaces.handler import Handler 8 | from src.api.v1 import dtos 9 | from src.api.v1.constants import ( 10 | MAX_LOGIN_LENGTH, 11 | MAX_PASSWORD_LENGTH, 12 | MIN_PASSWORD_LENGTH, 13 | ) 14 | from src.api.v1.tools.validate import validate_email 15 | from src.database import DBGateway 16 | from src.services.interfaces.hasher import AbstractHasher 17 | 18 | 19 | class CreateUserQuery(dtos.DTO): 20 | login: Annotated[ 21 | str, 22 | Meta( 23 | max_length=MAX_LOGIN_LENGTH, 24 | description=(f"Login maximum length `{MAX_LOGIN_LENGTH}`"), 25 | ), 26 | ] 27 | password: Annotated[ 28 | str, 29 | Meta( 30 | min_length=MIN_PASSWORD_LENGTH, 31 | max_length=MAX_PASSWORD_LENGTH, 32 | description=( 33 | f"Password between `{MIN_PASSWORD_LENGTH}` and " 34 | f"`{MAX_PASSWORD_LENGTH}` characters long" 35 | ), 36 | ), 37 | ] 38 | role_uuid: uuid.UUID 39 | 40 | def __post_init__(self) -> None: 41 | self.login = validate_email(self.login) 42 | 43 | 44 | @dataclass(slots=True) 45 | class CreateUserHandler(Handler[CreateUserQuery, dtos.User]): 46 | database: DBGateway 47 | hasher: AbstractHasher 48 | 49 | async def __call__(self, query: CreateUserQuery) -> dtos.User: 50 | async with self.database: 51 | user = await self.database.user.create( 52 | login=query.login, 53 | password=self.hasher.hash_password(query.password), 54 | role_uuid=query.role_uuid, 55 | ) 56 | 57 | return dtos.User.from_mapping(user.result().as_dict()) 58 | -------------------------------------------------------------------------------- /src/database/manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from types import TracebackType 4 | from typing import Optional, Union 5 | 6 | from sqlalchemy.exc import SQLAlchemyError 7 | from sqlalchemy.ext.asyncio import ( 8 | AsyncSession, 9 | AsyncSessionTransaction, 10 | async_sessionmaker, 11 | ) 12 | 13 | from src.database.exceptions import CommitError, RollbackError 14 | 15 | 16 | class TransactionManager: 17 | __slots__ = ( 18 | "session", 19 | "_transaction", 20 | ) 21 | 22 | def __init__( 23 | self, session_or_factory: Union[AsyncSession, async_sessionmaker[AsyncSession]] 24 | ) -> None: 25 | if isinstance(session_or_factory, async_sessionmaker): 26 | self.session = session_or_factory() 27 | else: 28 | self.session = session_or_factory 29 | 30 | self._transaction: Optional[AsyncSessionTransaction] = None 31 | 32 | async def __aexit__( 33 | self, 34 | exc_type: Optional[type[BaseException]], 35 | exc_value: Optional[BaseException], 36 | traceback: Optional[TracebackType], 37 | ) -> None: 38 | if self._transaction: 39 | if exc_type: 40 | await self.rollback() 41 | else: 42 | await self.commit() 43 | 44 | await self.close_transaction() 45 | 46 | async def __aenter__(self) -> TransactionManager: 47 | await self.create_transaction() 48 | return self 49 | 50 | async def commit(self) -> None: 51 | try: 52 | await self.session.commit() 53 | except SQLAlchemyError as err: 54 | raise CommitError from err 55 | 56 | async def rollback(self) -> None: 57 | try: 58 | await self.session.rollback() 59 | except SQLAlchemyError as err: 60 | raise RollbackError from err 61 | 62 | async def create_transaction(self) -> None: 63 | if not self.session.in_transaction() and self.session.is_active: 64 | self._transaction = await self.session.begin() 65 | 66 | async def close_transaction(self) -> None: 67 | if self.session.is_active: 68 | await self.session.close() 69 | -------------------------------------------------------------------------------- /src/api/v1/handlers/auth/login.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Annotated 3 | 4 | from msgspec import Meta 5 | 6 | from src.api.common.interfaces.handler import Handler 7 | from src.api.v1.constants import ( 8 | MAX_LOGIN_LENGTH, 9 | MAX_PASSWORD_LENGTH, 10 | MIN_PASSWORD_LENGTH, 11 | ) 12 | from src.api.v1.dtos.base import DTO 13 | from src.api.v1.tools.validate import validate_email 14 | from src.common.exceptions import ( 15 | ForbiddenError, 16 | UnAuthorizedError, 17 | ) 18 | from src.database import DBGateway 19 | from src.services import InternalServiceGateway 20 | from src.services.interfaces.hasher import AbstractHasher 21 | from src.services.internal.auth import TokensExpire 22 | 23 | 24 | class LoginQuery(DTO): 25 | login: Annotated[ 26 | str, 27 | Meta( 28 | max_length=MAX_LOGIN_LENGTH, 29 | description=(f"Login maximum length `{MAX_LOGIN_LENGTH}`"), 30 | ), 31 | ] 32 | password: Annotated[ 33 | str, 34 | Meta( 35 | min_length=MIN_PASSWORD_LENGTH, 36 | max_length=MAX_PASSWORD_LENGTH, 37 | description=( 38 | f"Password between `{MIN_PASSWORD_LENGTH}` and " 39 | f"`{MAX_PASSWORD_LENGTH}` characters long" 40 | ), 41 | ), 42 | ] 43 | fingerprint: str 44 | 45 | def __post_init__(self) -> None: 46 | self.login = validate_email(self.login) 47 | 48 | 49 | @dataclass(slots=True) 50 | class LoginHandler(Handler[LoginQuery, TokensExpire]): 51 | internal_gateway: InternalServiceGateway 52 | database: DBGateway 53 | hasher: AbstractHasher 54 | 55 | async def __call__(self, query: LoginQuery) -> TokensExpire: 56 | async with self.database.manager.session: 57 | user = (await self.database.user.select(login=query.login)).result() 58 | 59 | if not user.active: 60 | raise ForbiddenError("You have been blocked") 61 | if not self.hasher.verify_password(user.password, query.password): 62 | raise UnAuthorizedError("Incorrect password or login") 63 | 64 | return await self.internal_gateway.auth.login(query.fingerprint, user.uuid) 65 | -------------------------------------------------------------------------------- /src/services/security/jwt.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from datetime import UTC, datetime, timedelta 3 | from typing import ( 4 | Any, 5 | Literal, 6 | Optional, 7 | cast, 8 | ) 9 | 10 | import jwt 11 | 12 | from src.common.exceptions import ConflictError, UnAuthorizedError 13 | from src.settings.core import CipherSettings 14 | 15 | TokenType = Literal["access", "refresh"] 16 | 17 | 18 | class JWT: 19 | __slots__ = ("_settings",) 20 | 21 | def __init__(self, settings: CipherSettings) -> None: 22 | self._settings = settings 23 | 24 | def create( 25 | self, 26 | typ: TokenType, 27 | sub: str, 28 | expires_delta: Optional[timedelta] = None, 29 | **kw: Any, 30 | ) -> tuple[datetime, str]: 31 | now = datetime.now(UTC) 32 | if expires_delta: 33 | expire = now + expires_delta 34 | else: 35 | seconds_delta = ( 36 | self._settings.access_token_expire_seconds 37 | if typ == "access" 38 | else self._settings.refresh_token_expire_seconds 39 | ) 40 | expire = now + timedelta(seconds=seconds_delta) 41 | 42 | if now >= expire: 43 | raise ConflictError("Invalid expiration delta was provided") 44 | 45 | to_encode = { 46 | "exp": expire, 47 | "sub": sub, 48 | "iat": now, 49 | "type": typ, 50 | } 51 | try: 52 | token = jwt.encode( 53 | to_encode | kw, 54 | base64.b64decode(self._settings.secret_key), 55 | self._settings.algorithm, 56 | ) 57 | except jwt.PyJWTError as e: 58 | raise UnAuthorizedError("Token is expired") from e 59 | 60 | return expire, token 61 | 62 | def verify_token(self, token: str) -> dict[str, Any]: 63 | try: 64 | result = jwt.decode( 65 | token, 66 | base64.b64decode(self._settings.public_key), 67 | [self._settings.algorithm], 68 | ) 69 | except jwt.PyJWTError as e: 70 | raise UnAuthorizedError("Token is invalid or expired") from e 71 | 72 | return cast(dict[str, Any], result) 73 | -------------------------------------------------------------------------------- /src/api/common/broker/nats/core.py: -------------------------------------------------------------------------------- 1 | from nats.aio.client import Client as NatsClient 2 | from nats.js import JetStreamContext 3 | from src.api.common.broker.nats.message import NatsJetStreamMessage, NatsMessage 4 | from src.api.common.events.nats import NatsEvent, NatsJetStreamEvent 5 | from src.api.common.interfaces.broker import Broker 6 | 7 | 8 | class NatsBroker(Broker[NatsEvent, NatsMessage]): 9 | def __init__(self, nats: NatsClient) -> None: 10 | self.nats = nats 11 | 12 | async def publish(self, message: NatsMessage) -> None: 13 | await self.nats.publish( 14 | subject=message.subject, 15 | payload=message.payload, 16 | reply=message.reply, 17 | headers=message.headers, 18 | ) 19 | 20 | def _build_message(self, event: NatsEvent) -> NatsMessage: 21 | return NatsMessage( 22 | subject=event.subject, 23 | payload=event.as_bytes( 24 | exclude={ 25 | "subject", 26 | "_reply", 27 | "_headers", 28 | }, 29 | ), 30 | reply=event._reply, 31 | headers=event._headers or {}, 32 | ) 33 | 34 | 35 | class NatsJetStreamBroker(Broker[NatsJetStreamEvent, NatsJetStreamMessage]): 36 | def __init__(self, jetstream: JetStreamContext) -> None: 37 | self.js = jetstream 38 | 39 | async def publish(self, message: NatsJetStreamMessage) -> None: 40 | await self.js.publish( 41 | subject=message.subject, 42 | payload=message.payload, 43 | stream=message.stream, 44 | timeout=message.timeout, 45 | headers=message.headers, 46 | ) 47 | 48 | def _build_message(self, event: NatsJetStreamEvent) -> NatsJetStreamMessage: 49 | return NatsJetStreamMessage( 50 | subject=event.subject, 51 | payload=event.as_bytes( 52 | exclude={ 53 | "subject", 54 | "_reply", 55 | "_headers", 56 | "_stream", 57 | "_timeout", 58 | }, 59 | ), 60 | stream=event._stream, 61 | timeout=event._timeout, 62 | headers=event._headers or {}, 63 | ) 64 | -------------------------------------------------------------------------------- /src/api/common/mediator/core.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Self 2 | 3 | from src.api.common.interfaces.dto import DTO 4 | from src.api.common.interfaces.handler import Handler, HandlerType 5 | from src.api.common.mediator.utils import ( 6 | _create_handler_factory, 7 | _predict_dependency_or_raise, 8 | _resolve_factory, 9 | _retrieve_handler_params, 10 | ) 11 | 12 | type HandlerLike = Callable[[], HandlerType] | HandlerType 13 | 14 | 15 | class MediatorImpl: 16 | __slots__ = ("_handlers_registry", "_handlers_factory", "_dependencies") 17 | 18 | def __init__(self) -> None: 19 | self._handlers_registry: dict[type[DTO], HandlerLike] = {} 20 | self._handlers_factory: set[Callable[[Self], None]] = set() 21 | self._dependencies: dict[str, Any] = {} 22 | 23 | @classmethod 24 | def builder(cls) -> Self: 25 | return cls() 26 | 27 | def dependencies(self, **dependencies: Any) -> Self: 28 | self._dependencies = dependencies 29 | return self 30 | 31 | def handlers(self, *handlers_factory: Callable[[Self], None]) -> Self: 32 | self._handlers_factory = set(handlers_factory) 33 | return self 34 | 35 | def build(self) -> Self: 36 | for handler_factory in self._handlers_factory: 37 | handler_factory(self) 38 | return self 39 | 40 | def middleware(self) -> Self: 41 | return self 42 | 43 | def register[Q: DTO](self, query: type[Q], handler: type[HandlerType]) -> None: 44 | prepared_deps = _predict_dependency_or_raise( 45 | provided=_retrieve_handler_params(handler), 46 | required=self._dependencies, 47 | non_checkable={ 48 | "query", 49 | }, 50 | ) 51 | self._handlers_registry[query] = _create_handler_factory( 52 | handler, **prepared_deps 53 | ) 54 | 55 | async def send[Q: DTO, R: DTO](self, query: Q) -> R: 56 | handler: Handler[Q, R] = self._get_handler(query) 57 | return await handler(query) 58 | 59 | def _get_handler[Q: DTO, R: DTO](self, query: Q) -> Handler[Q, R]: 60 | try: 61 | return _resolve_factory(self._handlers_registry[type(query)], Handler) 62 | except KeyError as e: 63 | raise KeyError(f"Handler for `{type(query)}` is not registered") from e 64 | -------------------------------------------------------------------------------- /src/api/v1/integration/dishka.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Callable, override 3 | 4 | from dishka.integrations.litestar import inject, inject_websocket 5 | from litestar import Controller, Router 6 | from litestar.handlers import BaseRouteHandler, HTTPRouteHandler, WebsocketListener 7 | from litestar.handlers.websocket_handlers import WebsocketListenerRouteHandler 8 | from litestar.handlers.websocket_handlers._utils import ListenerHandler 9 | from litestar.routes import BaseRoute 10 | from litestar.types import ControllerRouterHandler 11 | 12 | 13 | def _inject_based_on_handler_type( 14 | value: BaseRouteHandler, 15 | ) -> BaseRouteHandler: 16 | if isinstance(value, HTTPRouteHandler): 17 | value._fn = inject(value._fn) 18 | 19 | if isinstance(value, WebsocketListenerRouteHandler) and isinstance( 20 | value._fn, 21 | ListenerHandler, 22 | ): 23 | value = value(inject_websocket(value._fn._fn)) 24 | 25 | return value 26 | 27 | 28 | def _inject_route_handlers[**P]( 29 | get_route_handlers: Callable[P, list[BaseRouteHandler]], 30 | ) -> Callable[P, list[BaseRouteHandler]]: 31 | @wraps(get_route_handlers) 32 | def _wrapper(*args: P.args, **kwargs: P.kwargs) -> list[BaseRouteHandler]: 33 | return [ 34 | _inject_based_on_handler_type(route) 35 | for route in get_route_handlers(*args, **kwargs) 36 | ] 37 | 38 | return _wrapper 39 | 40 | 41 | def _resolve_value( 42 | router: Router, 43 | value: ControllerRouterHandler, 44 | ) -> ControllerRouterHandler: 45 | if isinstance(value, Router): 46 | return value 47 | 48 | if isinstance(value, BaseRouteHandler): 49 | return _inject_based_on_handler_type(value) 50 | 51 | if isinstance(value, type): 52 | if issubclass(value, Controller): 53 | value.get_route_handlers = _inject_route_handlers( # type: ignore[method-assign] 54 | value.get_route_handlers, 55 | ) 56 | if issubclass(value, WebsocketListener): 57 | return _inject_based_on_handler_type(value(router).to_handler()) 58 | 59 | return value 60 | 61 | 62 | class DishkaRouter(Router): 63 | __slots__ = () 64 | 65 | @override 66 | def register(self, value: ControllerRouterHandler) -> list[BaseRoute]: 67 | return super().register(_resolve_value(self, value)) 68 | -------------------------------------------------------------------------------- /src/services/provider/errors.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | 4 | class BaseError(Exception): 5 | pass 6 | 7 | 8 | class DetailedError(BaseError): 9 | url: Optional[str] = None 10 | 11 | def __init__(self, message: str, content: Any, url: Optional[str] = None) -> None: 12 | self.message = message 13 | self.content = content 14 | self.url = url 15 | 16 | def __str__(self) -> str: 17 | message = self.message + f"\n{self.content}" 18 | if self.url: 19 | message += f"\n(background on this error at: {self.url})" 20 | return message 21 | 22 | def __repr__(self) -> str: 23 | return f"{type(self).__name__}('{self}')" 24 | 25 | 26 | class APIError(DetailedError): 27 | label: str = "Server says" 28 | 29 | def __init__( 30 | self, 31 | status_code: int, 32 | content: Any, 33 | message: str = "", 34 | url: Optional[str] = None, 35 | ) -> None: 36 | super().__init__(message=message, content=content, url=url) 37 | self.status_code = status_code 38 | 39 | def __str__(self) -> str: 40 | original_message = super().__str__() 41 | return f"Status: {self.status_code}: {self.label} - {original_message}" 42 | 43 | 44 | class ClientDecodeError(BaseError): 45 | def __init__(self, message: str, original: Exception, data: Any) -> None: 46 | self.message = message 47 | self.original = original 48 | self.data = data 49 | 50 | def __str__(self) -> str: 51 | original_type = type(self.original) 52 | return ( 53 | f"{self.message}\n" 54 | f"Caused from error: " 55 | f"{original_type.__module__}.{original_type.__name__}: {self.original}\n" 56 | f"Content: {self.data}" 57 | ) 58 | 59 | 60 | class NetworkError(BaseError): 61 | pass 62 | 63 | 64 | class BadRequestError(APIError): 65 | pass 66 | 67 | 68 | class NotFoundError(APIError): 69 | pass 70 | 71 | 72 | class ConflictError(APIError): 73 | pass 74 | 75 | 76 | class UnauthorizedError(APIError): 77 | pass 78 | 79 | 80 | class ForbiddenError(APIError): 81 | pass 82 | 83 | 84 | class EntityTooLarge(APIError): 85 | pass 86 | 87 | 88 | class ServerError(APIError): 89 | pass 90 | 91 | 92 | class NotValidMethodError(BaseError): 93 | pass 94 | 95 | 96 | class TooManyRequestsError(APIError): 97 | pass 98 | -------------------------------------------------------------------------------- /src/services/provider/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from collections.abc import AsyncIterator, Mapping 5 | from types import TracebackType 6 | from typing import ( 7 | Any, 8 | Final, 9 | Optional, 10 | ) 11 | 12 | from src.services.provider.middleware.base import RequestMiddlewareType 13 | from src.services.provider.middleware.manager import RequestMiddlewareManager 14 | from src.services.provider.response import Response 15 | from src.services.provider.types import RequestMethodType 16 | 17 | DEFAULT_CHUNK_SIZE: Final[int] = 65536 # 50 mb 18 | 19 | 20 | class AsyncProvider(abc.ABC): 21 | __slots__ = ( 22 | "url", 23 | "manager", 24 | ) 25 | 26 | def __init__( 27 | self, 28 | url: Optional[str] = None, 29 | middlewares: tuple[RequestMiddlewareType, ...] = (), 30 | ) -> None: 31 | self.url = url 32 | self.manager = RequestMiddlewareManager(*middlewares) 33 | 34 | async def __aenter__(self) -> AsyncProvider: 35 | return self 36 | 37 | async def __aexit__( 38 | self, 39 | exc_type: Optional[type[BaseException]], 40 | exc_value: Optional[BaseException], 41 | traceback: Optional[TracebackType], 42 | ) -> None: 43 | await self.close_session() 44 | 45 | async def __call__( 46 | self, method: RequestMethodType, url_or_endpoint: str = "", **kw: Any 47 | ) -> Response: 48 | middleware = self.manager.wrap_middleware(self.make_request) 49 | 50 | return await middleware(method=method, url_or_endpoint=url_or_endpoint, **kw) 51 | 52 | @abc.abstractmethod 53 | async def make_request( 54 | self, method: RequestMethodType, url_or_endpoint: str = "", **kw: Any 55 | ) -> Response: 56 | raise NotImplementedError 57 | 58 | @abc.abstractmethod 59 | async def close_session(self) -> None: 60 | raise NotImplementedError 61 | 62 | @abc.abstractmethod 63 | async def stream_content( 64 | self, 65 | url_or_endpoint: str, 66 | chunk_size: int = DEFAULT_CHUNK_SIZE, 67 | raise_for_status: bool = True, 68 | **kw: Any, 69 | ) -> AsyncIterator[bytes]: 70 | yield b"" 71 | 72 | @abc.abstractmethod 73 | def update_cookies(self, values: Mapping[str, Any]) -> None: 74 | raise NotImplementedError 75 | 76 | @abc.abstractmethod 77 | def update_headers(self, values: Mapping[str, Any]) -> None: 78 | raise NotImplementedError 79 | -------------------------------------------------------------------------------- /src/services/cache/redis.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any 3 | 4 | import redis.asyncio as aioredis 5 | 6 | from src.settings.core import RedisSettings 7 | 8 | 9 | class RedisCache: 10 | __slots__ = ("_redis",) 11 | 12 | def __init__(self, redis: aioredis.Redis) -> None: # type: ignore[type-arg] 13 | self._redis = redis 14 | 15 | async def get( 16 | self, 17 | key: str, 18 | ) -> str | None: 19 | return await self._redis.get(key) 20 | 21 | async def set( 22 | self, key: str, value: Any, expire: float | timedelta | None = None, **kw: Any 23 | ) -> None: 24 | await self._redis.set(key, value, ex=expire, **kw) 25 | 26 | async def delete(self, *keys: str) -> None: 27 | found_keys = [ 28 | found for key in keys async for found in self._redis.scan_iter(key) 29 | ] 30 | if found_keys: 31 | await self._redis.delete(*found_keys) 32 | 33 | async def set_list( 34 | self, key: str, *values: Any, expire: float | timedelta | None = None, **kw: Any 35 | ) -> None: 36 | await self._redis.lpush(key, *(v for v in values)) 37 | 38 | if expire: 39 | await self._redis.expire( 40 | key, 41 | expire if isinstance(expire, timedelta) else timedelta(seconds=expire), 42 | **kw, 43 | ) 44 | 45 | async def get_list( 46 | self, 47 | key: str, 48 | **kw: Any, 49 | ) -> list[str]: 50 | start, end = kw.pop("start", 0), kw.pop("end", -1) 51 | return await self._redis.lrange(key, start, end) 52 | 53 | async def discard( 54 | self, 55 | key: str, 56 | value: Any, 57 | **kw: Any, 58 | ) -> None: 59 | count = kw.pop("count", 0) 60 | await self._redis.lrem(key, count, value) 61 | 62 | async def clear(self) -> None: 63 | await self._redis.flushall(asynchronous=True) 64 | 65 | async def exists(self, key: str) -> bool: 66 | return bool(await self._redis.keys(key)) 67 | 68 | async def keys(self) -> list[str]: 69 | return await self._redis.keys("*") 70 | 71 | async def close(self) -> None: 72 | await self._redis.aclose(close_connection_pool=True) # type: ignore 73 | 74 | 75 | def get_redis(settings: RedisSettings, **kw: Any) -> RedisCache: 76 | return RedisCache( 77 | aioredis.Redis( 78 | host=settings.host, 79 | port=settings.port, 80 | password=settings.password, 81 | decode_responses=True, 82 | **kw, 83 | ) 84 | ) 85 | -------------------------------------------------------------------------------- /src/api/common/broker/nats/stream.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import msgspec 4 | 5 | from nats.js.api import StreamConfig 6 | from nats.js.errors import NotFoundError 7 | from src.api.common.broker.nats.core import NatsJetStreamBroker 8 | from src.common.di import Depends, FromDepends, inject 9 | from src.common.logger import log 10 | from src.common.tools.files import open_file_sync 11 | from src.settings.core import path 12 | 13 | 14 | def load_stream_config() -> list[dict[str, Any]]: 15 | stream_path = path("jetstream", "stream.json") 16 | if not stream_path: 17 | raise FileNotFoundError(f"Config file {stream_path} not found") 18 | config = msgspec.json.decode(open_file_sync(stream_path, mode="r")) 19 | return list(config.get("streams", [])) 20 | 21 | 22 | def normalize_config(config: dict[str, Any]) -> dict[str, Any]: 23 | return { 24 | key: config.get(key) 25 | for key in ( 26 | "subjects", 27 | "storage", 28 | "max_msgs", 29 | "max_bytes", 30 | "max_age", 31 | "discard", 32 | ) 33 | } 34 | 35 | 36 | async def sync_stream(broker: NatsJetStreamBroker, new_config: dict[str, Any]) -> None: 37 | stream_name = new_config.get("name") 38 | if not stream_name: 39 | raise ValueError("Stream configuration must include the 'name' key") 40 | 41 | try: 42 | stream_info = await broker.js.stream_info(stream_name) 43 | current_config = stream_info.config 44 | 45 | if current_config.retention != new_config.get("retention"): 46 | log.info( 47 | "Retention policy change detected for stream '%s'. Recreating stream...", 48 | stream_name, 49 | ) 50 | await broker.js.delete_stream(stream_name) 51 | await broker.js.add_stream(**new_config) 52 | return 53 | 54 | if normalize_config(current_config.as_dict()) != normalize_config( 55 | StreamConfig().evolve(**new_config).as_dict() 56 | ): 57 | log.info("Updating stream '%s' configuration", stream_name) 58 | await broker.js.update_stream(**new_config) 59 | 60 | except NotFoundError: 61 | log.info("Stream '%s' not found. Creating it.", stream_name) 62 | await broker.js.add_stream(**new_config) 63 | 64 | 65 | @inject 66 | async def sync_streams( 67 | broker: Depends[NatsJetStreamBroker] = FromDepends(), 68 | ) -> None: 69 | try: 70 | for stream_config in load_stream_config(): 71 | await sync_stream(broker, stream_config) 72 | except Exception as e: 73 | log.error("Error syncing streams: %s", e) 74 | -------------------------------------------------------------------------------- /src/api/v1/handlers/auth/register.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | from dataclasses import dataclass 3 | from datetime import timedelta 4 | from typing import Annotated 5 | 6 | from msgspec import Meta 7 | 8 | from src.api.common.interfaces.event_bus import EventBus 9 | from src.api.common.interfaces.handler import Handler 10 | from src.api.v1 import dtos 11 | from src.api.v1.constants import ( 12 | MAX_LOGIN_LENGTH, 13 | MAX_PASSWORD_LENGTH, 14 | MIN_PASSWORD_LENGTH, 15 | ) 16 | from src.api.v1.dtos.base import DTO 17 | from src.api.v1.events.email import SendEmail 18 | from src.api.v1.tools.validate import validate_email 19 | from src.common.exceptions import ( 20 | ConflictError, 21 | ) 22 | from src.common.tools.cache import default_key_builder 23 | from src.database import DBGateway 24 | from src.services import InternalServiceGateway 25 | from src.services.cache.redis import RedisCache 26 | 27 | 28 | class RegisterQuery(DTO): 29 | login: Annotated[ 30 | str, 31 | Meta( 32 | max_length=MAX_LOGIN_LENGTH, 33 | description=(f"Login maximum length `{MAX_LOGIN_LENGTH}`"), 34 | ), 35 | ] 36 | password: Annotated[ 37 | str, 38 | Meta( 39 | min_length=MIN_PASSWORD_LENGTH, 40 | max_length=MAX_PASSWORD_LENGTH, 41 | description=( 42 | f"Password between `{MIN_PASSWORD_LENGTH}` and " 43 | f"`{MAX_PASSWORD_LENGTH}` characters long" 44 | ), 45 | ), 46 | ] 47 | fingerprint: str 48 | 49 | def __post_init__(self) -> None: 50 | self.login = validate_email(self.login) 51 | 52 | 53 | @dataclass(slots=True) 54 | class RegisterHandler(Handler[RegisterQuery, dtos.Status]): 55 | internal_gateway: InternalServiceGateway 56 | database: DBGateway 57 | event_bus: EventBus 58 | redis: RedisCache 59 | 60 | async def __call__(self, query: RegisterQuery) -> dtos.Status: 61 | async with self.database.manager.session: 62 | user = (await self.database.user.exists(login=query.login)).result() 63 | 64 | if user: 65 | raise ConflictError("User already exists") 66 | 67 | code = secrets.token_hex(16) 68 | 69 | await self.redis.set( 70 | default_key_builder(code=code), 71 | query.as_string(), 72 | expire=timedelta(minutes=10), 73 | ) 74 | 75 | await self.event_bus.publish( 76 | SendEmail[dtos.VerificationCode]( 77 | from_="your-email@example.com", 78 | to=query.login, 79 | title="OcbUnknown Template", 80 | template="verify_email", 81 | props=dtos.VerificationCode(code=code), 82 | ) 83 | ) 84 | 85 | return dtos.Status(status=True) 86 | -------------------------------------------------------------------------------- /src/api/v1/middlewares/auth.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from typing import Any, Optional 3 | 4 | import uuid_utils.compat as uuid 5 | from litestar.connection import ASGIConnection 6 | from litestar.datastructures import State 7 | from litestar.handlers.base import BaseRouteHandler 8 | from litestar.middleware import ( 9 | AbstractAuthenticationMiddleware, 10 | AuthenticationResult, 11 | ) 12 | from litestar.types import ASGIApp, Method, Scopes 13 | 14 | from src.api.v1 import dtos 15 | from src.common.di import Depends, FromDepends, inject 16 | from src.common.exceptions import ( 17 | ForbiddenError, 18 | ServiceNotImplementedError, 19 | UnAuthorizedError, 20 | ) 21 | from src.database import DBGateway 22 | from src.services.internal import InternalServiceGateway 23 | 24 | 25 | class AuthenticationMiddleware(AbstractAuthenticationMiddleware): 26 | __slots__ = ("auth_header_name",) 27 | 28 | def __init__( 29 | self, 30 | app: ASGIApp, 31 | auth_header_name: str = "Authorization", 32 | exclude: str | list[str] = "schema", 33 | exclude_from_auth_key: str = "exclude_from_auth", 34 | exclude_http_methods: Sequence[Method] | None = None, 35 | scopes: Optional[Scopes] = None, 36 | ) -> None: 37 | super().__init__( 38 | app, exclude, exclude_from_auth_key, exclude_http_methods, scopes 39 | ) 40 | self.auth_header_name = auth_header_name 41 | 42 | async def authenticate_request( 43 | self, connection: ASGIConnection[BaseRouteHandler, Any, Any, State] 44 | ) -> AuthenticationResult: 45 | auth_header = connection.headers.get(self.auth_header_name) 46 | if not auth_header: 47 | raise UnAuthorizedError("Authorization is required") 48 | 49 | access_token = auth_header.partition(" ")[-1] 50 | 51 | user = await self._authenticate_token(access_token) 52 | return AuthenticationResult(user=user, auth=None) 53 | 54 | @inject 55 | async def _authenticate_token( 56 | self, 57 | access_token: str, 58 | internal_gateway: Depends[InternalServiceGateway] = FromDepends(), 59 | ) -> dtos.User: 60 | user_uuid = await internal_gateway.auth.verify_token(access_token, "access") 61 | return await self._authenticate_user(user_uuid) 62 | 63 | @inject 64 | async def _authenticate_user( 65 | self, 66 | user_uuid: uuid.UUID, 67 | database: Depends[DBGateway] = FromDepends(), 68 | ) -> dtos.User: 69 | async with database.manager.session: 70 | user = (await database.user.select("role", user_uuid=user_uuid)).result() 71 | 72 | if not user.active: 73 | raise ForbiddenError("You have been blocked") 74 | if not user.role: 75 | raise ServiceNotImplementedError("Role not found") 76 | 77 | return dtos.User.from_mapping(user.as_dict()) 78 | -------------------------------------------------------------------------------- /src/api/v1/dtos/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping 4 | from typing import Any, Self 5 | 6 | import msgspec 7 | 8 | from src.api.common.converters import convert_from, convert_to 9 | from src.api.common.serializers.msgspec import ( 10 | msgpack_decoder, 11 | msgpack_encoder, 12 | msgspec_decoder, 13 | msgspec_encoder, 14 | ) 15 | 16 | 17 | class DTO(msgspec.Struct): 18 | @classmethod 19 | def from_mapping(cls, value: Mapping[str, Any]) -> Self: 20 | return convert_to(cls, value, strict=False) 21 | 22 | @classmethod 23 | def from_string(cls, value: str) -> Self: 24 | return convert_to(cls, msgspec_decoder(value, strict=False), strict=False) 25 | 26 | @classmethod 27 | def from_attributes(cls, value: Any) -> Self: 28 | return convert_to(cls, value, strict=False, from_attributes=True) 29 | 30 | @classmethod 31 | def from_bytes(cls, value: bytes) -> Self: 32 | return convert_to(cls, msgpack_decoder(value, strict=False), strict=False) 33 | 34 | def as_mapping( 35 | self, exclude_none: bool = False, exclude: set[str] | None = None 36 | ) -> Mapping[str, Any]: 37 | result: dict[str, Any] = convert_from(self) 38 | if exclude_none: 39 | result = {k: v for k, v in result.items() if v is not None} 40 | if exclude: 41 | result = {k: v for k, v in result.items() if k not in exclude} 42 | 43 | return result 44 | 45 | def as_string( 46 | self, exclude_none: bool = False, exclude: set[str] | None = None 47 | ) -> str: 48 | return msgspec_encoder( 49 | self.as_mapping(exclude_none=exclude_none, exclude=exclude) 50 | if exclude_none or exclude 51 | else self 52 | ) 53 | 54 | def as_bytes( 55 | self, exclude_none: bool = False, exclude: set[str] | None = None 56 | ) -> bytes: 57 | return msgpack_encoder( 58 | self.as_mapping(exclude_none=exclude_none, exclude=exclude) 59 | if exclude_none or exclude 60 | else self 61 | ) 62 | 63 | def copy(self, update: Mapping[str, Any] | None = None) -> Self: 64 | return type(self)(**{**self.as_mapping(), **(update or {})}) 65 | 66 | 67 | class StrictDTO(msgspec.Struct): 68 | @classmethod 69 | def from_mapping(cls, value: Mapping[str, Any]) -> Self: 70 | return convert_to(cls, value, strict=True) 71 | 72 | @classmethod 73 | def from_string(cls, value: str) -> Self: 74 | return convert_to(cls, msgspec_decoder(value, strict=True), strict=True) 75 | 76 | @classmethod 77 | def from_attributes(cls, value: Any) -> Self: 78 | return convert_to(cls, value, strict=True, from_attributes=True) 79 | 80 | @classmethod 81 | def from_bytes(cls, value: bytes) -> Self: 82 | return convert_to(cls, msgpack_decoder(value, strict=True), strict=True) 83 | -------------------------------------------------------------------------------- /src/database/repositories/role.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from typing import Any, Optional, Unpack 3 | 4 | import uuid_utils.compat as uuid 5 | from sqlalchemy import ColumnExpressionArgument, UnaryExpression 6 | 7 | import src.database.models as models 8 | from src.database.exceptions import InvalidParamsError 9 | from src.database.models.types import Roles 10 | from src.database.repositories import Result 11 | from src.database.repositories.base import BaseRepository 12 | from src.database.repositories.types.role import ( 13 | CreateRoleType, 14 | RoleLoads, 15 | ) 16 | from src.database.tools import on_integrity, select_with_relationships 17 | from src.database.types import OrderBy 18 | 19 | 20 | class RoleRepository(BaseRepository[models.Role]): 21 | __slots__ = () 22 | 23 | @on_integrity("name") 24 | async def create(self, **data: Unpack[CreateRoleType]) -> Result[models.Role]: 25 | return Result("create", await self._crud.insert(**data)) 26 | 27 | async def select( 28 | self, 29 | *loads: RoleLoads, 30 | role_uuid: Optional[uuid.UUID] = None, 31 | name: Optional[Roles] = None, 32 | ) -> Result[models.Role]: 33 | if not any([role_uuid, name]): 34 | raise InvalidParamsError("at least one identifier must be provided") 35 | 36 | where_clauses: list[ColumnExpressionArgument[bool]] = [] 37 | 38 | if role_uuid: 39 | where_clauses.append(self.model.uuid == role_uuid) 40 | if name: 41 | where_clauses.append(self.model.name == name) 42 | 43 | stmt = select_with_relationships(*loads, model=self.model).where(*where_clauses) 44 | return Result("select", (await self._session.scalars(stmt)).unique().first()) 45 | 46 | async def select_many( 47 | self, 48 | *loads: RoleLoads, 49 | name: Optional[str] = None, 50 | order_by: OrderBy = "desc", 51 | offset: int = 0, 52 | limit: int = 10, 53 | ) -> tuple[int, Sequence[models.Role]]: 54 | where_clauses: list[ColumnExpressionArgument[bool]] = [] 55 | order_by_clauses: list[UnaryExpression[Any]] = [] 56 | 57 | if name: 58 | where_clauses.append(self.model.name.ilike(f"%{name}%")) 59 | if order_by: 60 | order_by_clauses.append(getattr(self.model.created_at, order_by)()) 61 | 62 | total = await self._crud.count(*where_clauses) 63 | if total <= 0: 64 | return total, [] 65 | 66 | stmt = ( 67 | select_with_relationships(*loads, model=self.model) 68 | .where(*where_clauses) 69 | .order_by(*order_by_clauses) 70 | .limit(limit) 71 | .offset(offset) 72 | ) 73 | 74 | results = (await self._session.scalars(stmt)).unique().all() 75 | return total, results 76 | 77 | async def exists(self, name: str) -> Result[bool]: 78 | return Result("exists", await self._crud.exists(self.model.name == name)) 79 | -------------------------------------------------------------------------------- /src/services/provider/middleware/error.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from http import HTTPStatus 3 | 4 | import src.services.provider.errors as err 5 | from src.services.provider.middleware.base import ( 6 | BaseRequestMiddleware, 7 | CallNextMiddlewareType, 8 | ) 9 | from src.services.provider.response import Response 10 | from src.services.provider.types import RequestMethodType 11 | 12 | 13 | async def _check_response(response: Response) -> Response: 14 | status_code = response.status 15 | if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED: 16 | return response 17 | 18 | content = await response.read() 19 | 20 | url = str(response.url) 21 | if status_code == HTTPStatus.BAD_REQUEST: 22 | raise err.BadRequestError( 23 | status_code=status_code, content=content, message="Bad request", url=url 24 | ) 25 | if status_code == HTTPStatus.TOO_MANY_REQUESTS: 26 | raise err.TooManyRequestsError( 27 | status_code=status_code, 28 | content=content, 29 | message="Too many requests", 30 | url=url, 31 | ) 32 | if status_code == HTTPStatus.NOT_FOUND: 33 | raise err.NotFoundError( 34 | status_code=status_code, content=content, message="Not found", url=url 35 | ) 36 | if status_code == HTTPStatus.CONFLICT: 37 | raise err.ConflictError( 38 | status_code=status_code, content=content, message="Conflict", url=url 39 | ) 40 | if status_code == HTTPStatus.UNAUTHORIZED: 41 | raise err.UnauthorizedError( 42 | status_code=status_code, 43 | content=content, 44 | message="Auth is required", 45 | url=url, 46 | ) 47 | if status_code == HTTPStatus.FORBIDDEN: 48 | raise err.ForbiddenError( 49 | status_code=status_code, 50 | content=content, 51 | message="You have no permissions", 52 | url=url, 53 | ) 54 | if status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: 55 | raise err.EntityTooLarge( 56 | status_code=status_code, 57 | content=content, 58 | message="Too large content", 59 | url=url, 60 | ) 61 | if status_code == HTTPStatus.INTERNAL_SERVER_ERROR: 62 | raise err.ServerError( 63 | status_code=status_code, 64 | content=content, 65 | message="Server is disabled or you are banned", 66 | url=url, 67 | ) 68 | 69 | raise err.APIError( 70 | status_code=status_code, content=content, message="Unknown Error", url=url 71 | ) 72 | 73 | 74 | class RequestErrorMiddleware(BaseRequestMiddleware): 75 | async def __call__( 76 | self, 77 | call_next: CallNextMiddlewareType, 78 | method: RequestMethodType, 79 | url_or_endpoint: str, 80 | **kw: t.Any, 81 | ) -> Response: 82 | response = await call_next(method=method, url_or_endpoint=url_or_endpoint, **kw) 83 | 84 | return await _check_response(response) 85 | -------------------------------------------------------------------------------- /src/api/common/events/nats.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | from src.api.common.events.base import Event 4 | 5 | 6 | class NatsEvent(Event): 7 | subject: str = "" 8 | _reply: str = "" 9 | _headers: dict[str, Any] | None = None 10 | 11 | 12 | class NatsJetStreamEvent(Event): 13 | subject: str = "" 14 | _stream: str | None = None 15 | _timeout: float | None = None 16 | _headers: dict[str, Any] | None = None 17 | 18 | 19 | def rebuild_nats_event[T: NatsEvent]( 20 | subject: str = "", 21 | subject_namespace: str | None = None, 22 | separator: str = ".", 23 | reply: str = "", 24 | headers: dict[str, Any] | None = None, 25 | ) -> Callable[[type[T]], type[T]]: 26 | def _wrapper(cls: type[T]) -> type[T]: 27 | updated_subject = subject 28 | 29 | class UpdatedEvent(cls): # type: ignore 30 | subject: str = updated_subject 31 | _reply: str = reply 32 | if headers is not None: 33 | _headers: dict[str, Any] = headers 34 | 35 | def __post_init__(self) -> None: 36 | if subject_namespace and self.subject: 37 | self.subject = f"{subject_namespace}{separator}{self.subject}" 38 | elif not self.subject and subject: 39 | self.subject = subject 40 | 41 | UpdatedEvent.__name__ = cls.__name__ 42 | UpdatedEvent.__qualname__ = cls.__qualname__ 43 | UpdatedEvent.__module__ = cls.__module__ 44 | UpdatedEvent.__parameters__ = cls.__parameters__ # type: ignore 45 | UpdatedEvent.__orig_bases__ = cls.__orig_bases__ # type: ignore 46 | return UpdatedEvent 47 | 48 | return _wrapper 49 | 50 | 51 | def rebuild_jetstream_event[T: NatsJetStreamEvent]( 52 | subject: str = "", 53 | subject_namespace: str | None = None, 54 | separator: str = ".", 55 | stream: str | None = None, 56 | timeout: float | None = None, 57 | headers: dict[str, Any] | None = None, 58 | ) -> Callable[[type[T]], type[T]]: 59 | def _wrapper(cls: type[T]) -> type[T]: 60 | updated_subject = subject 61 | 62 | class UpdatedEvent(cls): # type: ignore 63 | subject: str = updated_subject 64 | if stream is not None: 65 | _stream: str = stream 66 | if timeout is not None: 67 | _timeout: float = timeout 68 | if headers is not None: 69 | _headers: dict[str, Any] = headers 70 | 71 | def __post_init__(self) -> None: 72 | if subject_namespace and self.subject: 73 | self.subject = f"{subject_namespace}{separator}{self.subject}" 74 | elif not self.subject and subject: 75 | self.subject = subject 76 | 77 | UpdatedEvent.__name__ = cls.__name__ 78 | UpdatedEvent.__qualname__ = cls.__qualname__ 79 | UpdatedEvent.__module__ = cls.__module__ 80 | UpdatedEvent.__parameters__ = cls.__parameters__ # type: ignore 81 | UpdatedEvent.__orig_bases__ = cls.__orig_bases__ # type: ignore 82 | return UpdatedEvent 83 | 84 | return _wrapper 85 | -------------------------------------------------------------------------------- /migrations/versions/01_3d1a9bc30546_initial.py: -------------------------------------------------------------------------------- 1 | """initial 2 | 3 | Revision ID: 01_3d1a9bc30546 4 | Revises: 5 | Create Date: 2025-04-06 05:00:58.131522 6 | 7 | """ 8 | 9 | from typing import Sequence, Union 10 | 11 | import sqlalchemy as sa 12 | from alembic import op 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = "01_3d1a9bc30546" 16 | down_revision: Union[str, None] = None 17 | branch_labels: Union[str, Sequence[str], None] = None 18 | depends_on: Union[str, Sequence[str], None] = None 19 | 20 | 21 | def upgrade() -> None: 22 | # ### commands auto generated by Alembic - please adjust! ### 23 | op.create_table( 24 | "role", 25 | sa.Column( 26 | "name", 27 | sa.Enum("Admin", "User", name="role_type_enum", native_enum=False), 28 | nullable=False, 29 | ), 30 | sa.Column( 31 | "uuid", 32 | sa.UUID(), 33 | server_default=sa.text("gen_random_uuid()"), 34 | nullable=False, 35 | ), 36 | sa.Column( 37 | "created_at", 38 | sa.DateTime(timezone=True), 39 | server_default=sa.text("now()"), 40 | nullable=False, 41 | ), 42 | sa.Column( 43 | "updated_at", 44 | sa.DateTime(timezone=True), 45 | server_default=sa.text("now()"), 46 | nullable=False, 47 | ), 48 | sa.PrimaryKeyConstraint("uuid"), 49 | ) 50 | op.create_index(op.f("ix_role_name"), "role", ["name"], unique=False) 51 | op.create_table( 52 | "user", 53 | sa.Column("login", sa.String(), nullable=False), 54 | sa.Column("password", sa.String(), nullable=False), 55 | sa.Column("active", sa.Boolean(), nullable=False), 56 | sa.Column("role_uuid", sa.UUID(), nullable=False), 57 | sa.Column( 58 | "uuid", 59 | sa.UUID(), 60 | server_default=sa.text("gen_random_uuid()"), 61 | nullable=False, 62 | ), 63 | sa.Column( 64 | "created_at", 65 | sa.DateTime(timezone=True), 66 | server_default=sa.text("now()"), 67 | nullable=False, 68 | ), 69 | sa.Column( 70 | "updated_at", 71 | sa.DateTime(timezone=True), 72 | server_default=sa.text("now()"), 73 | nullable=False, 74 | ), 75 | sa.ForeignKeyConstraint( 76 | ["role_uuid"], ["role.uuid"], name="fk_user_role", ondelete="RESTRICT" 77 | ), 78 | sa.PrimaryKeyConstraint("uuid"), 79 | ) 80 | op.create_index(op.f("ix_user_active"), "user", ["active"], unique=False) 81 | op.create_index(op.f("ix_user_login"), "user", ["login"], unique=False) 82 | 83 | op.execute(sa.text("INSERT INTO role (name) VALUES ('Admin'), ('User')")) 84 | 85 | # ### end Alembic commands ### 86 | 87 | 88 | def downgrade() -> None: 89 | # ### commands auto generated by Alembic - please adjust! ### 90 | op.drop_index(op.f("ix_user_login"), table_name="user") 91 | op.drop_index(op.f("ix_user_active"), table_name="user") 92 | op.drop_table("user") 93 | op.drop_index(op.f("ix_role_name"), table_name="role") 94 | op.drop_table("role") 95 | # ### end Alembic commands ### 96 | -------------------------------------------------------------------------------- /src/api/common/exceptions.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from functools import partial 3 | from typing import Any, Callable 4 | 5 | from litestar import MediaType, Request, Response 6 | from litestar import status_codes as status 7 | from litestar.exceptions import ClientException 8 | from litestar.types import ExceptionHandlersMap 9 | from uuid_utils.compat import uuid4 10 | 11 | import src.common.exceptions as exc 12 | from src.common.logger import log 13 | 14 | JsonResponse = Response[dict[str, Any]] 15 | type BasicRequest = Request[Any, Any, Any] 16 | 17 | 18 | def setup_exception_handlers() -> ExceptionHandlersMap: 19 | return { 20 | exc.TooManyRequestsError: error_handler(status.HTTP_429_TOO_MANY_REQUESTS), 21 | exc.ServiceNotImplementedError: error_handler(status.HTTP_501_NOT_IMPLEMENTED), 22 | exc.ConflictError: error_handler(status.HTTP_409_CONFLICT), 23 | exc.AppException: error_handler(status.HTTP_500_INTERNAL_SERVER_ERROR), 24 | exc.NotFoundError: error_handler(status.HTTP_404_NOT_FOUND), 25 | exc.UnAuthorizedError: error_handler(status.HTTP_401_UNAUTHORIZED), 26 | exc.ForbiddenError: error_handler(status.HTTP_403_FORBIDDEN), 27 | exc.BadRequestError: error_handler(status.HTTP_400_BAD_REQUEST), 28 | exc.ServiceUnavailableError: error_handler(status.HTTP_503_SERVICE_UNAVAILABLE), 29 | exc.BadGatewayError: error_handler(status.HTTP_502_BAD_GATEWAY), 30 | Exception: error_handler(status.HTTP_500_INTERNAL_SERVER_ERROR), 31 | } 32 | 33 | 34 | def error_handler( 35 | status_code: int, 36 | ) -> Callable[..., JsonResponse]: 37 | return partial(app_error_handler, status_code=status_code) 38 | 39 | 40 | def app_error_handler( 41 | request: BasicRequest, exc: Exception, status_code: int 42 | ) -> JsonResponse: 43 | return handle_error( 44 | request, 45 | exception=exc, 46 | status_code=status_code, 47 | ) 48 | 49 | 50 | def handle_error( 51 | request: BasicRequest, 52 | exception: Exception, 53 | status_code: int, 54 | ) -> JsonResponse: 55 | ticket = request.state.get("request_id", uuid4().hex) 56 | tb_str = "".join( 57 | traceback.format_exception(type(exception), exception, exception.__traceback__) 58 | ) 59 | log.error( 60 | f"Ticket: {ticket}\n" 61 | f"Method: {request.method}\n" 62 | f"Path: {request.url.path}\n" 63 | f"Handle error: {type(exception).__name__} -> {exception.args}\n" 64 | f"Traceback:\n{tb_str}" 65 | ) 66 | if isinstance(exception, exc.AppException): 67 | error_dict = exception.as_dict() 68 | error_dict["content"] |= {"ticket": ticket} 69 | return JsonResponse( 70 | **error_dict, 71 | status_code=status_code, 72 | media_type=MediaType.JSON, 73 | ) 74 | 75 | if isinstance(exception, ClientException): 76 | return JsonResponse( 77 | content={ 78 | "message": type(exception).__name__, 79 | "details": exception.extra, 80 | "ticket": ticket, 81 | }, 82 | status_code=exception.status_code, 83 | media_type=MediaType.JSON, 84 | ) 85 | 86 | response = JsonResponse( 87 | **exc.ForbiddenError( 88 | message="Something went wrong", 89 | ticket=ticket, 90 | ).as_dict(), 91 | status_code=status_code, 92 | media_type=MediaType.JSON, 93 | ) 94 | return response 95 | -------------------------------------------------------------------------------- /src/api/v1/endpoints/admin/user.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | import uuid_utils.compat as uuid 4 | from litestar import Controller, get, patch, post, status_codes 5 | from litestar.params import Parameter 6 | 7 | from src.api.common.interfaces.mediator import Mediator 8 | from src.api.v1 import dtos, handlers 9 | from src.api.v1.constants import MAX_PAGINATION_LIMIT, MIN_PAGINATION_LIMIT 10 | from src.api.v1.permission import Permission 11 | from src.common.di import Depends 12 | from src.database.repositories.types.user import UserLoads 13 | from src.database.types import OrderBy 14 | 15 | 16 | class UserController(Controller): 17 | path = "/users" 18 | tags = ["Admin | Users"] 19 | guards = [Permission("Admin")] 20 | security = [{"BearerToken": []}] 21 | 22 | @post( 23 | status_code=status_codes.HTTP_201_CREATED, 24 | exclude_from_auth=True, 25 | ) 26 | async def create_user_endpoint( 27 | self, 28 | data: handlers.user.CreateUserQuery, 29 | mediator: Depends[Mediator], 30 | ) -> dtos.User: 31 | return await mediator.send(data) 32 | 33 | @get("/{user_id:uuid}", status_code=status_codes.HTTP_200_OK) 34 | async def select_user_endpoint( 35 | self, 36 | s: Annotated[ 37 | tuple[UserLoads, ...], 38 | Parameter( 39 | required=False, 40 | default=(), 41 | description="Search for additional relations", 42 | ), 43 | ], 44 | user_id: uuid.UUID, 45 | login: Annotated[str | None, Parameter(default=None, required=False)], 46 | mediator: Depends[Mediator], 47 | ) -> dtos.User: 48 | return await mediator.send( 49 | handlers.user.SelectUserQuery(loads=s, user_uuid=user_id, login=login) 50 | ) 51 | 52 | @patch("/{user_id:uuid}", status_code=status_codes.HTTP_200_OK) 53 | async def update_user_endpoint( 54 | self, 55 | user_id: uuid.UUID, 56 | data: dtos.UpdateUser, 57 | mediator: Depends[Mediator], 58 | ) -> dtos.User: 59 | return await mediator.send( 60 | handlers.user.UpdateUserQuery(user_uuid=user_id, **data.as_mapping()) 61 | ) 62 | 63 | @get(status_code=status_codes.HTTP_200_OK) 64 | async def select_users_endpoint( 65 | self, 66 | s: Annotated[ 67 | tuple[UserLoads, ...], 68 | Parameter( 69 | required=False, 70 | default=(), 71 | description="Search for additional relations", 72 | ), 73 | ], 74 | login: Annotated[str | None, Parameter(default=None, required=False)], 75 | role_uuid: Annotated[uuid.UUID | None, Parameter(default=None, required=False)], 76 | order_by: Annotated[OrderBy, Parameter(default="desc", required=False)], 77 | offset: Annotated[int, Parameter(default=0)], 78 | limit: Annotated[ 79 | int, 80 | Parameter( 81 | default=MIN_PAGINATION_LIMIT, 82 | ge=MIN_PAGINATION_LIMIT, 83 | le=MAX_PAGINATION_LIMIT, 84 | required=False, 85 | ), 86 | ], 87 | mediator: Depends[Mediator], 88 | ) -> dtos.OffsetResult[dtos.User]: 89 | return await mediator.send( 90 | handlers.user.SelectManyUserQuery( 91 | loads=s, 92 | login=login, 93 | role_uuid=role_uuid, 94 | order_by=order_by, 95 | offset=offset, 96 | limit=limit, 97 | ) 98 | ) 99 | -------------------------------------------------------------------------------- /migrations/env.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections.abc import Iterable 3 | from logging.config import fileConfig 4 | from typing import List, Optional, Union 5 | 6 | from alembic import context 7 | from alembic.operations.ops import MigrationScript 8 | from alembic.runtime.migration import MigrationContext 9 | from alembic.script import ScriptDirectory 10 | from sqlalchemy import pool 11 | from sqlalchemy.engine import Connection 12 | from sqlalchemy.ext.asyncio import async_engine_from_config 13 | 14 | from src.database.models import Base 15 | from src.settings.core import load_settings 16 | 17 | config = context.config 18 | 19 | 20 | if config.config_file_name is not None: 21 | fileConfig(config.config_file_name) 22 | 23 | target_metadata = Base.metadata 24 | 25 | if not (url := config.get_main_option("sqlalchemy.url")): 26 | url = load_settings().db.url 27 | 28 | 29 | def add_number_to_migrations( 30 | context: MigrationContext, 31 | revision: Union[str, Iterable[Optional[str]]], 32 | directives: List[MigrationScript], 33 | ) -> None: 34 | migration_script = directives[0] 35 | head_revision = ScriptDirectory.from_config(context.config).get_current_head() # type: ignore 36 | if head_revision is None: 37 | new_rev_id = 1 38 | else: 39 | last_rev_id = int(head_revision.split("_")[0]) 40 | new_rev_id = last_rev_id + 1 41 | 42 | migration_script.rev_id = f"{new_rev_id:02}_{migration_script.rev_id}" 43 | 44 | 45 | def run_migrations_offline() -> None: 46 | """Run migrations in 'offline' mode. 47 | 48 | This configures the context with just a URL 49 | and not an Engine, though an Engine is acceptable 50 | here as well. By skipping the Engine creation 51 | we don't even need a DBAPI to be available. 52 | 53 | Calls to context.execute() here emit the given string to the 54 | script output. 55 | 56 | """ 57 | 58 | context.configure( 59 | url=url, 60 | target_metadata=target_metadata, 61 | literal_binds=True, 62 | dialect_opts={"paramstyle": "named"}, 63 | ) 64 | 65 | with context.begin_transaction(): 66 | context.run_migrations() 67 | 68 | 69 | def do_run_migrations(connection: Connection) -> None: 70 | context.configure( 71 | connection=connection, 72 | target_metadata=target_metadata, 73 | process_revision_directives=add_number_to_migrations, 74 | ) 75 | 76 | with context.begin_transaction(): 77 | context.run_migrations() 78 | 79 | 80 | async def run_async_migrations() -> None: 81 | """In this scenario we need to create an Engine 82 | and associate a connection with the context. 83 | 84 | """ 85 | 86 | connectable = async_engine_from_config( 87 | config.get_section(config.config_ini_section, {}), 88 | prefix="sqlalchemy.", 89 | poolclass=pool.NullPool, 90 | future=True, 91 | url=url, 92 | ) 93 | 94 | async with connectable.connect() as connection: 95 | await connection.run_sync(do_run_migrations) 96 | 97 | await connectable.dispose() 98 | 99 | 100 | def run_migrations_online() -> None: 101 | """Run migrations in 'online' mode.""" 102 | 103 | connectable = config.attributes.get("connection", None) 104 | 105 | if connectable: 106 | do_run_migrations(connectable) 107 | else: 108 | asyncio.run(run_async_migrations()) 109 | 110 | 111 | if context.is_offline_mode(): 112 | run_migrations_offline() 113 | else: 114 | run_migrations_online() 115 | -------------------------------------------------------------------------------- /src/database/repositories/user.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from typing import Any, Optional, Unpack 3 | 4 | import uuid_utils.compat as uuid 5 | from sqlalchemy import ColumnExpressionArgument, UnaryExpression 6 | 7 | import src.database.models as models 8 | from src.database.exceptions import InvalidParamsError 9 | from src.database.repositories import Result 10 | from src.database.repositories.base import BaseRepository 11 | from src.database.repositories.types.user import ( 12 | CreateUserType, 13 | UpdateUserType, 14 | UserLoads, 15 | ) 16 | from src.database.tools import on_integrity, select_with_relationships 17 | from src.database.types import OrderBy 18 | 19 | 20 | class UserRepository(BaseRepository[models.User]): 21 | __slots__ = () 22 | 23 | @on_integrity("login") 24 | async def create(self, **data: Unpack[CreateUserType]) -> Result[models.User]: 25 | return Result("create", await self._crud.insert(**data)) 26 | 27 | async def select( 28 | self, 29 | *loads: UserLoads, 30 | user_uuid: Optional[uuid.UUID] = None, 31 | login: Optional[str] = None, 32 | ) -> Result[models.User]: 33 | if not any([user_uuid, login]): 34 | raise InvalidParamsError("at least one identifier must be provided") 35 | 36 | where_clauses: list[ColumnExpressionArgument[bool]] = [] 37 | 38 | if user_uuid: 39 | where_clauses.append(self.model.uuid == user_uuid) 40 | if login: 41 | where_clauses.append(self.model.login == login) 42 | 43 | stmt = select_with_relationships(*loads, model=self.model).where(*where_clauses) 44 | return Result("select", (await self._session.scalars(stmt)).unique().first()) 45 | 46 | @on_integrity("login") 47 | async def update( 48 | self, 49 | uuid: uuid.UUID, 50 | /, 51 | **data: Unpack[UpdateUserType], 52 | ) -> Result[models.User]: 53 | if not any([uuid]): 54 | raise InvalidParamsError("at least one identifier must be provided") 55 | 56 | result = await self._crud.update(self.model.uuid == uuid, **data) 57 | return Result("update", result[0] if result else None) 58 | 59 | async def delete( 60 | self, user_uuid: Optional[uuid.UUID] = None, login: Optional[str] = None 61 | ) -> Result[models.User]: 62 | if not any([user_uuid, login]): 63 | raise InvalidParamsError("at least one identifier must be provided") 64 | 65 | where_clauses: list[ColumnExpressionArgument[bool]] = [] 66 | 67 | if user_uuid: 68 | where_clauses.append(self.model.uuid == user_uuid) 69 | if login: 70 | where_clauses.append(self.model.login == login) 71 | 72 | result = await self._crud.delete(*where_clauses) 73 | return Result("delete", result[0] if result else None) 74 | 75 | async def select_many( 76 | self, 77 | *loads: UserLoads, 78 | login: Optional[str] = None, 79 | role_uuid: Optional[uuid.UUID] = None, 80 | order_by: OrderBy = "desc", 81 | offset: int = 0, 82 | limit: int = 10, 83 | ) -> Result[tuple[int, Sequence[models.User]]]: 84 | where_clauses: list[ColumnExpressionArgument[bool]] = [] 85 | order_by_clauses: list[UnaryExpression[Any]] = [] 86 | 87 | if login: 88 | where_clauses.append(self.model.login.ilike(f"%{login}%")) 89 | if role_uuid: 90 | where_clauses.append(self.model.role_uuid == role_uuid) 91 | if order_by: 92 | order_by_clauses.append(getattr(self.model.created_at, order_by)()) 93 | 94 | total = await self._crud.count(*where_clauses) 95 | if total <= 0: 96 | return Result("select", (total, [])) 97 | 98 | stmt = ( 99 | select_with_relationships(*loads, model=self.model) 100 | .where(*where_clauses) 101 | .order_by(*order_by_clauses) 102 | .limit(limit) 103 | .offset(offset) 104 | ) 105 | 106 | results = (await self._session.scalars(stmt)).unique().all() 107 | return Result("select", (total, results)) 108 | 109 | async def exists(self, login: str) -> Result[bool]: 110 | return Result("exists", await self._crud.exists(self.model.login == login)) 111 | -------------------------------------------------------------------------------- /src/api/setup.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncIterator, Callable 2 | from contextlib import asynccontextmanager 3 | 4 | from dishka.integrations.litestar import setup_dishka 5 | from litestar import Litestar 6 | from litestar.config.app import AppConfig 7 | from litestar.config.cors import CORSConfig 8 | from litestar.middleware.logging import LoggingMiddlewareConfig 9 | from litestar.openapi.config import OpenAPIConfig 10 | from litestar.openapi.plugins import SwaggerRenderPlugin 11 | from litestar.openapi.spec.components import Components 12 | from litestar.openapi.spec.security_scheme import SecurityScheme 13 | from litestar.stores.redis import RedisStore 14 | from litestar.stores.registry import StoreRegistry 15 | from redis.asyncio import Redis 16 | 17 | from src.api.common.broker.nats.connection import create_nats_connection 18 | from src.api.common.broker.nats.core import NatsBroker 19 | from src.api.common.exceptions import setup_exception_handlers 20 | from src.api.common.middlewares import setup_middlewares 21 | from src.api.common.tools import ClosableProxy, RouterState 22 | from src.common.di import container 23 | from src.settings.core import Settings 24 | 25 | 26 | @asynccontextmanager 27 | async def lifespan(app: Litestar) -> AsyncIterator[None]: 28 | try: 29 | broker: NatsBroker = await app.state.dishka_container.get(NatsBroker) 30 | settings: Settings = await app.state.dishka_container.get(Settings) 31 | await create_nats_connection(broker, settings) 32 | 33 | yield 34 | finally: 35 | for value in app.state.values(): 36 | if isinstance(value, ClosableProxy): 37 | await value.close() 38 | 39 | 40 | def on_app_init( 41 | settings: Settings, *router_state: RouterState 42 | ) -> Callable[[AppConfig], AppConfig]: 43 | def wrapped(app_config: AppConfig) -> AppConfig: 44 | app_config.exception_handlers.update(setup_exception_handlers()) 45 | app_config.middleware.extend(setup_middlewares()) 46 | 47 | app_config.cors_config = CORSConfig( 48 | allow_origins=settings.server.origins, 49 | allow_credentials=True, 50 | allow_methods=list(settings.server.methods), 51 | allow_headers=settings.server.headers, 52 | ) 53 | 54 | if settings.app.debug: 55 | app_config.middleware.append(LoggingMiddlewareConfig().middleware) 56 | 57 | for r_state in router_state: 58 | app_config.route_handlers.append(r_state.router) 59 | app_config.state.update(r_state.state.dict()) 60 | app_config.on_startup.extend(r_state.on_startup) 61 | app_config.on_shutdown.extend(r_state.on_shutdown) 62 | 63 | return app_config 64 | 65 | return wrapped 66 | 67 | 68 | def security_components() -> list[Components]: 69 | return [ 70 | Components( 71 | security_schemes={ 72 | "BearerToken": SecurityScheme( 73 | type="http", 74 | scheme="Bearer", 75 | name="Authorization", 76 | bearer_format="JWT", 77 | description=None, 78 | ), 79 | }, 80 | ), 81 | ] 82 | 83 | 84 | def init_app(settings: Settings, *router_state: RouterState) -> Litestar: 85 | app = Litestar( 86 | path=settings.app.root_path, 87 | openapi_config=( 88 | OpenAPIConfig( 89 | title=settings.app.title, 90 | version=settings.app.version, 91 | components=security_components(), 92 | render_plugins=(SwaggerRenderPlugin(),), 93 | ) 94 | ) 95 | if settings.app.title and settings.app.swagger 96 | else None, 97 | debug=settings.app.debug, 98 | lifespan=[lifespan], 99 | stores=StoreRegistry( 100 | default_factory=RedisStore( 101 | Redis( 102 | **settings.redis.model_dump(exclude_none=True), 103 | decode_responses=True, 104 | ), 105 | handle_client_shutdown=True, 106 | ).with_namespace, 107 | ), 108 | on_app_init=[on_app_init(settings, *router_state)], 109 | ) 110 | setup_dishka(container.get_container(), app) 111 | 112 | return app 113 | -------------------------------------------------------------------------------- /src/services/internal/auth.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import datetime 3 | from typing import Final, Literal 4 | 5 | import uuid_utils.compat as uuid 6 | from litestar.concurrency import sync_to_thread 7 | 8 | from src.common.exceptions import ForbiddenError 9 | from src.services.cache.redis import RedisCache 10 | from src.services.security.jwt import JWT 11 | 12 | DEFAULT_TOKENS_COUNT: Final[int] = 5 13 | 14 | TokenType = Literal["access", "refresh"] 15 | 16 | 17 | @dataclass(slots=True) 18 | class Token: 19 | token: str 20 | 21 | 22 | @dataclass(slots=True) 23 | class Tokens: 24 | access: str 25 | refresh: str 26 | 27 | 28 | @dataclass(slots=True) 29 | class TokensExpire: 30 | refresh_expire: datetime 31 | tokens: Tokens 32 | 33 | 34 | class AuthService: 35 | __slots__ = ("_jwt", "_cache") 36 | 37 | def __init__(self, jwt: JWT, cache: RedisCache) -> None: 38 | self._jwt = jwt 39 | self._cache = cache 40 | 41 | async def login(self, fingerprint: str, user_uuid: uuid.UUID) -> TokensExpire: 42 | _, access = await sync_to_thread( 43 | self._jwt.create, typ="access", sub=str(user_uuid) 44 | ) 45 | expire, refresh = await sync_to_thread( 46 | self._jwt.create, typ="refresh", sub=str(user_uuid) 47 | ) 48 | tokens = await self._cache.get_list(str(user_uuid)) 49 | 50 | if len(tokens) > DEFAULT_TOKENS_COUNT: 51 | await self._cache.delete(str(user_uuid)) 52 | 53 | await self._cache.set_list(str(user_uuid), f"{fingerprint}::{refresh}") 54 | 55 | return TokensExpire( 56 | refresh_expire=expire, 57 | tokens=Tokens(access=access, refresh=refresh), 58 | ) 59 | 60 | async def verify_refresh( 61 | self, 62 | fingerprint: str, 63 | refresh_token: str, 64 | ) -> TokensExpire: 65 | user_uuid = await self.verify_token(refresh_token, "refresh") 66 | token_pairs = await self._cache.get_list(str(user_uuid)) 67 | verified = None 68 | for pair in token_pairs: 69 | data = pair.split("::") 70 | if len(data) < 2: 71 | await self._cache.delete(str(user_uuid)) 72 | raise ForbiddenError( 73 | "Broken separator, try to login again. Token is not valid anymore" 74 | ) 75 | fp, cached_token, *_ = data 76 | if fp == fingerprint and cached_token == refresh_token: 77 | verified = pair 78 | break 79 | 80 | if not verified: 81 | await self._cache.delete(str(user_uuid)) 82 | raise ForbiddenError("Token is not valid anymore") 83 | 84 | await self._cache.discard(str(user_uuid), verified) 85 | _, access = await sync_to_thread( 86 | self._jwt.create, typ="access", sub=str(user_uuid) 87 | ) 88 | expire, refresh = await sync_to_thread( 89 | self._jwt.create, typ="refresh", sub=str(user_uuid) 90 | ) 91 | await self._cache.set_list(str(user_uuid), f"{fingerprint}::{refresh}") 92 | 93 | return TokensExpire( 94 | refresh_expire=expire, 95 | tokens=Tokens(access=access, refresh=refresh), 96 | ) 97 | 98 | async def invalidate_refresh( 99 | self, 100 | refresh_token: str, 101 | user_uuid: uuid.UUID, 102 | ) -> bool: 103 | await self.verify_token(refresh_token, "refresh") 104 | token_pairs = await self._cache.get_list(str(user_uuid)) 105 | for pair in token_pairs: 106 | data = pair.split("::") 107 | if len(data) < 2: 108 | await self._cache.delete(str(user_uuid)) 109 | break 110 | _, cached_token, *_ = data 111 | if cached_token == refresh_token: 112 | await self._cache.discard(str(user_uuid), pair) 113 | break 114 | 115 | return True 116 | 117 | async def verify_token( 118 | self, 119 | token: str, 120 | token_type: TokenType, 121 | ) -> uuid.UUID: 122 | payload = await sync_to_thread(self._jwt.verify_token, token) 123 | actual_token_type = payload.get("type") 124 | user_id = payload.get("sub") 125 | 126 | if actual_token_type != token_type: 127 | raise ForbiddenError("Invalid token") 128 | 129 | return uuid.UUID(user_id) 130 | -------------------------------------------------------------------------------- /src/database/repositories/crud.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping, Sequence 4 | from typing import ( 5 | Any, 6 | Optional, 7 | cast, 8 | ) 9 | 10 | from sqlalchemy import ( 11 | ColumnExpressionArgument, 12 | CursorResult, 13 | UnaryExpression, 14 | delete, 15 | exists, 16 | func, 17 | insert, 18 | select, 19 | update, 20 | ) 21 | from sqlalchemy.dialects.postgresql import insert as postgres_insert 22 | from sqlalchemy.ext.asyncio import AsyncSession 23 | 24 | from src.database.interfaces.crud import AbstractCRUDRepository 25 | from src.database.models.base import Base 26 | 27 | 28 | class CRUDRepository[M: Base](AbstractCRUDRepository[M]): 29 | __slots__ = ("_session",) 30 | 31 | def __init__(self, session: AsyncSession, model: type[M]) -> None: 32 | super().__init__(model) 33 | self._session = session 34 | 35 | async def insert(self, **values: Any) -> Optional[M]: 36 | stmt = insert(self.model).values(**values).returning(self.model) 37 | return (await self._session.execute(stmt)).scalars().first() 38 | 39 | async def insert_many(self, data: Sequence[Mapping[str, Any]]) -> Sequence[M]: 40 | stmt = insert(self.model).returning(self.model) 41 | result = await self._session.scalars(stmt, data) 42 | return result.all() 43 | 44 | async def select( 45 | self, 46 | *clauses: ColumnExpressionArgument[bool], 47 | ) -> Optional[M]: 48 | stmt = select(self.model).where(*clauses) 49 | return (await self._session.execute(stmt)).scalars().first() 50 | 51 | async def select_many( 52 | self, 53 | *clauses: ColumnExpressionArgument[bool], 54 | order_by: Optional[list[UnaryExpression[Any]]] = None, 55 | offset: Optional[int] = None, 56 | limit: Optional[int] = None, 57 | ) -> Sequence[M]: 58 | stmt = select(self.model).where(*clauses).offset(offset).limit(limit) 59 | if order_by: 60 | stmt = stmt.order_by(*order_by) 61 | 62 | return (await self._session.execute(stmt)).scalars().all() 63 | 64 | async def update( 65 | self, *clauses: ColumnExpressionArgument[bool], **values: Any 66 | ) -> Sequence[M]: 67 | stmt = update(self.model).where(*clauses).values(**values).returning(self.model) 68 | return (await self._session.execute(stmt)).scalars().all() 69 | 70 | async def update_many(self, data: Sequence[Mapping[str, Any]]) -> CursorResult[Any]: 71 | return await self._session.execute(update(self.model), data) 72 | 73 | async def delete(self, *clauses: ColumnExpressionArgument[bool]) -> Sequence[M]: 74 | stmt = delete(self.model).where(*clauses).returning(self.model) 75 | return (await self._session.execute(stmt)).scalars().all() 76 | 77 | async def exists(self, *clauses: ColumnExpressionArgument[bool]) -> bool: 78 | stmt = exists(select(self.model).where(*clauses)).select() 79 | return cast(bool, await self._session.scalar(stmt)) 80 | 81 | async def count(self, *clauses: ColumnExpressionArgument[bool]) -> int: 82 | stmt = select(func.count()).where(*clauses).select_from(self.model) 83 | return cast(int, await self._session.scalar(stmt)) 84 | 85 | async def upsert(self, *conflict_columns: str, **values: Any) -> Optional[M]: 86 | stmt = ( 87 | postgres_insert(self.model) 88 | .values(**values) 89 | .on_conflict_do_update( 90 | index_elements=conflict_columns, 91 | set_={ 92 | key: values[key] for key in values if key not in conflict_columns 93 | }, 94 | ) 95 | .returning(self.model) 96 | ) 97 | return (await self._session.execute(stmt)).scalars().first() 98 | 99 | async def upsert_many( 100 | self, *conflict_columns: str, data: Sequence[Mapping[str, Any]] 101 | ) -> Sequence[M]: 102 | stmt = postgres_insert(self.model).values(data) 103 | stmt = stmt.on_conflict_do_update( 104 | index_elements=conflict_columns, 105 | set_={ 106 | key: getattr(stmt.excluded, key) 107 | for key in data[0] 108 | if key not in conflict_columns 109 | }, 110 | ).returning(self.model) # type: ignore 111 | 112 | return (await self._session.execute(stmt)).scalars().all() 113 | 114 | def with_query_model(self, model: type[M]) -> CRUDRepository[M]: 115 | return CRUDRepository(self._session, model) 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Litestar Template 🚀 2 | 3 |