├── src ├── tests │ ├── __init__.py │ ├── etc │ │ └── jwt.txt │ ├── 03_main_test.py │ ├── payloads.py │ ├── payload_models.py │ ├── 05_config_test.py │ ├── shared.py │ ├── configs │ │ └── bad_field.toml │ ├── 02_startup_events_test.py │ ├── 01_cli_test.py │ └── 04_accounting_test.py ├── accounting │ ├── piccolo_migrations │ │ └── __init__.py │ ├── authentication │ │ ├── __init__.py │ │ ├── schemas.py │ │ ├── routing.py │ │ ├── endpoints.py │ │ ├── jwt.py │ │ └── models.py │ ├── roles │ │ ├── __init__.py │ │ ├── routing.py │ │ ├── endpoints.py │ │ └── models.py │ ├── groups │ │ ├── __init__.py │ │ ├── routing.py │ │ ├── endpoints.py │ │ └── models.py │ ├── users │ │ ├── __init__.py │ │ ├── routing.py │ │ ├── endpoints.py │ │ └── models.py │ ├── rbac │ │ ├── __init__.py │ │ ├── checks.py │ │ ├── endpoints.py │ │ ├── routing.py │ │ └── models.py │ ├── __init__.py │ ├── piccolo_app.py │ ├── decorators.py │ └── schemas.py ├── configuration │ ├── piccolo_migrations │ │ └── __init__.py │ ├── __init__.py │ ├── loader.py │ ├── base.py │ ├── piccolo_app.py │ ├── model.py │ └── sections.py ├── utils │ ├── __init__.py │ ├── security.py │ ├── routing.py │ ├── crypto.py │ ├── api_versioning.py │ ├── exceptions.py │ ├── telemetry.py │ ├── logger.py │ ├── id_propagation.py │ ├── piccolo.py │ ├── events.py │ └── vault.py ├── cli │ ├── __init__.py │ ├── shared.py │ ├── console.py │ ├── config_loader.py │ ├── accounting.py │ └── db.py ├── context │ └── __init__.py ├── run.py ├── main.py ├── config.ini ├── config.yaml ├── config.toml ├── app.py └── piccolo_conf.py ├── CONTRIBUTING.md ├── mkdocs.yml ├── SECURITY.md ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── codeql-analysis.yml │ ├── apisec-scan.yml │ └── ci.yml ├── ci ├── init_db.py └── init_vault.py ├── LICENSE ├── docs ├── index.md └── cli.md ├── README.md ├── .gitignore ├── pyproject.toml └── CODE_OF_CONDUCT.md /src/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tests/etc/jwt.txt: -------------------------------------------------------------------------------- 1 | localfilesecret -------------------------------------------------------------------------------- /src/accounting/piccolo_migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/configuration/piccolo_migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/configuration/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import config 2 | -------------------------------------------------------------------------------- /src/accounting/authentication/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Sessions 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .vault import ( 2 | Vault 3 | ) 4 | 5 | vault = Vault() 6 | -------------------------------------------------------------------------------- /src/accounting/roles/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Role 2 | from .routing import role_router 3 | -------------------------------------------------------------------------------- /src/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from .db import app as db_app 2 | from .accounting import app as aaa_app 3 | -------------------------------------------------------------------------------- /src/accounting/groups/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Group 2 | from .routing import group_router 3 | -------------------------------------------------------------------------------- /src/accounting/users/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import User, T_U 2 | from .routing import user_router 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | Any PR are welcome 4 | 5 | Or just open an issue for a feature, bug or patch 6 | -------------------------------------------------------------------------------- /src/utils/security.py: -------------------------------------------------------------------------------- 1 | from secrets import token_hex 2 | 3 | 4 | def generate_random_string_token(bytes: int = 64) -> str: 5 | return token_hex(bytes) 6 | -------------------------------------------------------------------------------- /src/configuration/loader.py: -------------------------------------------------------------------------------- 1 | from .model import Configuration 2 | from context import config_file 3 | config = Configuration() 4 | config.load(config_file.get()) 5 | -------------------------------------------------------------------------------- /src/cli/shared.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | 4 | def prepare_db_through_vault(): 5 | from utils.events import load_vault_db_creds 6 | from utils import vault 7 | asyncio.run(vault.init()) 8 | asyncio.run(load_vault_db_creds()) 9 | -------------------------------------------------------------------------------- /src/accounting/rbac/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Permission, Policy, M2MUserGroup, M2MUserRole 2 | from .routing import ( 3 | rbac_user_router, 4 | rbac_group_router, 5 | rbac_permissions_router, 6 | rbac_policies_router, 7 | rbac_role_router 8 | ) 9 | -------------------------------------------------------------------------------- /src/context/__init__.py: -------------------------------------------------------------------------------- 1 | from contextvars import ContextVar 2 | import os 3 | 4 | DEFAULT_CONFIG_FILENAME: str = os.environ.get( 5 | 'X_FA_CONFIG_FILENAME', 'src/config.toml') 6 | config_file: ContextVar[str] = ContextVar( 7 | 'config_file', default=DEFAULT_CONFIG_FILENAME) 8 | -------------------------------------------------------------------------------- /src/tests/03_main_test.py: -------------------------------------------------------------------------------- 1 | from fastapi.testclient import TestClient 2 | from app import create_app 3 | 4 | app = create_app() 5 | client = TestClient(app) 6 | 7 | 8 | def test_read_main(): 9 | response = client.get("/") 10 | assert response.status_code == 404 11 | -------------------------------------------------------------------------------- /src/accounting/authentication/schemas.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Token(BaseModel): 5 | """READ model for obtaining JWT""" 6 | access_token: str 7 | token_type: str 8 | 9 | 10 | class TokenData(BaseModel): 11 | """Extracted payload from JWT""" 12 | username: str | None = None 13 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: FastAPI-boilerplate 2 | site_description: FastAPI boilerplate with integrated AAA, Piccolo ORM, HC Vault and many other features 3 | repo_url: https://github.com/northpowered/fastapi-boilerplate 4 | repo_name: northpowered/fastapi-boilerplate 5 | theme: 6 | name: material 7 | language: en 8 | plugins: 9 | - tags -------------------------------------------------------------------------------- /src/utils/routing.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from starlette_exporter import handle_metrics 3 | 4 | 5 | misc_router = APIRouter( 6 | prefix="", 7 | tags=["Misc"], 8 | responses={404: {"description": "URL not found"}}, 9 | ) 10 | misc_router.add_api_route("/metrics", handle_metrics, summary='Prometeus metrics', 11 | description='Metrics from starlette_exporter module') 12 | -------------------------------------------------------------------------------- /src/accounting/__init__.py: -------------------------------------------------------------------------------- 1 | from .users import User, user_router 2 | from .roles import Role, role_router 3 | from .groups import Group, group_router 4 | from .rbac import ( 5 | Permission, 6 | Policy, 7 | M2MUserGroup, 8 | M2MUserRole, 9 | rbac_user_router, 10 | rbac_role_router, 11 | rbac_policies_router, 12 | rbac_group_router, 13 | rbac_permissions_router 14 | ) 15 | from .authentication import Sessions 16 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | | Version | Supported | 6 | | ------- | ------------------ | 7 | | 1.0.x | :white_check_mark: | 8 | | < 1.0 | :white_check_mark: | 9 | 10 | ## Reporting a Vulnerability 11 | 12 | All security updates from legacy code will be pushed to master branch in a week after official release 13 | 14 | BTW, you can open an issue or report about any other security update right here 15 | 16 | -------------------------------------------------------------------------------- /src/utils/crypto.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | PASSWORD_SALT: str = '79fd39c66d3_MY_INSECURE_SALT_453722dfb0b38148f4f0905b722283f269d2700dd4d753d' 4 | PASSWORD_ALGORITHM: str = 'sha512' 5 | PASSWORD_ITERATIONS: int = 100000 6 | 7 | 8 | def create_password_hash(plaintext: str) -> str: 9 | return hashlib.pbkdf2_hmac( 10 | PASSWORD_ALGORITHM, 11 | plaintext.encode('utf-8'), 12 | PASSWORD_SALT.encode('utf-8'), 13 | PASSWORD_ITERATIONS 14 | ).hex() 15 | -------------------------------------------------------------------------------- /src/tests/payloads.py: -------------------------------------------------------------------------------- 1 | from .payload_models import UserModel, RoleModel, GroupModel 2 | 3 | test_superuser_1: UserModel = UserModel.create() 4 | test_superuser_2: UserModel = UserModel.create() 5 | test_user_1: UserModel = UserModel.create() 6 | test_user_2: UserModel = UserModel.create() 7 | test_role_1: RoleModel = RoleModel(name='test_role_1') 8 | test_role_2: RoleModel = RoleModel(name='test_role_2') 9 | test_group_1: GroupModel = GroupModel(name='test_group_1') 10 | test_group_2: GroupModel = GroupModel(name='test_group_2') 11 | -------------------------------------------------------------------------------- /src/cli/console.py: -------------------------------------------------------------------------------- 1 | from rich import print as _print 2 | 3 | cli_prefix = "[magenta bold]CLI:>[/magenta bold]" 4 | 5 | 6 | def print(string: str): 7 | _print(f"{cli_prefix} {string}") 8 | 9 | 10 | def info(string: str): 11 | print(f":blue_circle: [blue]{string}[/ blue]") 12 | 13 | 14 | def success(string: str): 15 | print(f":green_circle: [green]{string}[/ green]") 16 | 17 | 18 | def warning(string: str): 19 | print(f":yellow_circle: [yellow]{string}[/ yellow]") 20 | 21 | 22 | def error(string: str): 23 | print(f":red_circle: [red bold]{string}[/ red bold]") 24 | -------------------------------------------------------------------------------- /src/cli/config_loader.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from context import DEFAULT_CONFIG_FILENAME 3 | import os 4 | 5 | 6 | def set_config(config_filename: str, remove_logger: bool = True) -> None: 7 | from context import config_file 8 | if remove_logger: 9 | from loguru import logger 10 | logger.remove() # Logger supression to beauty CLI output 11 | os.environ['X_FA_CONFIG_FILENAME'] = config_filename 12 | config_file.set(config_filename) 13 | 14 | 15 | config_default: str = typer.Option( 16 | default=DEFAULT_CONFIG_FILENAME, 17 | help='Path to CONFIG file' 18 | ) 19 | -------------------------------------------------------------------------------- /src/utils/api_versioning.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter as _APIRouter 2 | from typing import NamedTuple 3 | 4 | 5 | class APIVersion(NamedTuple): 6 | 7 | major: int = 1 8 | minor: int | None = None 9 | 10 | def __str__(self) -> str: 11 | if self.minor is not None: 12 | return f"{self.major}_{self.minor}" 13 | else: 14 | return f"{self.major}" 15 | 16 | 17 | class APIRouter(_APIRouter): 18 | def __init__(self, version: APIVersion | None = None, **kwargs): 19 | super().__init__(**kwargs) 20 | if version: 21 | self.prefix = f"/v{version}{self.prefix}" 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /src/accounting/rbac/checks.py: -------------------------------------------------------------------------------- 1 | from .models import Permission, Policy 2 | from accounting.users.models import User 3 | from accounting.roles.models import Role 4 | from configuration import config 5 | 6 | 7 | async def check_user_endpoint_policy(user: User, endpoint_name: str) -> bool: 8 | if not config.Main.is_prod_mode and not config.Security.is_rbac_enabled: 9 | return True 10 | if user.superuser: 11 | return True 12 | for role in user.roles: # type: ignore 13 | await role.join_m2m() 14 | for policy in role.policies: 15 | if endpoint_name == policy.object: 16 | return True 17 | else: 18 | continue 19 | return False 20 | -------------------------------------------------------------------------------- /src/configuration/base.py: -------------------------------------------------------------------------------- 1 | from pydantic import (BaseModel, ValidationError) 2 | from loguru import logger 3 | 4 | 5 | class BaseSectionModel(BaseModel): 6 | 7 | class Config: 8 | load_failed: bool = False 9 | 10 | def __repr__(self) -> str: 11 | return f"" 12 | 13 | def load(self, section_data: dict, section_name: str): 14 | try: 15 | return self.parse_obj(section_data) 16 | except ValidationError as ex: 17 | error = ex.errors()[0] 18 | self.Config.load_failed = True 19 | logger.error( 20 | f"{section_name} | {error.get('loc')[0]} | {error.get('msg')}" 21 | ) 22 | return None 23 | -------------------------------------------------------------------------------- /src/configuration/piccolo_app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Import all of the Tables subclasses in your app here, and register them with 3 | the APP_CONFIG. 4 | 5 | IMPORTANT! 6 | Do NOT change this file 7 | This piccolo_app is only for a right drop/init management of CLI app 8 | """ 9 | 10 | import os 11 | 12 | from piccolo.conf.apps import AppConfig 13 | 14 | from piccolo.apps.migrations.tables import Migration 15 | 16 | CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) 17 | 18 | 19 | APP_CONFIG = AppConfig( 20 | app_name="configuration", 21 | migrations_folder_path=os.path.join( 22 | CURRENT_DIRECTORY, "piccolo_migrations" 23 | ), 24 | table_classes=[ 25 | Migration 26 | ], 27 | migration_dependencies=[], 28 | commands=[], 29 | ) 30 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | from cli.config_loader import set_config 2 | import uvicorn 3 | 4 | 5 | def run_app(config_file: str, reload: bool): 6 | """ 7 | Runs application with Uvicorn 8 | Application defined in app.py 9 | Method set_config() MUST be invoked before importing 10 | `config` from `configuration` module to set env var with config filename 11 | 12 | Args: 13 | config_file (str): path to config file 14 | reload (bool): watch file changes and reload server (useful for development) 15 | """ 16 | set_config(config_file, remove_logger=False) 17 | from configuration import config 18 | uvicorn.run( 19 | "app:app", 20 | reload=reload, 21 | host=config.Server.bind_address, 22 | port=config.Server.bind_port, 23 | ) 24 | -------------------------------------------------------------------------------- /ci/init_db.py: -------------------------------------------------------------------------------- 1 | from piccolo.engine.postgres import PostgresEngine 2 | import asyncio 3 | dsn: str = "postgresql://test:test@127.0.0.1:5432/test?sslmode=disable" 4 | 5 | engine = PostgresEngine( 6 | config={ 7 | 'dsn': dsn, 8 | } 9 | ) 10 | username: str = 'test3' 11 | database_name = 'test' 12 | asyncio.run(engine.run_ddl( 13 | f"create user {username} with password '{username}'")) 14 | asyncio.run(engine.run_ddl( 15 | f"grant all on database {database_name} to {username};")) 16 | asyncio.run(engine.run_ddl(f"grant all on schema public to {username};")) 17 | asyncio.run(engine.run_ddl( 18 | f"grant ALL ON ALL tables in schema public TO {username};")) 19 | asyncio.run(engine.run_ddl(f"alter database test owner to {username};")) 20 | asyncio.run(engine.run_ddl(f"alter schema public owner to {username};")) 21 | -------------------------------------------------------------------------------- /src/accounting/authentication/routing.py: -------------------------------------------------------------------------------- 1 | from utils.api_versioning import APIRouter, APIVersion 2 | from . import endpoints 3 | from .schemas import ( 4 | Token 5 | ) 6 | from accounting.schemas import UserRead 7 | auth_router = APIRouter( 8 | prefix="/auth", 9 | tags=["AAA->Authentication"], 10 | responses={ 11 | 404: {"description": "URL not found"}, 12 | 400: {"description": "Bad request"} 13 | }, 14 | ) 15 | 16 | auth_router.add_api_route( 17 | '/token', 18 | endpoints.login_for_access_token, 19 | response_model=Token, 20 | summary='Authenticate via JWT Bearer scheme', 21 | methods=['post'] 22 | ) 23 | 24 | auth_router.add_api_route( 25 | '/me', 26 | endpoints.get_current_user, 27 | response_model=UserRead, 28 | summary='Get current user', 29 | methods=['get'] 30 | ) 31 | -------------------------------------------------------------------------------- /src/accounting/piccolo_app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Import all of the Tables subclasses in your app here, and register them with 3 | the APP_CONFIG. 4 | """ 5 | 6 | import os 7 | 8 | from piccolo.conf.apps import AppConfig 9 | 10 | from .rbac import ( 11 | M2MUserGroup, 12 | M2MUserRole, 13 | Permission, 14 | Policy 15 | ) 16 | from .users import User 17 | from .groups import Group 18 | from .roles import Role 19 | from .authentication import Sessions 20 | CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) 21 | 22 | 23 | APP_CONFIG = AppConfig( 24 | app_name="accounting", 25 | migrations_folder_path=os.path.join( 26 | CURRENT_DIRECTORY, "piccolo_migrations" 27 | ), 28 | table_classes=[ 29 | User, 30 | Sessions, 31 | Role, 32 | Group, 33 | Permission, 34 | Policy, 35 | M2MUserGroup, 36 | M2MUserRole 37 | ], 38 | migration_dependencies=[], 39 | commands=[], 40 | ) 41 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from cli import db_app, aaa_app 3 | from cli.config_loader import config_default 4 | 5 | app = typer.Typer(no_args_is_help=True) 6 | 7 | app.add_typer(db_app, name='db') 8 | app.add_typer(aaa_app, name='aaa') 9 | 10 | 11 | @app.command() 12 | def run( 13 | config: str = config_default, 14 | reload: bool = typer.Option( 15 | False, 16 | is_flag=True, 17 | flag_value=True, 18 | help="Allow Uvicorn watch file changes and reload server" 19 | ) 20 | ): 21 | """ 22 | Run application in uvicorn server with defined config file 23 | We don`t test these functions directly, because it`s unable to start 24 | uvicorn server in CI workflows permanently 25 | """ 26 | 27 | from run import run_app # pragma: no cover 28 | run_app( # pragma: no cover 29 | config_file=config, 30 | reload=reload 31 | ) 32 | 33 | 34 | if __name__ == "__main__": 35 | app() # pragma: no cover 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /src/tests/payload_models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from mimesis import Person 3 | 4 | good_email_domains: list[str] = ['mydomain.com'] 5 | bad_email_domains: list[str] = ['mydomain.com1'] 6 | 7 | 8 | class UserModel(BaseModel): 9 | username: str 10 | password: str 11 | email: str # Using str instead of EmailStr to test input validation in CLI and CRUD 12 | 13 | @classmethod 14 | def create(cls, good_emails: bool = True): 15 | person: Person = Person('en') 16 | domains: list[str] = good_email_domains 17 | if not good_emails: 18 | domains = bad_email_domains # pragma: no cover 19 | return UserModel( 20 | username=person.username(), 21 | password=person.password(), 22 | email=person.email(domains=domains, unique=True) 23 | ) 24 | 25 | def to_cli_input(self): 26 | return f"{self.username}\n{self.password}\n{self.email}\n" 27 | 28 | 29 | class RoleModel(BaseModel): 30 | name: str 31 | active: bool = True 32 | 33 | 34 | class GroupModel(BaseModel): 35 | name: str 36 | active: bool = True 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Oleg Romanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/accounting/authentication/endpoints.py: -------------------------------------------------------------------------------- 1 | from fastapi.security import OAuth2PasswordRequestForm 2 | from accounting.users import User 3 | from .jwt import create_access_token, get_user_by_token 4 | from fastapi import Depends 5 | 6 | 7 | async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): 8 | """ 9 | Base auth endpoint with OAuth2PasswordBearer form 10 | for obtaining JWT 11 | 12 | Args:\n 13 | username: str - required 14 | password: str - required 15 | grant_type: str 16 | scope: str 17 | client_id: str 18 | client_secret: str 19 | 20 | 21 | Returns:\n 22 | access_token: str 23 | token_type: str 24 | """ 25 | user = await User.authenticate_user(form_data.username, form_data.password) 26 | access_token = create_access_token( 27 | data={"sub": user.username}) # type: ignore 28 | return {"access_token": access_token, "token_type": "bearer"} 29 | 30 | 31 | async def get_current_user(current_user: User = Depends(get_user_by_token)): 32 | """ 33 | Obtaining {USER} object of authenticated user 34 | 35 | Returns: 36 | {USER}, see accounting.users 37 | """ 38 | return current_user 39 | -------------------------------------------------------------------------------- /src/tests/05_config_test.py: -------------------------------------------------------------------------------- 1 | from .shared import load_config 2 | from configuration.sections import ServerSectionConfiguration 3 | 4 | 5 | def test_conf_load_toml(): 6 | c = load_config('src/config.toml') 7 | assert c 8 | assert not c.Config.load_failed 9 | assert "Accounting->Roles"], 12 | responses={ 13 | 404: {"description": "URL not found"}, 14 | 400: {"description": "Bad request"} 15 | }, 16 | dependencies=[Depends(get_user_by_token)], 17 | version=APIVersion(1) 18 | ) 19 | 20 | role_router.add_api_route( 21 | '/', 22 | RoleCRUD.get_all_roles, 23 | response_model=list[RoleRead], 24 | summary='Get all roles', 25 | methods=['get'] 26 | ) 27 | 28 | role_router.add_api_route( 29 | '/{id}', 30 | RoleCRUD.get_role, 31 | response_model=RoleRead, 32 | summary='Get role by ID pk', 33 | methods=['get'] 34 | ) 35 | 36 | role_router.add_api_route( 37 | '/', 38 | RoleCRUD.create_role, 39 | response_model=RoleRead, 40 | status_code=201, 41 | summary='Create role', 42 | methods=['post']) 43 | 44 | role_router.add_api_route( 45 | '/{id}', 46 | RoleCRUD.update_role, 47 | response_model=RoleRead, 48 | summary='Update role', 49 | methods=['put']) 50 | 51 | role_router.add_api_route( 52 | '/{id}', 53 | RoleCRUD.delete_role, 54 | status_code=204, 55 | summary='Delete role', 56 | methods=['delete']) 57 | -------------------------------------------------------------------------------- /src/accounting/groups/routing.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends 2 | from utils.api_versioning import APIRouter, APIVersion 3 | from .endpoints import GroupCRUD 4 | from accounting.schemas import ( 5 | GroupRead 6 | ) 7 | from accounting.authentication.jwt import get_user_by_token 8 | 9 | group_router = APIRouter( 10 | prefix="/accounting/groups", 11 | tags=["AAA->Accounting->Groups"], 12 | responses={ 13 | 404: {"description": "URL not found"}, 14 | 400: {"description": "Bad request"} 15 | }, 16 | dependencies=[Depends(get_user_by_token)], 17 | version=APIVersion(1) 18 | ) 19 | 20 | group_router.add_api_route( 21 | '/', 22 | GroupCRUD.get_all_groups, 23 | response_model=list[GroupRead], 24 | summary='Get all groups', 25 | methods=['get'] 26 | ) 27 | 28 | group_router.add_api_route( 29 | '/{id}', 30 | GroupCRUD.get_group, 31 | response_model=GroupRead, 32 | summary='Get group by ID pk', 33 | methods=['get'] 34 | ) 35 | 36 | group_router.add_api_route( 37 | '/', 38 | GroupCRUD.create_group, 39 | response_model=GroupRead, 40 | status_code=201, 41 | summary='Create group', 42 | methods=['post']) 43 | 44 | group_router.add_api_route( 45 | '/{id}', 46 | GroupCRUD.update_group, 47 | response_model=GroupRead, 48 | summary='Update group', 49 | methods=['put']) 50 | 51 | group_router.add_api_route( 52 | '/{id}', 53 | GroupCRUD.delete_group, 54 | status_code=204, 55 | summary='Delete group', 56 | methods=['delete']) 57 | -------------------------------------------------------------------------------- /src/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | from email.mime import base 2 | from fastapi import HTTPException 3 | from asyncpg.exceptions import IntegrityConstraintViolationError 4 | from pprint import pprint 5 | 6 | 7 | class IntegrityException(HTTPException): 8 | def __init__(self, base_exception: IntegrityConstraintViolationError): 9 | try: 10 | detail = base_exception.args[0] 11 | except (KeyError, TypeError, ValueError): 12 | detail = base_exception.message 13 | finally: 14 | raise HTTPException( 15 | status_code=400, 16 | detail=detail 17 | ) 18 | 19 | 20 | class ObjectNotFoundException(HTTPException): 21 | def __init__(self, object_name: str, object_id: str): 22 | raise HTTPException( 23 | status_code=404, 24 | detail=f'Object {object_name} with id {object_id} not found' 25 | ) 26 | 27 | 28 | class BaseBadRequestException(HTTPException): 29 | def __init__(self, message: str): 30 | raise HTTPException( 31 | status_code=400, 32 | detail=str(message) 33 | ) 34 | 35 | 36 | class UnauthorizedException(HTTPException): 37 | def __init__(self, details: str): 38 | raise HTTPException( 39 | status_code=401, 40 | detail=str(details) 41 | ) 42 | 43 | 44 | class PermissionDeniedException(HTTPException): 45 | def __init__(self, details: str): 46 | raise HTTPException( 47 | status_code=403, 48 | detail=str(details) 49 | ) 50 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Intro 2 | 3 | This project was created as a template for FastAPI applications. 4 | ## What`s in the box? 5 | 6 | * [FastAPI](https://github.com/tiangolo/fastapi) as a base ASGI app 7 | * [Piccolo ORM](https://github.com/piccolo-orm/piccolo) for a database operations 8 | * [Piccolo Admin GUI](https://github.com/piccolo-orm/piccolo_admin) for a convenient database management 9 | * [Hashicorp Vault](https://github.com/hashicorp/vault) integration for DB credentials (with auto-rotating), JWT secrets and other 10 | * Custom `Accounting` CRUD application for managing 11 | * Users 12 | * Roles 13 | * Groups 14 | * Permissions 15 | * Security policies 16 | * JWT autehntication 17 | * API versioning 18 | * [PyDantic](https://github.com/samuelcolvin/pydantic)-based flexible configfile parcer (`toml` and `yaml` formats supports) 19 | * [Typer](https://github.com/tiangolo/typer)-based CLI management 20 | * [Prometheus](https://github.com/prometheus/prometheus) endpoint based on [Starlette exporter](https://github.com/stephenhillier/starlette_exporter) 21 | * [OpenTelemetry](https://github.com/orgs/open-telemetry) collector 22 | * Request ID propagation for logger, Request and Response (injection to Headers) 23 | * CI pipeline for linting and testing (with coverage) 24 | ## Installation 25 | 26 | We`re strongly reccomend to use [Poetry](https://python-poetry.org/) 27 | 28 | Create new virtual environment or enter to an existing one 29 | 30 | >poetry shell 31 | 32 | Install all dependencies from *pyproject.toml* 33 | 34 | >poetry install 35 | 36 | Run your app 37 | 38 | >python3 main.py run -------------------------------------------------------------------------------- /src/accounting/decorators.py: -------------------------------------------------------------------------------- 1 | from fastapi import Request 2 | from functools import wraps 3 | from utils.exceptions import PermissionDeniedException, UnauthorizedException 4 | from loguru import logger 5 | from .authentication.jwt import get_user_by_token, decode_auth_header 6 | 7 | 8 | def AAA_endpoint_oauth2(): 9 | from .users import User 10 | 11 | def outer_wrapper(func): 12 | # Need for openapi schema, to show, that endpoint needs a permission 13 | func.__setattr__('rbac_enable', True) 14 | 15 | @wraps(func) 16 | async def wrapper(request: Request, *args, **kwargs): 17 | from accounting.rbac.checks import check_user_endpoint_policy 18 | try: 19 | validator, token = decode_auth_header( 20 | request.headers.get('authorization', str())) 21 | current_user: User = await get_user_by_token(token) 22 | assert await check_user_endpoint_policy(current_user, func.__name__), "Permission denied" 23 | except AssertionError as ex: 24 | logger.debug( 25 | f'Access denied | User: {current_user.username} | Object: {func.__name__}') 26 | raise PermissionDeniedException(str(ex)) 27 | except IndexError: 28 | raise UnauthorizedException('Cannot decode token') 29 | else: 30 | logger.debug( 31 | f'Access permitted | User: {current_user.username} | Object: {func.__name__}') 32 | return await func(request, *args, **kwargs) 33 | return wrapper 34 | return outer_wrapper 35 | -------------------------------------------------------------------------------- /src/accounting/users/routing.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends 2 | from .endpoints import UserCRUD 3 | from accounting.schemas import ( 4 | UserRead, 5 | ) 6 | from accounting.authentication.jwt import get_user_by_token 7 | from utils.api_versioning import APIRouter, APIVersion 8 | user_router = APIRouter( 9 | prefix="/accounting/users", 10 | tags=["AAA->Accounting->Users"], 11 | responses={ 12 | 404: {"description": "URL not found"}, 13 | 400: {"description": "Bad request"} 14 | }, 15 | dependencies=[Depends(get_user_by_token)], 16 | version=APIVersion(1) 17 | ) 18 | 19 | user_router.add_api_route( 20 | '/', 21 | UserCRUD.get_all_users, 22 | response_model=list[UserRead], 23 | summary='Get all users', 24 | methods=['get'] 25 | ) 26 | 27 | user_router.add_api_route( 28 | '/{id}', 29 | UserCRUD.get_user, 30 | response_model=UserRead, 31 | summary='Get user by ID pk', 32 | methods=['get'] 33 | ) 34 | 35 | user_router.add_api_route( 36 | '/', 37 | UserCRUD.create_user, 38 | response_model=UserRead, 39 | status_code=201, 40 | summary='Create user', 41 | methods=['post']) 42 | 43 | user_router.add_api_route( 44 | '/{id}', 45 | UserCRUD.update_user, 46 | response_model=UserRead, 47 | summary='Update user', 48 | methods=['put']) 49 | 50 | user_router.add_api_route( 51 | '/{id}', 52 | UserCRUD.patch_user, 53 | response_model=UserRead, 54 | summary='Change user password', 55 | methods=['patch']) 56 | 57 | user_router.add_api_route( 58 | '/{id}', 59 | UserCRUD.delete_user, 60 | status_code=204, 61 | summary='Delete user', 62 | methods=['delete']) 63 | -------------------------------------------------------------------------------- /docs/cli.md: -------------------------------------------------------------------------------- 1 | # CLI commands 2 | 3 | ## Base usage 4 | All commands have `config` option with default value `config.toml` 5 | Be sure about right config file when using AAA or DB operations 6 | ``` 7 | Usage: main.py [OPTIONS] COMMAND [ARGS]... 8 | 9 | Options: 10 | --help Show this message and exit. 11 | 12 | Commands: 13 | aaa Operations with users and other AAA objects 14 | db Operations with DB 15 | run Run application in uvicorn server with defined config file 16 | ``` 17 | ## Development mode 18 | `--reload` option invokes vanilla Uvicorn reload manager 19 | >python3 main.py run --reload 20 | 21 | ## AAA management 22 | ### Creating superuser 23 | 24 | **Superuser** has all available privileges, including an access to Piccolo Admin Gui and ignore all **Roles** and **Policies** restrictions 25 | 26 | >python3 main.py aaa create superuser 27 | 28 | ### Creating user 29 | 30 | Also you can create a simple **User** without any privileges 31 | 32 | >python3 main.py aaa create user 33 | 34 | ### Creating JWT secret 35 | 36 | Then you ca generate JWT secret salt, which will be stored in a place, defined in config file 37 | 38 | >python3 main.py aaa create secret 39 | 40 | ## Database management 41 | 42 | All database commands have an arg `application`, where you can define an app, which tables will be used. Defaults to `all` - it means all available applications in project. 43 | 44 | ### Show DB schema 45 | >python3 main.py db show 46 | ### Create all tables 47 | >python3 main.py db init 48 | ### Drop all tables 49 | >python3 main.py db drop 50 | ### Create migrations without running 51 | >python3 main.py db mg create 52 | ### Run created migrations 53 | >python3 main.py db mg run 54 | -------------------------------------------------------------------------------- /src/tests/shared.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | def clear_migrations_files(): 6 | """ 7 | Find and remove all migration files 8 | """ 9 | current_dir: str = os.getcwd() 10 | migrations: list = list() 11 | for dir in os.walk(current_dir, followlinks=False): 12 | directory = os.path.relpath(dir[0], current_dir) 13 | if not directory.startswith('.') and 'piccolo_migrations' in directory: 14 | current_migrations: list = glob.glob(r'*.py', root_dir=directory) 15 | for current_migration in current_migrations: 16 | if current_migration != '__init__.py': 17 | migrations.append(f"{directory}/{current_migration}") # pragma: no cover 18 | for migration in migrations: 19 | if os.path.isfile(migration): # pragma: no cover 20 | os.remove(migration) # pragma: no cover 21 | 22 | 23 | def prepare_db_with_users(superuser, user): 24 | from cli.db import prepare_db_through_vault 25 | from main import app 26 | from typer.testing import CliRunner 27 | prepare_db_through_vault() 28 | runner = CliRunner() 29 | runner.invoke(app, ["db", "drop", "all"], input="y\n") 30 | runner.invoke(app, ["db", "init", "all"]) 31 | runner.invoke(app, ["aaa", "create", "superuser"], 32 | input=superuser.to_cli_input()) 33 | runner.invoke(app, ["aaa", "create", "user"], input=user.to_cli_input()) 34 | 35 | 36 | def load_config(filename: str): 37 | from configuration.model import Configuration 38 | try: 39 | config = Configuration() 40 | config.load(filename) 41 | except FileNotFoundError: 42 | return None 43 | else: 44 | return config 45 | -------------------------------------------------------------------------------- /src/config.ini: -------------------------------------------------------------------------------- 1 | [Main] 2 | application_mode = dev 3 | log_level = debug 4 | log_destination = stdout 5 | log_in_json = 0 6 | log_sql = 0 7 | timezone = +3 8 | enable_swagger = 1 9 | swagger_doc_url = /doc 10 | swagger_redoc_url = /redoc 11 | enable_security = 1 12 | 13 | [AdminGUI] 14 | admin_enable = 1 15 | admin_url = /admin/ 16 | 17 | 18 | [Server] 19 | bind_address = localhost 20 | bind_port = 8000 21 | base_url = example.com 22 | 23 | 24 | [Vault] 25 | vault_enable = 1 26 | vault_host = localhost 27 | vault_port = 8228 28 | vault_disable_tls = 1 29 | vault_auth_method = token 30 | #vault_token = hvs.e7zbhM4OadYPKTLqGNH9eCci 31 | vault_credentials = 32 | vault_try_to_unseal = 1 33 | #vault_key_type - json | keys 34 | #json - legacy json file from Vault, created at initialization of Vault instance/cluster 35 | # also can contain root_token string, which will be used to access Vault with TOKEN auth_method 36 | # Priority: 37 | # 1) vault_auth_token from config file 38 | # 2) root_token from json file 39 | #keys - simple txt file with unsealing key portions in base64, line by line 40 | vault_keyfile_type = json 41 | vault_unseal_keys = vault-cluster-vault-2022-07-12T08 03 48.497Z.json 42 | #vault_unseal_keys = vault-cluster-vault-2022-07-06T18 47 09.634Z.json 43 | 44 | [Database] 45 | db_driver = postgresql 46 | db_host = 127.0.0.1 47 | db_port = 5432 48 | db_name = fastapi-boilerplate 49 | db_username = fastapi-boilerplate 50 | db_password = fastapi-boilerplate 51 | 52 | db_vault_enable = 1 53 | db_vault_role = testrole02 54 | db_vault_static = 1 55 | db_vault_storage = database 56 | 57 | [Telemetry] 58 | enable = 1 59 | agent_type = jaeger 60 | agent_host = localhost 61 | agent_port = 6831 62 | 63 | [Security] 64 | enable_rbac = 1 -------------------------------------------------------------------------------- /src/config.yaml: -------------------------------------------------------------------------------- 1 | Main: 2 | application_mode: "dev" 3 | log_level: "debug" 4 | log_destination: "stdout" 5 | log_in_json: 0 6 | log_sql: 0 7 | timezone: +3 8 | enable_swagger: 1 9 | swagger_doc_url: "/doc" 10 | swagger_redoc_url: "/redoc" 11 | enable_security: 1 12 | 13 | AdminGUI: 14 | admin_enable: 1 15 | admin_url: "/admin/" 16 | 17 | 18 | Server: 19 | bind_address: "localhost" 20 | bind_port: 8000 21 | base_url: "example.com" 22 | 23 | 24 | Vault: 25 | vault_enable: 1 26 | vault_host: "localhost" 27 | vault_port: 8228 28 | vault_disable_tls: 1 29 | vault_auth_method: "token" 30 | #vault_token = "hvs.e7zbhM4OadYPKTLqGNH9eCci" 31 | #vault_credentials = 32 | vault_try_to_unseal: 1 33 | #vault_key_type - json | keys 34 | #json - legacy json file from Vault, created at initialization of Vault instance/cluster 35 | # also can contain root_token string, which will be used to access Vault with TOKEN auth_method 36 | # Priority: 37 | # 1) vault_auth_token from config file 38 | # 2) root_token from json file 39 | #keys - simple txt file with unsealing key portions in base64, line by line 40 | #vault_keyfile_type: "json" 41 | #vault_unseal_keys: "vault-cluster-vault-2022-07-12T08 03 48.497Z.json" 42 | #vault_unseal_keys = "vault-cluster-vault-2022-07-06T18 47 09.634Z.json" 43 | 44 | Database: 45 | db_driver: "postgresql" 46 | db_host: "127.0.0.1" 47 | db_port: 5432 48 | db_name: "fastapi-boilerplate" 49 | db_username: "fastapi-boilerplate" 50 | db_password: "fastapi-boilerplate" 51 | 52 | db_vault_enable: 1 53 | db_vault_role: "testrole02" 54 | db_vault_static: 1 55 | db_vault_storage: "database" 56 | 57 | Telemetry: 58 | enable: 1 59 | agent_type: "jaeger" 60 | agent_host: "localhost" 61 | agent_port: 6831 62 | 63 | Security: 64 | enable_rbac: 1 -------------------------------------------------------------------------------- /src/config.toml: -------------------------------------------------------------------------------- 1 | [Main] 2 | application_mode = "dev" 3 | log_level = "debug" 4 | log_destination = "stdout" 5 | log_in_json = 0 6 | log_sql = 0 7 | timezone = +3 8 | enable_swagger = 1 9 | swagger_doc_url = "/doc" 10 | swagger_redoc_url = "/redoc" 11 | enable_security = 1 12 | 13 | [AdminGUI] 14 | admin_enable = 1 15 | admin_url = "/admin/" 16 | 17 | 18 | [Server] 19 | bind_address = "localhost" 20 | bind_port = 8000 21 | base_url = "example.com" 22 | 23 | 24 | [Vault] 25 | vault_enable = 1 26 | vault_host = "localhost" 27 | vault_port = 8200 28 | vault_disable_tls = 1 29 | vault_auth_method = "token" 30 | vault_token = "test" 31 | #vault_credentials = 32 | vault_try_to_unseal = 1 33 | #vault_key_type - json | keys 34 | #json - legacy json file from Vault, created at initialization of Vault instance/cluster 35 | # also can contain root_token string, which will be used to access Vault with TOKEN auth_method 36 | # Priority: 37 | # 1) vault_auth_token from config file 38 | # 2) root_token from json file 39 | #keys - simple txt file with unsealing key portions in base64, line by line 40 | #vault_keyfile_type = 41 | #vault_unseal_keys = 42 | 43 | [Database] 44 | db_driver = "postgresql" 45 | db_host = "127.0.0.1" 46 | db_port = 5432 47 | db_name = "test" 48 | db_username = "test" 49 | db_password = "test" 50 | 51 | db_vault_enable = 1 52 | db_vault_role = "testrole" 53 | db_vault_static = 1 54 | db_vault_storage = "database" 55 | 56 | [Telemetry] 57 | enable = 1 58 | agent_type = "jaeger" 59 | agent_host = "localhost" 60 | agent_port = 6831 61 | trace_id_length = 12 62 | 63 | [Security] 64 | enable_rbac = 1 65 | login_with_username = 1 66 | login_with_email = 1 67 | jwt_algorithm = "HS256" 68 | jwt_ttl = 3600 69 | jwt_base_secret = "dev-secret-from-configfile" 70 | jwt_base_secret_storage = 'vault' 71 | jwt_base_secret_filename = 'secret.key1' 72 | jwt_base_secret_vault_storage_name = 'kv_test' 73 | jwt_base_secret_vault_secret_name = 'jwt' -------------------------------------------------------------------------------- /src/tests/configs/bad_field.toml: -------------------------------------------------------------------------------- 1 | [Main] 2 | application_mode = "dev1" 3 | log_level = "debug" 4 | log_destination = "stdout" 5 | log_in_json = 0 6 | log_sql = 0 7 | timezone = +3 8 | enable_swagger = 1 9 | swagger_doc_url = "/doc" 10 | swagger_redoc_url = "/redoc" 11 | enable_security = 1 12 | 13 | [AdminGUI] 14 | admin_enable = 1 15 | admin_url = "/admin/" 16 | 17 | 18 | [Server] 19 | bind_address = "localhost" 20 | bind_port = 8000 21 | base_url = "example.com" 22 | 23 | 24 | [Vault] 25 | vault_enable = 1 26 | vault_host = "localhost" 27 | vault_port = 8200 28 | vault_disable_tls = 1 29 | vault_auth_method = "token" 30 | vault_token = "test" 31 | #vault_credentials = 32 | vault_try_to_unseal = 1 33 | #vault_key_type - json | keys 34 | #json - legacy json file from Vault, created at initialization of Vault instance/cluster 35 | # also can contain root_token string, which will be used to access Vault with TOKEN auth_method 36 | # Priority: 37 | # 1) vault_auth_token from config file 38 | # 2) root_token from json file 39 | #keys - simple txt file with unsealing key portions in base64, line by line 40 | #vault_keyfile_type = 41 | #vault_unseal_keys = 42 | 43 | [Database] 44 | db_driver = "postgresql" 45 | db_host = "127.0.0.1" 46 | db_port = 5432 47 | db_name = "test" 48 | db_username = "test" 49 | db_password = "test" 50 | 51 | db_vault_enable = 1 52 | db_vault_role = "testrole" 53 | db_vault_static = 1 54 | db_vault_storage = "database" 55 | 56 | [Telemetry] 57 | enable = 1 58 | agent_type = "jaeger" 59 | agent_host = "localhost" 60 | agent_port = 6831 61 | trace_id_length = 12 62 | 63 | [Security] 64 | enable_rbac = 1 65 | login_with_username = 1 66 | login_with_email = 1 67 | jwt_algorithm = "HS256" 68 | jwt_ttl = 3600 69 | jwt_base_secret = "dev-secret-from-configfile" 70 | jwt_base_secret_storage = 'vault' 71 | jwt_base_secret_filename = 'secret.key1' 72 | jwt_base_secret_vault_storage_name = 'kv_test' 73 | jwt_base_secret_vault_secret_name = 'jwt' -------------------------------------------------------------------------------- /ci/init_vault.py: -------------------------------------------------------------------------------- 1 | import requests # type: ignore 2 | 3 | URL: str = "http://127.0.0.1:8200/" 4 | TOKEN: str = "test" 5 | HEADERS: dict = {'X-Vault-Token': TOKEN} 6 | 7 | database_mount: str = "database" 8 | kv_mount: str = "kv_test" 9 | 10 | db_host = 'postgres:5432' 11 | db_dsn: str = f"postgresql://test:test@{db_host}/test?sslmode=disable" 12 | db_role: str = "testrole" 13 | kv_secret_name: str = "jwt" 14 | 15 | 16 | def post(path: str, data: dict) -> requests.Response: 17 | return requests.post( 18 | url=f"{URL}{path}", 19 | json=data, 20 | headers=HEADERS 21 | ) 22 | 23 | 24 | """ VAULT DATABASE INIT """ 25 | 26 | print("Creating database secret engine") 27 | resp = post(f'v1/sys/mounts/{database_mount}', {"type": "database"}) 28 | print(f"{resp.status_code} --- {resp.text}") 29 | 30 | print("Creating database connection") 31 | resp = post( 32 | path=f"v1/{database_mount}/config/postgresql", 33 | data={ 34 | "plugin_name": "postgresql-database-plugin", 35 | "allowed_roles": "*", 36 | "connection_url": db_dsn, 37 | "username": "test", 38 | "password": "test" 39 | } 40 | ) 41 | print(f"{resp.status_code} --- {resp.text}") 42 | 43 | print("Creating static role") 44 | resp = post( 45 | path=f"v1/{database_mount}/static-roles/{db_role}", 46 | data={ 47 | "db_name": "postgresql", 48 | "rotation_statements": "ALTER USER \"{{name}}\" WITH PASSWORD '{{password}}';", 49 | "username": "test3", 50 | "rotation_period": "86400" 51 | } 52 | ) 53 | print(f"{resp.status_code} --- {resp.text}") 54 | 55 | """ VAULT KV INIT """ 56 | 57 | print("Creating KVv2 storage") 58 | resp = post( 59 | path=f"v1/sys/mounts/{kv_mount}", 60 | data={ 61 | "type": "kv", 62 | "options": { 63 | "version": "2" 64 | } 65 | } 66 | ) 67 | print(f"{resp.status_code} --- {resp.text}") 68 | 69 | print("Push something to KV") 70 | resp = post( 71 | path=f"v1/{kv_mount}/data/{kv_secret_name}", 72 | data={ 73 | "data": { 74 | "base_secret": "foobar" 75 | } 76 | } 77 | ) 78 | print(f"{resp.status_code} --- {resp.text}") 79 | -------------------------------------------------------------------------------- /src/utils/telemetry.py: -------------------------------------------------------------------------------- 1 | from opentelemetry import trace 2 | from opentelemetry.exporter.jaeger.thrift import JaegerExporter 3 | from opentelemetry.sdk.resources import SERVICE_NAME, Resource 4 | from opentelemetry.sdk.trace import TracerProvider 5 | from opentelemetry.sdk.trace.export import BatchSpanProcessor 6 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor 7 | from configuration import config 8 | from loguru import logger 9 | 10 | 11 | def server_request_hook(span, scope: dict): 12 | if span and span.is_recording(): 13 | pass 14 | 15 | 16 | def client_request_hook(span, scope: dict): 17 | if span and span.is_recording(): 18 | pass 19 | 20 | 21 | def client_response_hook(span, message: dict): 22 | if span and span.is_recording(): 23 | pass 24 | 25 | 26 | def enable_tracing(app): 27 | trace.set_tracer_provider( 28 | TracerProvider( 29 | resource=Resource.create({SERVICE_NAME: app.title}) 30 | ) 31 | ) 32 | 33 | jaeger_exporter = JaegerExporter( 34 | # configure agent 35 | agent_host_name=config.Telemetry.agent_host, 36 | # optional: configure also collector 37 | agent_port=config.Telemetry.agent_port, 38 | # collector_endpoint='http://localhost:14268/api/traces?format=jaeger.thrift', 39 | # username=xxxx, # optional 40 | # password=xxxx, # optional 41 | # max_tag_value_length=None # optionalormat=jaeger.thrift', 42 | # username=xxxx, # optional 43 | # password=xxxx, # optional 44 | # max_tag_value_length=None # optional 45 | ) 46 | 47 | span_processor = BatchSpanProcessor(jaeger_exporter) 48 | trace.get_tracer_provider().add_span_processor(span_processor) 49 | 50 | FastAPIInstrumentor.instrument_app(app) 51 | FastAPIInstrumentor().instrument( 52 | server_request_hook=server_request_hook, 53 | client_request_hook=client_request_hook, 54 | client_response_hook=client_response_hook 55 | ) 56 | logger.info( 57 | f'Telemetry exporter to {config.Telemetry.agent_host}:{config.Telemetry.agent_port} for {config.Telemetry.agent_type} enabled' 58 | ) 59 | 60 | 61 | tracer = trace.get_tracer(__name__) 62 | -------------------------------------------------------------------------------- /src/tests/02_startup_events_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from fastapi.testclient import TestClient 3 | from app import create_app 4 | 5 | app = create_app() 6 | client = TestClient(app) 7 | 8 | 9 | def test_load_endpoints(): 10 | from utils.events import load_endpoints 11 | load_endpoints(app) 12 | 13 | 14 | def test_enable_telemetry(): 15 | from utils.telemetry import enable_tracing 16 | enable_tracing(app) 17 | 18 | 19 | def test_creating_admin_gui(): 20 | from utils.events import create_admin_gui 21 | create_admin_gui(app, '/admin', 'foobar') 22 | 23 | 24 | def test_load_endpoint_permissions(): 25 | from utils.events import load_endpoint_permissions 26 | from cli.db import prepare_db_through_vault 27 | prepare_db_through_vault() 28 | asyncio.run(load_endpoint_permissions(app)) 29 | 30 | 31 | def test_load_base_jwt_secret_from_config(): 32 | from utils.events import load_base_jwt_secret 33 | asyncio.run(load_base_jwt_secret()) 34 | 35 | 36 | def test_load_base_jwt_secret_from_file(): 37 | from configuration import config 38 | from utils.events import load_base_jwt_secret 39 | asyncio.run( 40 | load_base_jwt_secret( 41 | jwt_base_secret=None, 42 | jwt_base_secret_storage='local', 43 | jwt_base_secret_filename='src/tests/etc/jwt.txt' 44 | ) 45 | ) 46 | assert config.Security.get_jwt_base_secret() == 'localfilesecret' 47 | 48 | 49 | def test_load_base_jwt_secret_from_vault(): 50 | from configuration import config 51 | from utils.events import load_base_jwt_secret 52 | from utils.vault import Vault 53 | vault: Vault = Vault( 54 | auth=Vault.VaultAuth( 55 | auth_method='token', 56 | token='test' 57 | ) 58 | ) 59 | asyncio.run( 60 | load_base_jwt_secret( 61 | jwt_base_secret=None, 62 | jwt_base_secret_vault_secret_name='jwt', 63 | jwt_base_secret_vault_storage_name='kv_test', 64 | vault=vault 65 | ) 66 | ) 67 | assert isinstance(config.Security.get_jwt_base_secret(), str) 68 | assert len(config.Security.get_jwt_base_secret()) > 10 69 | 70 | 71 | def test_reload_db_creds(): 72 | from utils.events import reload_db_creds 73 | asyncio.run(reload_db_creds()) 74 | -------------------------------------------------------------------------------- /src/accounting/roles/endpoints.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from .models import Role 3 | from accounting.schemas import ( 4 | RoleCreate, 5 | RoleUpdate, 6 | ) 7 | from fastapi import Request, Response 8 | from accounting.decorators import AAA_endpoint_oauth2 9 | 10 | 11 | class RoleCRUD(): 12 | 13 | @staticmethod 14 | @AAA_endpoint_oauth2() 15 | async def get_all_roles(request: Request, offset: int = 0, limit: int = 100): 16 | """ 17 | ### READ list[Role] with offset and limit 18 | #### Args:\n 19 | offset (int, optional): Defaults to 0.\n 20 | limit (int, optional): Defaults to 100.\n 21 | #### Returns: 22 | list[Role] 23 | """ 24 | return await Role.get_all(offset=offset, limit=limit) 25 | 26 | @staticmethod 27 | @AAA_endpoint_oauth2() 28 | async def get_role(request: Request, id: str): 29 | """ 30 | ### READ one {Role} by id 31 | #### Args:\n 32 | id (str): UUID4 PK 33 | #### Returns: 34 | Role | None 35 | """ 36 | return await Role.get_by_id(id) 37 | 38 | @staticmethod 39 | @AAA_endpoint_oauth2() 40 | async def create_role(request: Request, user: RoleCreate): 41 | """ 42 | ### CREATE role 43 | #### Args:\n 44 | role (Role): { 45 | name: str (Unique) 46 | active: bool = True 47 | } 48 | #### Returns: 49 | Role 50 | """ 51 | return await Role.add(**user.dict()) 52 | 53 | @staticmethod 54 | @AAA_endpoint_oauth2() 55 | async def update_role(request: Request, id: str, role: RoleUpdate): 56 | """ 57 | ### Update one role (full or partial) 58 | Args:\n 59 | role (Role): { 60 | name: str (Unique) 61 | active: bool 62 | } 63 | Returns: 64 | Role 65 | """ 66 | return await Role.update_by_id(id=id, data=role.dict(exclude_none=True)) 67 | 68 | @staticmethod 69 | @AAA_endpoint_oauth2() 70 | async def delete_role(request: Request, id: str): 71 | """ 72 | ### DELETE one role by ID 73 | #### Args:\n 74 | id (str): UUID4 PK 75 | #### Returns: 76 | None, code=204 77 | """ 78 | await Role.delete_by_id(id) 79 | return Response(status_code=204) 80 | -------------------------------------------------------------------------------- /src/accounting/groups/endpoints.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from .models import Group 3 | from accounting.schemas import ( 4 | GroupCreate, 5 | GroupUpdate 6 | ) 7 | from fastapi import Request, Response 8 | from accounting.decorators import AAA_endpoint_oauth2 9 | 10 | 11 | class GroupCRUD(): 12 | 13 | @staticmethod 14 | @AAA_endpoint_oauth2() 15 | async def get_all_groups(request: Request, offset: int = 0, limit: int = 100): 16 | """ 17 | ### READ list[Group] with offset and limit 18 | #### Args:\n 19 | offset (int, optional): Defaults to 0.\n 20 | limit (int, optional): Defaults to 100.\n 21 | #### Returns: 22 | list[Group] 23 | """ 24 | return await Group.get_all(offset=offset, limit=limit) 25 | 26 | @staticmethod 27 | @AAA_endpoint_oauth2() 28 | async def get_group(request: Request, id: str = str()): 29 | """ 30 | ### READ one {Group} by id 31 | #### Args:\n 32 | id (str): UUID4 PK 33 | #### Returns: 34 | Group | None 35 | """ 36 | return await Group.get_by_id(id) 37 | 38 | @staticmethod 39 | @AAA_endpoint_oauth2() 40 | async def create_group(request: Request, user: GroupCreate): 41 | """ 42 | ### CREATE group 43 | #### Args:\n 44 | group (Group): { 45 | name: str (Unique) 46 | active: bool = True 47 | } 48 | #### Returns: 49 | Group 50 | """ 51 | return await Group.add(**user.dict()) 52 | 53 | @staticmethod 54 | @AAA_endpoint_oauth2() 55 | async def update_group(request: Request, id: str, group: GroupUpdate): 56 | """ 57 | ### Update one group (full or partial) 58 | Args:\n 59 | group (Group): { 60 | name: str (Unique) 61 | active: bool 62 | } 63 | Returns: 64 | Group 65 | """ 66 | return await Group.update_by_id(id=id, data=group.dict(exclude_unset=True)) 67 | 68 | @staticmethod 69 | @AAA_endpoint_oauth2() 70 | async def delete_group(request: Request, id: str): 71 | """ 72 | ### DELETE one group by ID 73 | #### Args:\n 74 | id (str): UUID4 PK 75 | #### Returns: 76 | None, code=204 77 | """ 78 | await Group.delete_by_id(id) 79 | return Response(status_code=204) 80 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI as _FastAPI 2 | 3 | 4 | class FastAPI(_FastAPI): 5 | def __init__(self, **kwargs): 6 | super().__init__(**kwargs) 7 | 8 | 9 | def create_app() -> FastAPI: 10 | """ 11 | Creates and returns FastAPI application object 12 | Loads all configuration and executes events 13 | 14 | Returns: 15 | FastAPI: app object 16 | """ 17 | from starlette_exporter import PrometheusMiddleware 18 | from utils.logger import setup_logging 19 | from utils.telemetry import enable_tracing 20 | from utils.id_propagation import IDPropagationMiddleware 21 | from utils import events 22 | from configuration import config 23 | __title__ = "FastAPI boilerplate" 24 | __doc__ = "Your project description" 25 | __version__ = "1.0.0" 26 | __doc_url__ = config.Main.doc_url 27 | __redoc_url__ = config.Main.redoc_url 28 | 29 | # You should import logger from loguru after setup_logging() 30 | # for right logger initialization 31 | setup_logging() 32 | from loguru import logger 33 | 34 | app = FastAPI( 35 | title=__title__, 36 | description=__doc__, 37 | version=__version__, 38 | redoc_url=__redoc_url__, 39 | docs_url=__doc_url__, 40 | # swagger_ui_init_oauth={"realm":"qqq"} 41 | 42 | ) 43 | events.load_endpoints(app) 44 | 45 | @app.on_event("startup") 46 | async def startup_event(): 47 | # We don`t test startup_event directly 48 | # Tests are written for each event function 49 | app.add_middleware(PrometheusMiddleware) # pragma: no cover 50 | app.add_middleware(IDPropagationMiddleware) # pragma: no cover 51 | if config.Telemetry.is_active: # pragma: no cover 52 | enable_tracing(app) # pragma: no cover 53 | 54 | if config.AdminGUI.is_admin_gui_enable: # pragma: no cover 55 | events.create_admin_gui( # pragma: no cover 56 | app=app, 57 | admin_url=config.AdminGUI.admin_url, 58 | site_name=__title__ 59 | ) 60 | await events.init_vault() # pragma: no cover 61 | await events.load_vault_db_creds() # pragma: no cover 62 | await events.load_endpoint_permissions(app) # pragma: no cover 63 | await events.load_base_jwt_secret() # pragma: no cover 64 | 65 | @app.on_event("shutdown") 66 | async def shutdown_event(): 67 | logger.warning('Application is shutting down') # pragma: no cover 68 | 69 | return app 70 | 71 | 72 | app = create_app() 73 | -------------------------------------------------------------------------------- /src/accounting/authentication/jwt.py: -------------------------------------------------------------------------------- 1 | from fastapi.security import OAuth2PasswordBearer 2 | from fastapi import Depends 3 | from jose import JWTError, jwt 4 | from datetime import timedelta, datetime 5 | from accounting.users import User 6 | from typing import Type 7 | from utils.exceptions import UnauthorizedException 8 | from configuration import config 9 | SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" 10 | 11 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") 12 | 13 | 14 | def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: 15 | """ 16 | Creates JWT signed token from any payload, with expires time 17 | 18 | Args: 19 | data (dict): payload for token 20 | expires_delta (timedelta | None, optional): exp time in timedelta format. 21 | Defaults to None. 22 | 23 | Returns: 24 | str: JWT 25 | """ 26 | to_encode = data.copy() 27 | if expires_delta: 28 | expire = datetime.utcnow() + expires_delta 29 | else: 30 | expire = datetime.utcnow() + timedelta(seconds=config.Security.jwt_ttl) 31 | to_encode.update({"exp": expire}) 32 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, 33 | algorithm=config.Security.jwt_algorithm) 34 | return encoded_jwt 35 | 36 | 37 | def decode_access_token(token: str) -> dict: 38 | """ 39 | Decode string JWT token 40 | 41 | Args: 42 | token (str): JWT 43 | 44 | Raises: 45 | UnauthorizedException: when token is invalid 46 | 47 | Returns: 48 | dict: extracted payload 49 | """ 50 | try: 51 | payload = jwt.decode( 52 | token, SECRET_KEY, algorithms=config.Security._available_jwt_algorithms) 53 | except JWTError: 54 | raise UnauthorizedException('Cannot decode token') 55 | else: 56 | return payload 57 | 58 | 59 | async def get_user_by_token(token: str = Depends(oauth2_scheme)) -> User: 60 | """ 61 | Returns USER data for username from token, if exists 62 | 63 | Args: 64 | token (str, optional): JWT 65 | 66 | Returns: 67 | User: see accounting.users 68 | """ 69 | payload: dict = decode_access_token(token) 70 | username: str = payload.get('sub', str()) 71 | return await User.get_by_username(username) # type: ignore 72 | 73 | 74 | def decode_auth_header(header: str) -> tuple[str, str]: 75 | try: 76 | chunks: list = header.split(' ') 77 | assert len(chunks) == 2, 'Bad header' 78 | return (chunks[0], chunks[1]) 79 | except AssertionError as ex: 80 | raise UnauthorizedException(str(ex)) 81 | except IndexError: 82 | raise UnauthorizedException('Wrong header payload') 83 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from loguru import logger 4 | from configuration import config 5 | from .id_propagation import TraceIdFilter 6 | import traceback 7 | 8 | TRACE_ID_LENGTH: int = 12 # Replace to config file 9 | 10 | 11 | class InterceptHandler(logging.Handler): 12 | def emit(self, record): 13 | extra_data: dict = dict() 14 | try: 15 | # Trying to catch `trace_id` and exclude None, if cought 16 | assert record.trace_id 17 | extra_data['trace_id'] = record.trace_id 18 | except (AttributeError, AssertionError): 19 | pass 20 | if record.exc_info: 21 | extra_data['exc_info'] = record.exc_info 22 | # Get corresponding Loguru level if it exists 23 | try: 24 | level = logger.level(record.levelname).name 25 | except ValueError: 26 | level = record.levelno 27 | 28 | # Find caller from where originated the logged message 29 | frame, depth = logging.currentframe(), 2 30 | while frame.f_code.co_filename == logging.__file__: 31 | frame = frame.f_back 32 | depth += 1 33 | # Inject `extra` payload to `message` dict 34 | log = logger.bind(**extra_data) 35 | 36 | log.opt(depth=depth, exception=record.exc_info).log( 37 | level, record.getMessage()) 38 | 39 | 40 | def setup_logging(): 41 | logging.root.handlers = [InterceptHandler()] 42 | 43 | for name in logging.root.manager.loggerDict.keys(): 44 | _logger = logging.getLogger(name) 45 | _logger.handlers = [] 46 | _logger.propagate = True 47 | _logger.setLevel(config.Main.log_level) 48 | if name.startswith('uvicorn'): 49 | _logger.addFilter(TraceIdFilter( 50 | uuid_length=config.Telemetry.trace_id_length)) 51 | 52 | def formatter(record): 53 | base_fmt = "{time:YYYY-MM-DDTHH:mm:ss} | {level: <8} | {module: <16}" 54 | extra: dict = record.get('extra', dict()) 55 | exception = record.get('exception') 56 | try: 57 | trace_id = extra['trace_id'] 58 | base_fmt = base_fmt + f" | [{trace_id}]" 59 | except KeyError: 60 | pass 61 | if exception: 62 | extra["traceback"] = "\n" + \ 63 | "".join(traceback.format_exception(extra['exc_info'][1])) 64 | return base_fmt + f"{extra['traceback']}" 65 | return base_fmt + " | {message}\n" 66 | 67 | logger.configure( 68 | handlers=[ 69 | { 70 | "sink": config.Main.log_sink, 71 | "serialize": config.Main.log_in_json, 72 | "level": config.Main.log_level, 73 | "format": formatter, 74 | "colorize": True 75 | } 76 | ] 77 | ) 78 | logger.add(lambda _: os._exit(0), level="CRITICAL") 79 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "master" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "master" ] 20 | schedule: 21 | - cron: '42 23 * * 2' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | # 55 | # 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![CodeFactor](https://www.codefactor.io/repository/github/northpowered/fastapi-boilerplate/badge/master)](https://www.codefactor.io/repository/github/northpowered/fastapi-boilerplate/overview/master) 2 | [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=northpowered_fastapi-boilerplate&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=northpowered_fastapi-boilerplate) 3 | [![CI](https://github.com/northpowered/fastapi-boilerplate/actions/workflows/ci.yml/badge.svg)](https://github.com/northpowered/fastapi-boilerplate/actions/workflows/ci.yml) 4 | [![codecov](https://codecov.io/gh/northpowered/fastapi-boilerplate/branch/master/graph/badge.svg?token=2E6WMLULD7)](https://codecov.io/gh/northpowered/fastapi-boilerplate) 5 | # FastAPI boilerplate 6 | 7 | > Version: 1.1.2 8 | 9 | Work in progress, please read [issues](https://github.com/northpowered/fastapi-boilerplate/issues) 10 | 11 | Full documentation is available on [Github pages](https://northpowered.github.io/fastapi-boilerplate/) 12 | 13 | ## Another [FastAPI](https://github.com/tiangolo/fastapi) Boilerplate with: 14 | * [FastAPI](https://github.com/tiangolo/fastapi) as a base ASGI app 15 | * [Piccolo ORM](https://github.com/piccolo-orm/piccolo) for a database operations 16 | * [Piccolo Admin GUI](https://github.com/piccolo-orm/piccolo_admin) for a convenient database management 17 | * [Hashicorp Vault](https://github.com/hashicorp/vault) integration for DB credentials (with auto-rotating), JWT secrets and other 18 | * Custom `Accounting` CRUD application for managing 19 | * Users 20 | * Roles 21 | * Groups 22 | * Permissions 23 | * Security policies 24 | * JWT autehntication 25 | * API versioning 26 | * [PyDantic](https://github.com/samuelcolvin/pydantic)-based flexible configfile parcer (`toml` and `yaml` formats supports) 27 | * [Typer](https://github.com/tiangolo/typer)-based CLI management 28 | * [Prometheus](https://github.com/prometheus/prometheus) endpoint based on [Starlette exporter](https://github.com/stephenhillier/starlette_exporter) 29 | * [OpenTelemetry](https://github.com/orgs/open-telemetry) collector 30 | * Request ID propagation for logger, Request and Response (injection to Headers) 31 | * CI pipeline for linting and testing (with coverage) 32 | 33 | ## Usage 34 | #### Base usage 35 | ``` 36 | Usage: main.py [OPTIONS] COMMAND [ARGS]... 37 | 38 | Options: 39 | --help Show this message and exit. 40 | 41 | Commands: 42 | aaa Operations with users and other AAA objects 43 | db Operations with DB 44 | run Run application in uvicorn server with defined config file 45 | ``` 46 | All CLI commands with descriptions are placed [here](https://northpowered.github.io/fastapi-boilerplate/cli/) 47 | 48 | ## Installation 49 | 50 | We`re strongly reccomend to use [Poetry](https://python-poetry.org/) 51 | 52 | Create new virtual environment or enter to an existing one 53 | 54 | >poetry shell 55 | 56 | Install all dependencies from *pyproject.toml* 57 | 58 | >poetry install 59 | 60 | Run your app 61 | 62 | >python3 main.py run -------------------------------------------------------------------------------- /.github/workflows/apisec-scan.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | # APIsec addresses the critical need to secure APIs before they reach production. 7 | # APIsec provides the industry’s only automated and continuous API testing platform that uncovers security vulnerabilities and logic flaws in APIs. 8 | # Clients rely on APIsec to evaluate every update and release, ensuring that no APIs go to production with vulnerabilities. 9 | 10 | # How to Get Started with APIsec.ai 11 | # 1. Schedule a demo at https://www.apisec.ai/request-a-demo . 12 | # 13 | # 2. Register your account at https://cloud.fxlabs.io/#/signup . 14 | # 15 | # 3. Register your API . See the video (https://www.youtube.com/watch?v=MK3Xo9Dbvac) to get up and running with APIsec quickly. 16 | # 17 | # 4. Get GitHub Actions scan attributes from APIsec Project -> Configurations -> Integrations -> CI-CD -> GitHub Actions 18 | # 19 | # apisec-run-scan 20 | # 21 | # This action triggers the on-demand scans for projects registered in APIsec. 22 | # If your GitHub account allows code scanning alerts, you can then upload the sarif file generated by this action to show the scan findings. 23 | # Else you can view the scan results from the project home page in APIsec Platform. 24 | # The link to view the scan results is also displayed on the console on successful completion of action. 25 | 26 | # This is a starter workflow to help you get started with APIsec-Scan Actions 27 | 28 | name: APIsec 29 | 30 | # Controls when the workflow will run 31 | on: 32 | # Triggers the workflow on push or pull request events but only for the "master" branch 33 | # Customize trigger events based on your DevSecOps processes. 34 | push: 35 | branches: [ "master" ] 36 | pull_request: 37 | branches: [ "master" ] 38 | schedule: 39 | - cron: '31 15 * * 0' 40 | 41 | # Allows you to run this workflow manually from the Actions tab 42 | workflow_dispatch: 43 | 44 | 45 | permissions: 46 | contents: read 47 | 48 | jobs: 49 | Trigger APIsec scan: 50 | permissions: 51 | security-events: write # for github/codeql-action/upload-sarif to upload SARIF results 52 | runs-on: ubuntu-latest 53 | 54 | steps: 55 | - name: APIsec scan 56 | uses: apisec-inc/apisec-run-scan@f62d0c6fae8a80f97b091a323befdb56e6ad9993 57 | with: 58 | # The APIsec username with which the scans will be executed 59 | apisec-username: ${{ secrets.apisec_username }} 60 | # The Password of the APIsec user with which the scans will be executed 61 | apisec-password: ${{ secrets.apisec_password}} 62 | # The name of the project for security scan 63 | apisec-project: "VAmPI" 64 | # The name of the sarif format result file The file is written only if this property is provided. 65 | sarif-result-file: "apisec-results.sarif" 66 | - name: Import results 67 | uses: github/codeql-action/upload-sarif@v2 68 | with: 69 | sarif_file: ./apisec-results.sarif 70 | -------------------------------------------------------------------------------- /src/accounting/users/endpoints.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from .models import User 3 | from accounting.schemas import ( 4 | UserUpdate, 5 | UserCreate, 6 | UserPasswordChange 7 | ) 8 | from fastapi import Request, Response 9 | from accounting.decorators import AAA_endpoint_oauth2 10 | 11 | 12 | class UserCRUD(): 13 | 14 | @staticmethod 15 | @AAA_endpoint_oauth2() 16 | async def get_all_users(request: Request, offset: int = 0, limit: int = 100): 17 | """ 18 | ### READ list[User] with offset and limit 19 | #### Args:\n 20 | offset (int, optional): Defaults to 0.\n 21 | limit (int, optional): Defaults to 100.\n 22 | #### Returns: 23 | list[User] 24 | """ 25 | return await User.get_all(offset=offset, limit=limit) 26 | 27 | @staticmethod 28 | @AAA_endpoint_oauth2() 29 | async def get_user(request: Request, id: str): 30 | """ 31 | ### READ one {User} by id 32 | #### Args:\n 33 | id (str): UUID4 PK 34 | #### Returns: 35 | User | None 36 | """ 37 | return await User.get_by_id(id) 38 | 39 | @staticmethod 40 | @AAA_endpoint_oauth2() 41 | async def create_user(request: Request, user: UserCreate): 42 | """ 43 | ### CREATE user 44 | #### Args:\n 45 | user (User): { 46 | username: str (Unique) 47 | password: str 48 | email: str 49 | active: bool 50 | } 51 | #### Returns: 52 | User 53 | """ 54 | return await User.add(**user.dict()) 55 | 56 | @staticmethod 57 | @AAA_endpoint_oauth2() 58 | async def update_user(request: Request, id: str, user: UserUpdate): 59 | """ 60 | ### Update one user (full or partial) 61 | Args:\n 62 | user (User): { 63 | name: str (Unique) 64 | price: int 65 | } 66 | Returns: 67 | User 68 | """ 69 | return await User.update_by_id(id=id, data=user.dict(exclude_none=True)) 70 | 71 | @staticmethod 72 | @AAA_endpoint_oauth2() 73 | async def patch_user(request: Request, id: str, user: UserPasswordChange): 74 | """ 75 | ### User password change 76 | Args:\n 77 | PasswordChange (UserPasswordChange): { 78 | old_password: Optional[str | None] #Admin can change password without old_password 79 | new_password: str 80 | } 81 | Returns: 82 | User 83 | """ 84 | return await User.change_password( 85 | id=id, 86 | old_plaintext_password=user.old_password, 87 | new_plaintext_password=user.new_password 88 | ) 89 | 90 | @staticmethod 91 | @AAA_endpoint_oauth2() 92 | async def delete_user(request: Request, id: str): 93 | """ 94 | ### DELETE one user by ID 95 | #### Args:\n 96 | id (str): UUID4 PK 97 | #### Returns: 98 | None, code=204 99 | """ 100 | await User.delete_by_id(id) 101 | return Response(status_code=204) 102 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | 9 | jobs: 10 | linters: 11 | runs-on: ubuntu-latest 12 | timeout-minutes: 30 13 | strategy: 14 | matrix: 15 | python-version: ["3.10"] 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Installing flake8 24 | run: | 25 | pip install poetry 26 | poetry config virtualenvs.create false 27 | poetry add flake8 28 | - name: Lint 29 | run: | 30 | # stop the build if there are Python syntax errors or undefined names 31 | flake8 src --count --select=E9,F63,F7,F82 --show-source --statistics 32 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 33 | flake8 src --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 34 | 35 | integration: 36 | runs-on: ubuntu-latest 37 | timeout-minutes: 30 38 | strategy: 39 | matrix: 40 | python-version: ["3.10"] 41 | postgres-version: [14] 42 | services: 43 | postgres: 44 | image: postgres:${{ matrix.postgres-version }} 45 | env: 46 | DB_SERVER_HOST: postgres 47 | DB_SERVER_PORT: 5432 48 | POSTGRES_USER: test 49 | POSTGRES_PASSWORD: test 50 | POSTGRES_DB: test 51 | options: >- 52 | --health-cmd pg_isready 53 | --health-interval 10s 54 | --health-timeout 5s 55 | --health-retries 5 56 | ports: 57 | - 5432:5432 58 | vault: 59 | image: vault:1.10.2 60 | env: 61 | VAULT_DEV_ROOT_TOKEN_ID: test 62 | VAULT_DEV_LISTEN_ADDRESS: 0.0.0.0:8200 63 | ports: 64 | - 8200:8200 65 | 66 | steps: 67 | - uses: actions/checkout@v3 68 | - name: Set up Python ${{ matrix.python-version }} 69 | uses: actions/setup-python@v3 70 | with: 71 | python-version: ${{ matrix.python-version }} 72 | - name: Install dependencies 73 | run: | 74 | pip install poetry 75 | poetry config virtualenvs.create false 76 | poetry install 77 | poetry add pytest pytest-cov flake8 78 | - name: Init DB roles for Vault 79 | run: python3 ci/init_db.py 80 | - name: Init Vault schema 81 | run: python3 ci/init_vault.py 82 | - name: Run integration tests 83 | run: poetry run pytest src --cov 84 | - name: Upload Codecov 85 | uses: codecov/codecov-action@v3 86 | - name: Building docs 87 | run: mkdocs gh-deploy --force 88 | 89 | -------------------------------------------------------------------------------- /src/tests/01_cli_test.py: -------------------------------------------------------------------------------- 1 | from typer.testing import CliRunner 2 | from main import app 3 | from .shared import clear_migrations_files 4 | from .payloads import ( 5 | test_superuser_1, 6 | test_superuser_2, 7 | test_user_1, 8 | test_user_2 9 | ) 10 | 11 | runner = CliRunner() 12 | 13 | 14 | def test_cli_db_drop_all(): 15 | result = runner.invoke(app, ["db", "drop", "all"], input="y\n") 16 | assert result.exit_code == 0 17 | 18 | 19 | def test_cli_db_drop_all_again(): 20 | result = runner.invoke(app, ["db", "drop", "all"], input="y\n") 21 | assert result.exit_code == 0 22 | assert "Ignore" in result.stdout 23 | assert "Ignored" in result.stdout 24 | 25 | 26 | def test_cli_db_init_all(): 27 | result = runner.invoke(app, ["db", "init", "all"]) 28 | assert result.exit_code == 0 29 | assert "Create" in result.stdout 30 | assert "Created" in result.stdout 31 | 32 | 33 | def test_cli_db_init_all_again(): 34 | result = runner.invoke(app, ["db", "init", "all"]) 35 | assert result.exit_code == 0 36 | assert "Ignore" in result.stdout 37 | assert "Ignored" in result.stdout 38 | test_cli_db_drop_all() # Let`s drop db for a right migration tests 39 | 40 | 41 | def test_cli_db_mg_create_all(): 42 | test_cli_db_drop_all() # Double check - Let`s drop db for a right mg tests 43 | clear_migrations_files() # Delete all mg in */piccolo_migrations dirs 44 | result = runner.invoke(app, ["db", "mg", "create", "all"]) 45 | assert result.exit_code == 0 46 | 47 | 48 | def test_cli_db_mg_run_all(): 49 | test_cli_db_drop_all() 50 | result = runner.invoke(app, ["db", "mg", "run", "all"]) 51 | assert result.exit_code == 0 52 | 53 | 54 | def test_cli_db_show_all(): 55 | result = runner.invoke(app, ["db", "show", "all"]) 56 | assert result.exit_code == 0 57 | assert "1" in result.stdout 58 | 59 | 60 | def test_cli_db_show_accounting(): 61 | result = runner.invoke(app, ["db", "show", "accounting"]) 62 | assert result.exit_code == 0 63 | assert "1" in result.stdout 64 | 65 | 66 | def test_cli_db_show_wrong_app(): 67 | result = runner.invoke(app, ["db", "show", "non-existing-app"]) 68 | assert result.exit_code == 0 69 | assert "1" not in result.stdout 70 | 71 | 72 | def test_cli_aaa_create_superuser(): 73 | result = runner.invoke( 74 | app, ["aaa", "create", "superuser"], input=test_superuser_1.to_cli_input()) 75 | assert result.exit_code == 0 76 | assert f"Superuser {test_superuser_1.username} was created with id" in result.stdout 77 | result = runner.invoke( 78 | app, ["aaa", "create", "superuser"], input=test_superuser_2.to_cli_input()) 79 | assert result.exit_code == 0 80 | assert f"Superuser {test_superuser_2.username} was created with id" in result.stdout 81 | 82 | 83 | def test_cli_aaa_create_user(): 84 | result = runner.invoke( 85 | app, ["aaa", "create", "user"], input=test_user_1.to_cli_input()) 86 | assert result.exit_code == 0 87 | assert f"User {test_user_1.username} was created with id" in result.stdout 88 | result = runner.invoke( 89 | app, ["aaa", "create", "user"], input=test_user_2.to_cli_input()) 90 | assert result.exit_code == 0 91 | assert f"User {test_user_2.username} was created with id" in result.stdout 92 | 93 | 94 | def test_cli_aaa_create_secret_from_config(): 95 | result = runner.invoke(app, ["aaa", "create", "secret"]) 96 | assert result.exit_code == 0 97 | assert "Secret generation completed" in result.stdout 98 | assert "All checks successfully passed" in result.stdout 99 | -------------------------------------------------------------------------------- /src/configuration/model.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from .sections import ( 3 | MainSectionConfiguration, 4 | AdminGUISectionConfiguration, 5 | SecuritySectionConfiguration, 6 | ServerSectionConfiguration, 7 | VaultSectionConfiguration, 8 | DatabaseSectionConfiguration, 9 | TelemetrySectionConfiguration 10 | ) 11 | from .base import BaseSectionModel 12 | from pydantic import BaseSettings, ValidationError 13 | import configparser 14 | import toml 15 | from loguru import logger 16 | import os 17 | 18 | 19 | class Configuration(BaseSettings): 20 | 21 | class Config: 22 | load_failed: bool = False 23 | 24 | def __repr__(self) -> str: 25 | return f"" 26 | 27 | Main: MainSectionConfiguration = MainSectionConfiguration() 28 | AdminGUI: AdminGUISectionConfiguration = AdminGUISectionConfiguration() 29 | Server: ServerSectionConfiguration = ServerSectionConfiguration() 30 | Vault: VaultSectionConfiguration = VaultSectionConfiguration() 31 | Database: DatabaseSectionConfiguration = DatabaseSectionConfiguration() 32 | Telemetry: TelemetrySectionConfiguration = TelemetrySectionConfiguration() 33 | Security: SecuritySectionConfiguration = SecuritySectionConfiguration() 34 | 35 | def load(self, filename: str, filetype: str | None = None): 36 | raw_data: dict = dict() 37 | try: 38 | file_extention = filename.split('.')[1] 39 | except IndexError: 40 | self.Config.load_failed = True 41 | logger.critical('Cannot find config file extention') 42 | else: 43 | match file_extention: 44 | case 'ini': raw_data = Configuration.ini_reader(filename) 45 | case 'toml': raw_data = Configuration.toml_reader(filename) 46 | case 'yaml': raw_data = Configuration.yaml_reader(filename) 47 | case _: 48 | self.Config.load_failed = True 49 | logger.critical('Cannot define config file extention') 50 | self.read_from_dict(raw_data) 51 | logger.info(f'Configuration was successfully loaded from {filename}') 52 | 53 | @staticmethod 54 | def ini_reader(filename: str) -> dict: 55 | config = configparser.ConfigParser() 56 | parced_data: dict = dict() 57 | with open(filename, 'r') as f: 58 | config.read_file(f) 59 | for section_name in dict(config).keys(): 60 | if section_name != 'DEFAULT': 61 | section_data: dict = dict( 62 | dict(config).get(section_name)) # type: ignore 63 | parced_data.update({section_name: section_data}) 64 | return parced_data 65 | 66 | @staticmethod 67 | def toml_reader(filename: str) -> dict: 68 | with open(filename, 'r') as f: 69 | return toml.loads(f.read(), _dict=dict) 70 | 71 | @staticmethod 72 | def yaml_reader(filename: str) -> dict: 73 | with open(filename, 'r') as f: 74 | return yaml.load(f, yaml.loader.SafeLoader) 75 | 76 | def read_from_dict(self, raw_data: dict): 77 | for section_name in self.__fields__: 78 | section_data: dict = raw_data.get(section_name, dict()) 79 | section: BaseSectionModel = self.__getattribute__(section_name) 80 | loaded_section: BaseSectionModel = section.load( 81 | section_data, 82 | section_name 83 | ) 84 | if not loaded_section: 85 | os._exit(1) 86 | self.__setattr__( 87 | section_name, 88 | loaded_section 89 | ) 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | # pytype static type analyzer 143 | .pytype/ 144 | 145 | # Cython debug symbols 146 | cython_debug/ 147 | 148 | # PyCharm 149 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 150 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 151 | # and can be added to the global gitignore or merged into this file. For a more nuclear 152 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 153 | #.idea/ 154 | 155 | .vscode/ 156 | *.code-workspace 157 | .infra/var/ 158 | .infra/vault/data/ 159 | #.infra/vault/data 160 | #*.pem 161 | #*.key 162 | #*.crt 163 | emoji 164 | src/*/piccolo_migrations -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fastapi-boilerplate" 3 | version = "1.1.2" 4 | description = "" 5 | authors = ["northpowered "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.10" 9 | aiofiles = "0.8.0" 10 | aiohttp = "3.8.1" 11 | aiosignal = "1.2.0" 12 | aiosqlite = "0.17.0" 13 | anyio = "3.6.1" 14 | asgiref = "3.5.2" 15 | asttokens = "2.0.5" 16 | async-hvac-fork = "0.6.1" 17 | async-timeout = "4.0.2" 18 | asyncpg = "0.25.0" 19 | attrs = "21.4.0" 20 | autopep8 = "1.6.0" 21 | backcall = "0.2.0" 22 | black = "22.6.0" 23 | certifi = "2022.6.15" 24 | cffi = "1.15.1" 25 | chardet = "3.0.4" 26 | charset-normalizer = "2.1.0" 27 | click = "8.1.3" 28 | colorama = "0.4.5" 29 | commonmark = "0.9.1" 30 | cryptography = "37.0.2" 31 | decorator = "5.1.1" 32 | Deprecated = "1.2.13" 33 | dnspython = "2.2.1" 34 | docstring-parser = "0.12" 35 | ecdsa = "0.17.0" 36 | email-validator = "1.2.1" 37 | executing = "0.8.3" 38 | fastapi = "0.78.0" 39 | frozenlist = "1.3.0" 40 | googleapis-common-protos = "1.56.1" 41 | greenlet = "1.1.2" 42 | grpcio = "1.47.0" 43 | h11 = "0.12.0" 44 | h2 = "4.1.0" 45 | hpack = "4.0.0" 46 | httpcore = "0.15.0" 47 | httptools = "0.4.0" 48 | httpx = "0.23.0" 49 | hypercorn = "0.13.2" 50 | hyperframe = "6.0.1" 51 | idna = "3.3" 52 | inflection = "0.5.1" 53 | ipython = "8.4.0" 54 | itsdangerous = "2.1.2" 55 | jedi = "0.18.1" 56 | Jinja2 = "3.1.2" 57 | loguru = "0.6.0" 58 | MarkupSafe = "2.1.1" 59 | matplotlib-inline = "0.1.3" 60 | multidict = "4.7.6" 61 | mypy = "0.961" 62 | mypy-extensions = "0.4.3" 63 | opentelemetry-api = "1.12.0rc2" 64 | opentelemetry-exporter-jaeger = "1.12.0rc2" 65 | opentelemetry-exporter-jaeger-proto-grpc = "1.12.0rc2" 66 | opentelemetry-exporter-jaeger-thrift = "1.12.0rc2" 67 | opentelemetry-instrumentation = "0.32b0" 68 | opentelemetry-instrumentation-asgi = "0.32b0" 69 | opentelemetry-instrumentation-fastapi = "0.32b0" 70 | opentelemetry-sdk = "1.12.0rc2" 71 | opentelemetry-semantic-conventions = "0.32b0" 72 | opentelemetry-util-http = "0.32b0" 73 | orjson = "3.7.5" 74 | parso = "0.8.3" 75 | pathspec = "0.9.0" 76 | pexpect = "4.8.0" 77 | piccolo = "0.82.0" 78 | piccolo-admin = "0.27.0" 79 | piccolo-api = "0.40.0" 80 | pickleshare = "0.7.5" 81 | platformdirs = "2.5.2" 82 | priority = "2.0.0" 83 | prometheus-client = "0.14.1" 84 | prompt-toolkit = "3.0.30" 85 | protobuf = "4.21.2" 86 | psycopg2 = "2.9.3" 87 | ptyprocess = "0.7.0" 88 | pure-eval = "0.2.2" 89 | pyasn1 = "0.4.8" 90 | pycodestyle = "2.9.0" 91 | pycparser = "2.21" 92 | pydantic = "1.9.1" 93 | pydot = "1.4.2" 94 | Pygments = "2.12.0" 95 | PyJWT = "2.4.0" 96 | pyparsing = "3.0.9" 97 | python-dotenv = "0.20.0" 98 | python-jose = "3.3.0" 99 | python-multipart = "0.0.5" 100 | PyYAML = "6.0" 101 | requests = "2.28.1" 102 | rfc3986 = "1.5.0" 103 | rich = "12.5.1" 104 | rsa = "4.8" 105 | six = "1.16.0" 106 | sniffio = "1.2.0" 107 | stack-data = "0.3.0" 108 | starlette = "0.19.1" 109 | starlette-exporter = "0.13.0" 110 | targ = "0.3.7" 111 | thrift = "0.16.0" 112 | toml = "0.10.2" 113 | tomli = "2.0.1" 114 | traitlets = "5.3.0" 115 | typer = "0.5.0" 116 | typing-extensions = "4.2.0" 117 | ujson = "5.4.0" 118 | urllib3 = "1.26.9" 119 | uvicorn = "0.17.6" 120 | uvloop = "0.16.0" 121 | watchgod = "0.8.2" 122 | wcwidth = "0.2.5" 123 | websockets = "10.3" 124 | wrapt = "1.14.1" 125 | wsproto = "1.1.0" 126 | yarl = "1.7.2" 127 | types-PyYAML = "^6.0.11" 128 | types-toml = "^0.10.8" 129 | pytest = "^7.1.2" 130 | mimesis = "^6.0.0" 131 | pytest-cov = "^3.0.0" 132 | hvac = {extras = ["parcer"], version = "^0.11.2"} 133 | mkdocs-material = "^8.4.2" 134 | trio = "^0.21.0" 135 | 136 | [tool.poetry.dev-dependencies] 137 | flake8 = "^5.0.4" 138 | 139 | [build-system] 140 | requires = ["poetry-core>=1.0.0"] 141 | build-backend = "poetry.core.masonry.api" 142 | -------------------------------------------------------------------------------- /src/accounting/groups/models.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from typing import TypeVar, Type 3 | from utils.piccolo import Table, uuid4_for_PK 4 | from piccolo.columns import m2m 5 | from piccolo.columns.column_types import ( 6 | Text, Boolean, LazyTableReference 7 | ) 8 | from piccolo.columns.readable import Readable 9 | from asyncpg.exceptions import UniqueViolationError 10 | from utils.exceptions import IntegrityException, ObjectNotFoundException, BaseBadRequestException 11 | 12 | T_G = TypeVar('T_G', bound='Group') 13 | 14 | 15 | class Group(Table, tablename="groups"): 16 | 17 | id = Text(primary_key=True, index=True, default=uuid4_for_PK) 18 | name = Text(unique=True, index=True, null=False) 19 | active = Boolean(nullable=False, default=True) 20 | users = m2m.M2M(LazyTableReference( 21 | "M2MUserGroup", module_path='accounting')) 22 | 23 | @classmethod 24 | def get_readable(cls): 25 | return Readable(template="%s", columns=[cls.name]) 26 | 27 | @classmethod 28 | async def get_all(cls: Type[T_G], offset: int, limit: int) -> list[Type[T_G]]: 29 | return await cls.objects().limit(limit).offset(offset) 30 | 31 | @classmethod 32 | async def get_by_id(cls: Type[T_G], id: str) -> Type[T_G]: 33 | group: Type[T_G] = await cls.objects().where(cls.id == id).first() 34 | try: 35 | assert group 36 | except AssertionError: 37 | raise ObjectNotFoundException(object_name=__name__, object_id=id) 38 | else: 39 | return group 40 | 41 | @classmethod 42 | async def get_by_name(cls: Type[T_G], name: str) -> Type[T_G]: 43 | group: Type[T_G] = await cls.objects().where(cls.name == name).first() 44 | try: 45 | assert group 46 | except AssertionError: 47 | raise ObjectNotFoundException(object_name=__name__, object_id=name) 48 | else: 49 | return group 50 | 51 | @classmethod 52 | async def add(cls: Type[T_G], name: str, active: bool) -> Type[T_G]: 53 | 54 | new_id: str = uuid4_for_PK() 55 | group: T_G = cls( 56 | id=new_id, 57 | name=name, 58 | active=active 59 | ) 60 | try: 61 | resp = await cls.insert(group) 62 | except UniqueViolationError as ex: 63 | raise IntegrityException(ex) 64 | else: 65 | inserted_pk = resp[0].get('id') 66 | return await cls.get_by_id(inserted_pk) 67 | 68 | @classmethod 69 | async def update_by_id(cls: Type[T_G], id: str, data: dict) -> Type[T_G]: 70 | await cls.update(**data).where(cls.id == id) 71 | return await cls.get_by_id(id) 72 | 73 | @classmethod 74 | async def delete_by_id(cls: Type[T_G], id: str) -> None: 75 | await cls.get_by_id(id) 76 | await cls.delete().where(cls.id == id) 77 | 78 | @classmethod 79 | async def add_users(cls: Type[T_G], group_id: str, user_ids: list[str]): 80 | from accounting import User # CircularImport error 81 | group: T_G = await cls.objects().get(cls.id == group_id) 82 | for user_id in user_ids: 83 | user = await User.get_by_id(user_id) 84 | await group.add_m2m( 85 | user, # type: ignore 86 | m2m=cls.users 87 | ) 88 | return await cls.get_by_id(group_id) 89 | 90 | @classmethod 91 | async def delete_users(cls: Type[T_G], group_id: str, user_ids: list[str]): 92 | from accounting import User # CircularImport error 93 | group: T_G = await cls.objects().get(cls.id == group_id) 94 | for user_id in user_ids: 95 | user = await User.get_by_id(user_id) 96 | await group.remove_m2m( 97 | user, # type: ignore 98 | m2m=cls.users 99 | ) 100 | return await cls.get_by_id(group_id) 101 | -------------------------------------------------------------------------------- /src/accounting/schemas.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, EmailStr 2 | from typing import Optional 3 | import datetime 4 | 5 | """ Base pydantic models """ 6 | 7 | 8 | class UserBase(BaseModel): 9 | """User schema without joined fields""" 10 | id: Optional[str] 11 | username: Optional[str] 12 | email: Optional[EmailStr] 13 | active: Optional[bool] = True 14 | birthdate: Optional[datetime.datetime | None] 15 | created_at: Optional[datetime.datetime] 16 | updated_at: Optional[datetime.datetime] 17 | last_login: Optional[datetime.datetime | None] 18 | 19 | class Config: 20 | orm_mode = True 21 | 22 | 23 | class RoleBase(BaseModel): 24 | """Role schema without joined fields""" 25 | id: Optional[str] 26 | name: Optional[str] 27 | active: Optional[bool] = True 28 | 29 | class Config: 30 | orm_mode = True 31 | 32 | 33 | class GroupBase(BaseModel): 34 | """Group schema without joined fields""" 35 | name: Optional[str] 36 | active: Optional[bool] = True 37 | 38 | class Config: 39 | orm_mode = True 40 | 41 | 42 | class PermissionBase(BaseModel): 43 | """Permission schema without joined fields""" 44 | name: str 45 | object: str 46 | description: Optional[str] 47 | 48 | class Config: 49 | orm_mode = True 50 | 51 | 52 | class PolicyBase(BaseModel): 53 | """Policy schema without joined fields""" 54 | id: Optional[str] 55 | name: Optional[str] 56 | active: Optional[bool] 57 | description: Optional[str] 58 | 59 | class Config: 60 | orm_mode = True 61 | 62 | 63 | """ CRUD pydantic models """ 64 | 65 | 66 | class UserRead(UserBase): 67 | """ 68 | READ model for USER subject, with M2M fields 69 | """ 70 | roles: list[RoleBase] 71 | groups: list[GroupBase] 72 | 73 | 74 | class UserUpdate(UserBase): 75 | pass 76 | 77 | 78 | class UserCreate(BaseModel): 79 | """ 80 | CREATE model for USER subject with required fields 81 | """ 82 | username: str 83 | password: str 84 | email: EmailStr 85 | 86 | 87 | class UserPasswordChange(BaseModel): 88 | # TODO Admin can change password without old_password 89 | old_password: str 90 | new_password: str 91 | 92 | class Config: 93 | orm_mode = True 94 | 95 | 96 | class RoleRead(RoleBase): 97 | """ 98 | READ model for ROLE subject 99 | """ 100 | users: list[UserBase] 101 | 102 | 103 | class RoleCreate(BaseModel): 104 | """ 105 | CREATE model for USER subject with required fields 106 | """ 107 | name: str 108 | active: bool = True 109 | 110 | 111 | class RoleUpdate(RoleBase): 112 | pass 113 | 114 | 115 | class GroupRead(GroupBase): 116 | """ 117 | READ model for GROUP subject 118 | """ 119 | id: str 120 | 121 | 122 | class GroupCreate(BaseModel): 123 | """ 124 | CREATE model for GROUP subject with required fields 125 | """ 126 | name: str 127 | active: bool = True 128 | 129 | 130 | class GroupUpdate(GroupBase): 131 | pass 132 | 133 | 134 | class PermissionCreate(PermissionBase): 135 | id: Optional[str] 136 | 137 | 138 | class RolesToUser(BaseModel): 139 | user_id: str 140 | role_ids: list[str] 141 | 142 | 143 | class UsersToRole(BaseModel): 144 | role_id: str 145 | user_ids: list[str] 146 | 147 | 148 | class GroupesToUser(BaseModel): 149 | user_id: str 150 | group_ids: list[str] 151 | 152 | 153 | class UsersToGroup(BaseModel): 154 | group_id: str 155 | user_ids: list[str] 156 | 157 | 158 | class PermissionRead(PermissionBase): 159 | policies: list[PolicyBase] 160 | 161 | 162 | class PolicyRead(PolicyBase): 163 | permission: PermissionBase 164 | role: RoleBase 165 | 166 | 167 | class PolicyCreate(BaseModel): 168 | permission_id: str 169 | role_id: str 170 | name: Optional[str] 171 | description: Optional[str] 172 | active: bool = True 173 | 174 | 175 | class PolicyUpdate(PolicyBase): 176 | pass 177 | -------------------------------------------------------------------------------- /src/accounting/roles/models.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from typing import TypeVar, Type 3 | from utils.piccolo import Table, uuid4_for_PK, get_pk_from_resp 4 | from piccolo.columns import m2m 5 | from piccolo.columns.column_types import ( 6 | Text, Boolean, LazyTableReference 7 | ) 8 | from piccolo.columns.readable import Readable 9 | from asyncpg.exceptions import UniqueViolationError 10 | from utils.exceptions import IntegrityException, ObjectNotFoundException, BaseBadRequestException 11 | 12 | T_R = TypeVar('T_R', bound='Role') 13 | 14 | 15 | class Role(Table, tablename="roles"): 16 | 17 | id = Text(primary_key=True, index=True, default=uuid4_for_PK) 18 | name = Text(unique=True, index=True, null=False) 19 | active = Boolean(nullable=False, default=True) 20 | users = m2m.M2M(LazyTableReference( 21 | "M2MUserRole", module_path='accounting')) 22 | policies = m2m.M2M(LazyTableReference("Policy", module_path='accounting')) 23 | 24 | @classmethod 25 | def get_readable(cls): 26 | return Readable(template="%s", columns=[cls.name]) 27 | 28 | @classmethod 29 | async def get_all(cls: Type[T_R], offset: int, limit: int) -> list[T_R]: 30 | resp: list[T_R] = await cls.objects().limit(limit).offset(offset) 31 | for r in resp: 32 | await r.join_m2m() 33 | return resp 34 | 35 | @classmethod 36 | async def get_by_id(cls: Type[T_R], id: str) -> T_R: 37 | role: T_R = await cls.objects().where(cls.id == id).first() 38 | try: 39 | assert role 40 | except AssertionError: 41 | raise ObjectNotFoundException(object_name=__name__, object_id=id) 42 | else: 43 | await role.join_m2m() 44 | return role 45 | 46 | @classmethod 47 | async def get_by_name(cls: Type[T_R], name: str) -> T_R: 48 | role: T_R = await cls.objects().where(cls.name == name).first() 49 | try: 50 | assert role 51 | except AssertionError: 52 | raise ObjectNotFoundException(object_name=__name__, object_id=name) 53 | else: 54 | await role.join_m2m() 55 | return role 56 | 57 | @classmethod 58 | async def add(cls: Type[T_R], name: str, active: bool) -> T_R: 59 | 60 | new_id: str = uuid4_for_PK() 61 | role: T_R = cls( 62 | id=new_id, 63 | name=name, 64 | active=active 65 | ) 66 | try: 67 | resp = await cls.insert(role) 68 | except UniqueViolationError as ex: 69 | raise IntegrityException(ex) 70 | else: 71 | inserted_pk = get_pk_from_resp(resp, 'id') 72 | return await cls.get_by_id(inserted_pk) # type: ignore 73 | 74 | @classmethod 75 | async def update_by_id(cls: Type[T_R], id: str, data: dict) -> T_R: 76 | await cls.update(**data).where(cls.id == id) 77 | return await cls.get_by_id(id) 78 | 79 | @classmethod 80 | async def delete_by_id(cls: Type[T_R], id: str) -> None: 81 | await cls.get_by_id(id) 82 | await cls.delete().where(cls.id == id) 83 | 84 | @classmethod 85 | async def add_users(cls: Type[T_R], role_id: str, user_ids: list[str]): 86 | from accounting import User # CircularImport error 87 | role: T_R = await cls.objects().get(cls.id == role_id) 88 | for user_id in user_ids: 89 | user = await User.get_by_id(user_id) 90 | await role.add_m2m( 91 | user, # type: ignore 92 | m2m=cls.users 93 | ) 94 | return await cls.get_by_id(role_id) 95 | 96 | @classmethod 97 | async def delete_users(cls: Type[T_R], role_id: str, user_ids: list[str]): 98 | from accounting import User # CircularImport error 99 | role: T_R = await cls.objects().get(cls.id == role_id) 100 | for user_id in user_ids: 101 | user = await User.get_by_id(user_id) 102 | await role.remove_m2m( 103 | user, # type: ignore 104 | m2m=cls.users 105 | ) 106 | return await cls.get_by_id(role_id) 107 | -------------------------------------------------------------------------------- /src/piccolo_conf.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from piccolo.conf.apps import AppRegistry 3 | from piccolo.engine.postgres import PostgresEngine as _PostgresEngine 4 | from configuration import config 5 | from typing import Any, Sequence, Dict 6 | from piccolo.querystring import QueryString 7 | from asyncpg.exceptions import InvalidPasswordError, UndefinedTableError 8 | from utils.events import load_vault_db_creds, reload_db_creds 9 | 10 | 11 | # TODO asyncio warning about db auth fail, redone in _run_in_new_connection 12 | class PostgresEngine(_PostgresEngine): 13 | """ 14 | Implemetation of base PostgresEngine class of Piccolo ORM 15 | Added Vault integration for obtaining new DB credentials after expiration 16 | """ 17 | 18 | def __init__( 19 | self, 20 | config: Dict[str, Any], 21 | extensions: Sequence[str] = (), 22 | log_queries: bool = False, 23 | extra_nodes: Dict[str, _PostgresEngine] = None 24 | ) -> None: 25 | super().__init__(config, extensions, log_queries, extra_nodes) 26 | 27 | async def run_inside_the_transaction(self, query: str, query_args: list, in_pool: bool): 28 | """ 29 | Simple code wrapper, just for DRY 30 | 31 | Args: 32 | query: str 33 | query_args: list 34 | in_pool: bool 35 | 36 | """ 37 | connection = self.transaction_connection.get() 38 | # We don`t test this block, becouse it`s already tested in original piccolo-orm 39 | if connection: 40 | return await connection.fetch(query, *query_args) # pragma: no cover 41 | elif in_pool and self.pool: 42 | return await self._run_in_pool(query, query_args) # pragma: no cover 43 | else: 44 | return await self._run_in_new_connection(query, query_args) 45 | 46 | async def run_querystring( 47 | self, querystring: QueryString, in_pool: bool = True 48 | ): 49 | 50 | query, query_args = querystring.compile_string( 51 | engine_type=self.engine_type 52 | ) 53 | 54 | if self.log_queries: 55 | # No reason to test `print` 56 | print(querystring) # pragma: no cover 57 | 58 | # If running inside a transaction: 59 | try: 60 | return await self.run_inside_the_transaction( 61 | query=query, 62 | query_args=query_args, 63 | in_pool=in_pool 64 | ) 65 | except InvalidPasswordError: 66 | logger.warning( # pragma: no cover 67 | 'Failed to authenticate in DB server through Vault, obtaining new credentials') 68 | # Reloading db creds from vault, with adding to CONFIG instance 69 | # Was already tested directly 70 | await reload_db_creds() # pragma: no cover 71 | # Setting new 'config' for PostgresEngine 72 | self.config = {'dsn': config.Database.get_connection_string()} 73 | # Retry of query with new creds 74 | return await self.run_inside_the_transaction( # pragma: no cover 75 | query=query, 76 | query_args=query_args, 77 | in_pool=in_pool 78 | ) 79 | except UndefinedTableError as ex: 80 | logger.error( # pragma: no cover 81 | "Table not found, did you forget to `init` db or `run` migration?") 82 | logger.critical(f"Database error: {ex}") # pragma: no cover 83 | 84 | 85 | # First time building DB Engine 86 | # Engine object is storing in CONFIG instance 87 | # config.Database.set_connection_string('postgresql://test:test@127.0.0.1:5432/test') 88 | # print(config.Database.build_connection_string) 89 | config.Database.set_engine( 90 | PostgresEngine( 91 | config={ 92 | 'dsn': config.Database.get_connection_string(), 93 | }, 94 | log_queries=bool(config.Main.log_sql) 95 | ) 96 | ) 97 | DB = config.Database.get_engine() 98 | 99 | 100 | # A list of paths to piccolo apps 101 | # e.g. ['blog.piccolo_app'] 102 | APP_REGISTRY = AppRegistry( 103 | apps=[ 104 | "accounting.piccolo_app", 105 | "piccolo_admin.piccolo_app", 106 | "configuration.piccolo_app" 107 | ] 108 | ) 109 | -------------------------------------------------------------------------------- /src/accounting/authentication/models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import secrets 3 | from loguru import logger 4 | from typing import TypeVar, Type, Optional, cast 5 | from piccolo.columns.defaults.timestamp import TimestampOffset 6 | from piccolo.columns.column_types import Text, Timestamp 7 | from piccolo_api.session_auth.tables import SessionsBase 8 | from piccolo.utils.sync import run_sync 9 | from configuration import config 10 | from accounting.users import T_U 11 | 12 | T_S = TypeVar('T_S', bound='Sessions') 13 | 14 | 15 | class Sessions(SessionsBase, tablename="sessions"): 16 | """ 17 | INHERITED from SessionsBase 18 | Use this table, or inherit from it, to create for a session store. 19 | 20 | We set a hard limit on the expiry date - it can keep on getting extended 21 | up until this value, after which it's best to invalidate it, and either 22 | require login again, or just create a new session token. 23 | """ 24 | 25 | token = Text(length=100, null=False) # type: ignore 26 | user_id = Text(null=False) # type: ignore 27 | expiry_date: Timestamp | datetime.datetime = Timestamp( 28 | default=TimestampOffset(hours=1), null=False 29 | ) 30 | max_expiry_date: Timestamp | datetime.datetime = Timestamp( 31 | default=TimestampOffset(days=7), null=False 32 | ) 33 | 34 | @classmethod 35 | async def create_session( # type: ignore 36 | cls: Type[T_S], 37 | user_id: str, 38 | expiry_date: Optional[datetime.datetime] = None, 39 | max_expiry_date: Optional[datetime.datetime] = None, 40 | ) -> Type[T_S]: 41 | while True: 42 | token: str = secrets.token_urlsafe(nbytes=32) 43 | if not await cls.exists().where(cls.token == token).run(): # type: ignore 44 | break 45 | 46 | session = cls(token=token, user_id=user_id) 47 | if expiry_date: 48 | session.expiry_date = expiry_date 49 | if max_expiry_date: 50 | session.max_expiry_date = max_expiry_date 51 | 52 | await session.save().run() 53 | return session # type: ignore 54 | 55 | @classmethod 56 | def create_session_sync( # type: ignore 57 | cls, user_id: str, expiry_date: Optional[datetime.datetime] = None 58 | ) -> Type[T_S]: 59 | return run_sync(cls.create_session(user_id, expiry_date)) 60 | 61 | @classmethod 62 | async def get_user_id( # type: ignore 63 | cls, token: str, increase_expiry: Optional[datetime.timedelta] = None 64 | ) -> Optional[str]: 65 | """ 66 | Returns the user_id if the given token is valid, otherwise None. 67 | 68 | :param increase_expiry: 69 | If set, the `expiry_date` will be increased by the given amount 70 | if it's close to expiring. If it has already expired, nothing 71 | happens. The `max_expiry_date` remains the same, so there's a hard 72 | limit on how long a session can be used for. 73 | """ 74 | session: Type[T_S] = ( # type: ignore 75 | await cls.objects().where(cls.token == token).first().run() # type: ignore 76 | ) 77 | if not session: 78 | return None 79 | now = datetime.datetime.now() 80 | if (session.expiry_date > now) and (session.max_expiry_date > now): # type: ignore 81 | if increase_expiry and ( 82 | cast(datetime.datetime, session.expiry_date) - 83 | now < increase_expiry # type: ignore 84 | ): 85 | session.expiry_date = ( # type: ignore 86 | cast(datetime.datetime, session.expiry_date) + \ 87 | increase_expiry # type: ignore 88 | ) 89 | await session.save().run() # type: ignore 90 | 91 | return cast(Optional[str], session.user_id) # type: ignore 92 | else: 93 | return None 94 | 95 | @classmethod 96 | def get_user_id_sync(cls, token: str) -> Optional[str]: # type: ignore 97 | return run_sync(cls.get_user_id(token)) 98 | 99 | @classmethod 100 | async def remove_session(cls, token: str): 101 | await cls.delete().where(cls.token == token).run() # type: ignore 102 | 103 | @classmethod 104 | def remove_session_sync(cls, token: str): 105 | return run_sync(cls.remove_session(token)) 106 | -------------------------------------------------------------------------------- /src/accounting/rbac/endpoints.py: -------------------------------------------------------------------------------- 1 | from fastapi import Request, Response 2 | from accounting.schemas import ( 3 | RolesToUser, 4 | UsersToRole, 5 | UsersToGroup, 6 | GroupesToUser, 7 | PolicyCreate, 8 | PolicyUpdate) 9 | from accounting.users import User 10 | from accounting.groups import Group 11 | from accounting.roles import Role 12 | from .models import Policy, Permission 13 | from accounting.decorators import AAA_endpoint_oauth2 14 | 15 | 16 | class UserRoleCRUD(): 17 | 18 | @staticmethod 19 | @AAA_endpoint_oauth2() 20 | async def add_roles_to_user(request: Request, data: RolesToUser): 21 | return await User.add_roles(user_id=data.user_id, role_ids=data.role_ids) 22 | 23 | @staticmethod 24 | @AAA_endpoint_oauth2() 25 | async def add_users_to_role(request: Request, data: UsersToRole): 26 | return await Role.add_users(role_id=data.role_id, user_ids=data.user_ids) 27 | 28 | @staticmethod 29 | @AAA_endpoint_oauth2() 30 | async def delete_roles_from_user(request: Request, data: RolesToUser): 31 | return await User.delete_roles(user_id=data.user_id, role_ids=data.role_ids) 32 | 33 | @staticmethod 34 | @AAA_endpoint_oauth2() 35 | async def delete_users_from_role(request: Request, data: UsersToRole): 36 | return await Role.delete_users(role_id=data.role_id, user_ids=data.user_ids) 37 | 38 | 39 | class UserGroupCRUD(): 40 | 41 | @staticmethod 42 | @AAA_endpoint_oauth2() 43 | async def add_groups_to_user(request: Request, data: GroupesToUser): 44 | return await User.add_groups(user_id=data.user_id, group_ids=data.group_ids) 45 | 46 | @staticmethod 47 | @AAA_endpoint_oauth2() 48 | async def add_users_to_group(request: Request, data: UsersToGroup): 49 | return await Group.add_users(group_id=data.group_id, user_ids=data.user_ids) 50 | 51 | @staticmethod 52 | @AAA_endpoint_oauth2() 53 | async def delete_groups_from_user(request: Request, data: GroupesToUser): 54 | return await User.delete_groups(user_id=data.user_id, group_ids=data.group_ids) 55 | 56 | @staticmethod 57 | @AAA_endpoint_oauth2() 58 | async def delete_users_from_group(request: Request, data: UsersToGroup): 59 | return await Group.delete_users(group_id=data.group_id, user_ids=data.user_ids) 60 | 61 | 62 | class PermissionCRUD(): 63 | 64 | @staticmethod 65 | @AAA_endpoint_oauth2() 66 | async def get_all_permissions(request: Request, offset: int = 0, limit: int = 100): 67 | return await Permission.get_all(offset=offset, limit=limit) 68 | 69 | @staticmethod 70 | @AAA_endpoint_oauth2() 71 | async def get_permission(request: Request, id: str): 72 | return await Permission.get_by_id(id=id) 73 | 74 | 75 | class PolicyCRUD(): 76 | 77 | @staticmethod 78 | @AAA_endpoint_oauth2() 79 | async def get_all_policies(request: Request, offset: int = 0, limit: int = 100): 80 | """ 81 | ### READ list[Policy] with offset and limit 82 | #### Args:\n 83 | offset (int, optional): Defaults to 0.\n 84 | limit (int, optional): Defaults to 100.\n 85 | #### Returns: 86 | list[Policy] 87 | """ 88 | return await Policy.get_all(offset=offset, limit=limit) 89 | 90 | @staticmethod 91 | @AAA_endpoint_oauth2() 92 | async def get_policy(id: str): 93 | """ 94 | ### READ one {Policy} by id 95 | #### Args:\n 96 | id (str): UUID4 PK 97 | #### Returns: 98 | Policy | None 99 | """ 100 | return await Policy.get_by_id(id) 101 | 102 | @staticmethod 103 | @AAA_endpoint_oauth2() 104 | async def add_policy(request: Request, data: PolicyCreate): 105 | return await Policy.add(**data.dict(exclude_unset=True)) 106 | 107 | @staticmethod 108 | @AAA_endpoint_oauth2() 109 | async def update_policy(id: str, policy: PolicyUpdate): 110 | """ 111 | ### Update one policy (full or partial) 112 | Args:\n 113 | policy (Policy): { 114 | name: str (Unique) 115 | price: int 116 | } 117 | Returns: 118 | Policy 119 | """ 120 | return await Policy.update_by_id(id=id, data=policy.dict(exclude_unset=True)) 121 | 122 | @staticmethod 123 | @AAA_endpoint_oauth2() 124 | async def delete_policy(id: str): 125 | """ 126 | ### DELETE one policy by ID 127 | #### Args:\n 128 | id (str): UUID4 PK 129 | #### Returns: 130 | None, code=204 131 | """ 132 | await Policy.delete_by_id(id) 133 | return Response(status_code=204) 134 | -------------------------------------------------------------------------------- /src/utils/id_propagation.py: -------------------------------------------------------------------------------- 1 | from starlette.requests import Request 2 | from starlette.responses import Response, StreamingResponse 3 | from starlette.types import ASGIApp, Receive, Scope, Send 4 | from starlette.datastructures import MutableHeaders 5 | from logging import Filter, LogRecord 6 | from contextvars import ContextVar 7 | from uuid import uuid4 8 | import typing 9 | import anyio 10 | 11 | trace_id_context: ContextVar[typing.Optional[str] 12 | ] = ContextVar('trace_id', default=None) 13 | RequestResponseEndpoint = typing.Callable[[ 14 | Request], typing.Awaitable[Response]] 15 | DispatchFunction = typing.Callable[ 16 | [Request, RequestResponseEndpoint], typing.Awaitable[Response] 17 | ] 18 | 19 | 20 | class IDPropagationMiddleware(): 21 | """ 22 | TraceId propagation middleware, inspired with https://github.com/snok/asgi-correlation-id 23 | """ 24 | 25 | def __init__(self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None) -> None: 26 | self.app = app 27 | self.dispatch_func = self.dispatch if dispatch is None else dispatch 28 | self.tracing_header = "X-TRACE-ID" 29 | 30 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 31 | if scope["type"] != "http": 32 | await self.app(scope, receive, send) 33 | return 34 | 35 | async def call_next(request: Request) -> Response: 36 | app_exc: typing.Optional[Exception] = None 37 | send_stream, recv_stream = anyio.create_memory_object_stream() 38 | 39 | async def coro() -> None: 40 | nonlocal app_exc 41 | async with send_stream: 42 | try: 43 | await self.app(scope, request.receive, send_stream.send) 44 | except Exception as exc: 45 | app_exc = exc 46 | 47 | task_group.start_soon(coro) 48 | 49 | try: 50 | message = await recv_stream.receive() 51 | except anyio.EndOfStream: 52 | if app_exc is not None: 53 | raise app_exc 54 | raise RuntimeError("No response returned.") 55 | 56 | assert message["type"] == "http.response.start" 57 | 58 | async def body_stream() -> typing.AsyncGenerator[bytes, None]: 59 | async with recv_stream: 60 | async for message in recv_stream: 61 | assert message["type"] == "http.response.body" 62 | yield message.get("body", b"") 63 | 64 | if app_exc is not None: 65 | raise app_exc 66 | 67 | response = StreamingResponse( 68 | status_code=message["status"], content=body_stream() 69 | ) 70 | response.raw_headers = message["headers"] 71 | return response 72 | 73 | async with anyio.create_task_group() as task_group: 74 | request = Request(scope, receive=receive) 75 | headers = MutableHeaders(scope=scope) 76 | # trace_id generates once at Request-Responce pair to propagate itself 77 | # to Request object of any endpoint and returns to Responce 78 | trace_id: str = uuid4().hex 79 | trace_id_context.set(trace_id) 80 | headers.append(self.tracing_header, trace_id) 81 | response = await self.dispatch_func(request, call_next) 82 | response.headers.append(self.tracing_header, trace_id) 83 | await response(scope, receive, send) 84 | task_group.cancel_scope.cancel() 85 | 86 | async def dispatch( 87 | self, request: Request, call_next: RequestResponseEndpoint 88 | ) -> Response: 89 | response = await call_next(request) 90 | return response 91 | 92 | 93 | class TraceIdFilter(Filter): 94 | """Logging filter to attached trace IDs to log records""" 95 | 96 | def __init__(self, name: str = '', uuid_length: typing.Optional[int] = None): 97 | super().__init__(name=name) 98 | self.uuid_length = uuid_length 99 | 100 | def filter(self, record: LogRecord) -> bool: 101 | """ 102 | Attach a trace (correlation) ID to the log record. 103 | Since the trace ID is defined in the middleware layer, any 104 | log generated from a request after this point can easily be searched 105 | for, if the trace ID is added to the message, or included as 106 | metadata. 107 | """ 108 | trace_id: str | None = trace_id_context.get() 109 | if self.uuid_length is not None and trace_id: 110 | record.__setattr__('trace_id', trace_id[: self.uuid_length]) 111 | else: 112 | record.__setattr__('trace_id', trace_id) 113 | return True 114 | -------------------------------------------------------------------------------- /src/utils/piccolo.py: -------------------------------------------------------------------------------- 1 | from piccolo.table import Table as BaseTable 2 | from piccolo.columns import m2m 3 | import inspect 4 | from uuid import uuid4 5 | from typing import Any 6 | 7 | 8 | def uuid4_for_PK() -> str: 9 | """ 10 | Just returns UUID4 in string format 11 | Using for 'default' kwarg in PK TEXT column 12 | 13 | default=uuid4_for_PK #for column 14 | uuid4_for_PK() #for other cases 15 | 16 | Returns: 17 | str: uuid4 string 18 | """ 19 | return str(uuid4()) 20 | 21 | 22 | def get_pk_from_resp(resp: Any, attr: str) -> str | None: 23 | try: 24 | return resp[0].get('id') 25 | except (IndexError, TypeError, ValueError, AttributeError): 26 | return None 27 | 28 | 29 | class Table(BaseTable): 30 | def __init__(self, ignore_missing: bool = False, exists_in_db: bool = False, **kwargs): 31 | super().__init__(ignore_missing, exists_in_db, **kwargs) # type: ignore 32 | 33 | async def __join_field(self, field: str, ignore: bool = False) -> list: 34 | """ 35 | Runs get_m2m for a FIELD of object. Catches ValueError, when there are 36 | no relations in M2M table and returns empty list(). If ignore flag is 37 | True, also returns empty list 38 | Args: 39 | field (str): Attr name 40 | ignore (bool, optional): Flag for include and exclude logic of join_m2m. Defaults to False. 41 | 42 | Returns: 43 | list: list of related objects, or an empty one 44 | """ 45 | if ignore: 46 | return list() 47 | try: 48 | return await self.get_m2m(self.__getattribute__(field)).run() 49 | except ValueError: 50 | return list() 51 | 52 | async def join_m2m( 53 | self, 54 | include_fields: set[str] | list[str] | None = None, 55 | exclude_fields: set[str] | list[str] | None = None 56 | ): 57 | """ 58 | Runs get_m2m() method for all M2M fields of object. Can be useful for 59 | complex PyDantic models in READ actions. Returns empty list() for an 60 | attribute, if there are no relations to this object. 61 | 62 | Optional, you can include or exclude fields to define which attrs should 63 | be joined. Setting either include_fields, and exclude_fields will raise 64 | AssertionError. 65 | 66 | .. code-block:: python 67 | 68 | >>> band = await Band.objects().get(Band.name == "Pythonistas") 69 | >>> await band.join_m2m() 70 | >>> band.genres 71 | [, ] 72 | >>> band.tours 73 | [,,] 74 | 75 | include_fields example: 76 | 77 | .. code-block:: python 78 | 79 | >>> await band.join_m2m(include_fields=['genres']) 80 | >>> band.genres 81 | [, ] 82 | >>> band.tours 83 | [] 84 | 85 | exclude_fields example: 86 | 87 | .. code-block:: python 88 | 89 | >>> await band.join_m2m(exclude_fields=['genres']) 90 | >>> band.genres 91 | [] 92 | >>> band.tours 93 | [,,] 94 | 95 | Args: 96 | include_fields (set[str] | list[str] | None, optional): 97 | Only this fields will be joined to base model`s object. Defaults to None. 98 | exclude_fields (set[str] | list[str] | None, optional): 99 | This fields will be excluded from join. Defaults to None. 100 | """ 101 | assert (include_fields is None) or (exclude_fields is None), "Only one of FIELDS arguments can exist" 102 | if include_fields is not None: 103 | assert isinstance(include_fields, set | list), "include_fields MUST be set, list or None" 104 | if exclude_fields is not None: 105 | assert isinstance(exclude_fields, set | list), "exclude_fields MUST be set, list or None" 106 | m2m_fields: set = set([field for field, object in inspect.getmembers( 107 | self, 108 | lambda a:( 109 | isinstance( 110 | a, 111 | m2m.M2M 112 | ) 113 | ) 114 | ) 115 | ]) 116 | ignore_fields: list = list() 117 | if include_fields: 118 | ignore_fields = list(m2m_fields.difference(set(include_fields))) 119 | if exclude_fields: 120 | ignore_fields = list(m2m_fields.intersection(set(exclude_fields))) 121 | for field in list(m2m_fields): 122 | ignore: bool = False 123 | if field in ignore_fields: 124 | ignore = True 125 | self.__setattr__( 126 | field, # M2M attr name 127 | await self.__join_field( 128 | field=field, 129 | ignore=ignore 130 | ) 131 | ) 132 | -------------------------------------------------------------------------------- /src/accounting/rbac/routing.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends 2 | from utils.api_versioning import APIRouter, APIVersion 3 | from .endpoints import ( 4 | UserRoleCRUD, 5 | UserGroupCRUD, 6 | PermissionCRUD, 7 | PolicyCRUD 8 | ) 9 | from accounting.schemas import ( 10 | UserRead, 11 | RoleRead, 12 | GroupRead, 13 | PermissionRead, 14 | PolicyRead 15 | ) 16 | from accounting.authentication.jwt import get_user_by_token 17 | 18 | rbac_user_router = APIRouter( 19 | prefix="/accounting/rbac/user", 20 | tags=["AAA->Accounting->RBAC->User"], 21 | responses={ 22 | 404: {"description": "URL not found"}, 23 | 400: {"description": "Bad request"} 24 | }, 25 | dependencies=[Depends(get_user_by_token)], 26 | version=APIVersion(1) 27 | ) 28 | 29 | rbac_role_router = APIRouter( 30 | prefix="/accounting/rbac/role", 31 | tags=["AAA->Accounting->RBAC->Role"], 32 | responses={ 33 | 404: {"description": "URL not found"}, 34 | 400: {"description": "Bad request"} 35 | }, 36 | dependencies=[Depends(get_user_by_token)], 37 | version=APIVersion(1) 38 | ) 39 | 40 | rbac_group_router = APIRouter( 41 | prefix="/accounting/rbac/group", 42 | tags=["AAA->Accounting->RBAC->Group"], 43 | responses={ 44 | 404: {"description": "URL not found"}, 45 | 400: {"description": "Bad request"} 46 | }, 47 | dependencies=[Depends(get_user_by_token)], 48 | version=APIVersion(1) 49 | ) 50 | 51 | rbac_permissions_router = APIRouter( 52 | prefix="/accounting/rbac/permissions", 53 | tags=["AAA->Accounting->RBAC->Permissions"], 54 | responses={ 55 | 404: {"description": "URL not found"}, 56 | 400: {"description": "Bad request"} 57 | }, 58 | dependencies=[Depends(get_user_by_token)], 59 | version=APIVersion(1) 60 | ) 61 | 62 | rbac_policies_router = APIRouter( 63 | prefix="/accounting/rbac/policies", 64 | tags=["AAA->Accounting->RBAC->Policies"], 65 | responses={ 66 | 404: {"description": "URL not found"}, 67 | 400: {"description": "Bad request"} 68 | }, 69 | dependencies=[Depends(get_user_by_token)], 70 | version=APIVersion(1) 71 | ) 72 | 73 | rbac_user_router.add_api_route( 74 | '/roles/', 75 | UserRoleCRUD.add_roles_to_user, 76 | response_model=UserRead, 77 | summary='Add roles to user', 78 | methods=['put']) 79 | 80 | rbac_user_router.add_api_route( 81 | '/roles/', 82 | UserRoleCRUD.delete_roles_from_user, 83 | response_model=UserRead, 84 | summary='Remove roles from user', 85 | methods=['patch']) 86 | 87 | rbac_user_router.add_api_route( 88 | '/groups/', 89 | UserGroupCRUD.add_groups_to_user, 90 | response_model=UserRead, 91 | summary='Add groups to user', 92 | methods=['put']) 93 | 94 | rbac_user_router.add_api_route( 95 | '/groups/', 96 | UserGroupCRUD.delete_groups_from_user, 97 | response_model=UserRead, 98 | summary='Remove groups from user', 99 | methods=['patch']) 100 | 101 | rbac_role_router.add_api_route( 102 | '/users/', 103 | UserRoleCRUD.add_users_to_role, 104 | response_model=RoleRead, 105 | summary='Add users to role', 106 | methods=['put']) 107 | 108 | rbac_role_router.add_api_route( 109 | '/users/', 110 | UserRoleCRUD.delete_users_from_role, 111 | response_model=RoleRead, 112 | summary='Remove users from role', 113 | methods=['patch']) 114 | 115 | rbac_group_router.add_api_route( 116 | '/users/', 117 | UserGroupCRUD.add_users_to_group, 118 | response_model=GroupRead, 119 | summary='Add users to group', 120 | methods=['put']) 121 | 122 | rbac_group_router.add_api_route( 123 | '/users/', 124 | UserGroupCRUD.delete_users_from_group, 125 | response_model=GroupRead, 126 | summary='Remove users from group', 127 | methods=['patch']) 128 | 129 | rbac_permissions_router.add_api_route( 130 | '/', 131 | PermissionCRUD.get_all_permissions, 132 | response_model=list[PermissionRead], 133 | summary='Get all permissions', 134 | methods=['get']) 135 | 136 | rbac_permissions_router.add_api_route( 137 | '/{id}', 138 | PermissionCRUD.get_permission, 139 | response_model=PermissionRead, 140 | summary='Get permission by id', 141 | methods=['get']) 142 | 143 | rbac_policies_router.add_api_route( 144 | '/', 145 | PolicyCRUD.get_all_policies, 146 | response_model=list[PolicyRead], 147 | summary='Get all policies', 148 | methods=['get']) 149 | 150 | rbac_policies_router.add_api_route( 151 | '/{id}', 152 | PolicyCRUD.get_policy, 153 | response_model=PolicyRead, 154 | summary='Get policy by id', 155 | methods=['get']) 156 | 157 | rbac_policies_router.add_api_route( 158 | '/', 159 | PolicyCRUD.add_policy, 160 | response_model=PolicyRead, 161 | status_code=201, 162 | summary='Create policy', 163 | methods=['post']) 164 | 165 | rbac_policies_router.add_api_route( 166 | '/{id}', 167 | PolicyCRUD.update_policy, 168 | response_model=PolicyRead, 169 | summary='Update policy (full or partial)', 170 | methods=['put']) 171 | 172 | rbac_policies_router.add_api_route( 173 | '/{id}', 174 | PolicyCRUD.delete_policy, 175 | status_code=204, 176 | summary='Delete policy', 177 | methods=['delete']) 178 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /src/accounting/rbac/models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from loguru import logger 3 | from typing import TypeVar, Type, Tuple 4 | from utils.piccolo import Table, uuid4_for_PK 5 | from piccolo.columns import m2m 6 | from piccolo.columns.column_types import ( 7 | Text, Boolean, ForeignKey, LazyTableReference 8 | ) 9 | from asyncpg.exceptions import UniqueViolationError 10 | from utils.exceptions import IntegrityException, ObjectNotFoundException 11 | from configuration import config 12 | from piccolo.columns.readable import Readable 13 | from accounting.users import User 14 | from accounting.groups import Group 15 | from accounting.roles import Role 16 | from accounting.schemas import PermissionCreate 17 | 18 | T_P = TypeVar('T_P', bound='Policy') 19 | T_Pm = TypeVar('T_Pm', bound='Permission') 20 | 21 | tz: datetime.timezone = config.Main.tz 22 | 23 | 24 | class Permission(Table, tablename="permissions"): 25 | 26 | id = Text(primary_key=True, index=True, default=uuid4_for_PK) 27 | name = Text(unique=True, index=False, null=False) 28 | object = Text(unique=True, index=True, null=False) 29 | description = Text(unique=False, index=False, null=True) 30 | policies = m2m.M2M(LazyTableReference("Policy", module_path=__name__)) 31 | 32 | @classmethod 33 | def get_readable(cls): 34 | return Readable(template="%s", columns=[cls.name]) 35 | 36 | @classmethod 37 | async def get_all(cls: Type[T_Pm], offset: int, limit: int) -> list[T_Pm]: 38 | resp: list[T_Pm] = await cls.objects().limit(limit).offset(offset) 39 | for r in resp: 40 | await r.join_m2m() 41 | return resp 42 | 43 | @classmethod 44 | async def get_by_id(cls: Type[T_Pm], id: str) -> T_Pm: 45 | permission: T_Pm = await cls.objects().where(cls.id == id).first() 46 | try: 47 | assert permission 48 | except AssertionError: 49 | raise ObjectNotFoundException(object_name=__name__, object_id=id) 50 | else: 51 | await permission.join_m2m() 52 | return permission 53 | 54 | @classmethod 55 | async def add(cls: Type[T_Pm], name: str, object: str, description: str = str()) -> T_Pm: 56 | 57 | new_id = uuid4_for_PK() 58 | permission: T_Pm = cls( 59 | id=new_id, 60 | name=name, 61 | object=object, 62 | description=description 63 | ) 64 | try: 65 | resp = await cls.insert(permission) 66 | except UniqueViolationError as ex: 67 | raise IntegrityException(ex) 68 | else: 69 | inserted_pk = resp[0].get('id') 70 | return await cls.get_by_id(inserted_pk) 71 | 72 | @classmethod 73 | async def add_from_list(cls: Type[T_Pm], objects: list[PermissionCreate]) -> Tuple[int, int]: 74 | existing_permissions: int = int() 75 | inserted_permissions: int = int() 76 | for permission in objects: 77 | permission.id = uuid4_for_PK() 78 | p: T_Pm = cls( 79 | **permission.dict() 80 | ) 81 | try: 82 | await cls.insert(p) 83 | except UniqueViolationError: 84 | existing_permissions = existing_permissions + 1 85 | else: 86 | inserted_permissions = inserted_permissions + 1 87 | return (existing_permissions, inserted_permissions) 88 | 89 | 90 | class M2MUserRole(Table): 91 | user = ForeignKey(User) 92 | role = ForeignKey(Role) 93 | 94 | 95 | class M2MUserGroup(Table): 96 | user = ForeignKey(User) 97 | group = ForeignKey(Group) 98 | 99 | 100 | class Policy(Table, tablename="policies"): 101 | id = Text(primary_key=True, index=True, default=uuid4_for_PK) 102 | permission = ForeignKey(Permission, null=False) 103 | role = ForeignKey(Role, null=False) 104 | active = Boolean(nullable=False, default=True) 105 | name = Text(unique=False, index=False, null=False) 106 | description = Text(unique=False, index=False, null=True) 107 | 108 | @classmethod 109 | async def get_all(cls: Type[T_P], offset: int, limit: int) -> list[T_P]: 110 | # type: ignore 111 | resp: list[T_P] = await cls.objects(cls.all_related()).limit(limit).offset(offset) 112 | return resp 113 | 114 | @classmethod 115 | async def get_by_id(cls: Type[T_P], id: str) -> T_P: 116 | # type: ignore 117 | policy: T_P = await cls.objects(cls.all_related()).where(cls.id == id).first() 118 | try: 119 | assert policy 120 | except AssertionError: 121 | raise ObjectNotFoundException(object_name=__name__, object_id=id) 122 | else: 123 | return policy 124 | 125 | @classmethod 126 | async def add( 127 | cls: Type[T_P], 128 | permission_id: str, 129 | role_id: str, 130 | name: str | None = None, 131 | active: bool = True, 132 | description: str = str() 133 | ) -> T_P: 134 | permission: Permission = await Permission.get_by_id(id=permission_id) 135 | role: Role = await Role.get_by_id(id=role_id) 136 | new_id: str = uuid4_for_PK() 137 | if not name: 138 | name = f"{role.name}->{permission.object}" 139 | policy: T_P = cls( 140 | id=new_id, 141 | permission=permission, 142 | role=role, 143 | active=active, 144 | name=name, 145 | description=description 146 | ) 147 | try: 148 | resp = await cls.insert(policy) 149 | except UniqueViolationError as ex: 150 | raise IntegrityException(ex) 151 | else: 152 | inserted_pk = resp[0].get('id') 153 | return await cls.get_by_id(inserted_pk) 154 | 155 | @classmethod 156 | async def update_by_id(cls: Type[T_P], id: str, data: dict) -> T_P: 157 | await cls.update(**data).where(cls.id == id) 158 | return await cls.get_by_id(id) 159 | 160 | @classmethod 161 | async def delete_by_id(cls: Type[T_P], id: str) -> None: 162 | await cls.get_by_id(id) 163 | await cls.delete().where(cls.id == id) 164 | -------------------------------------------------------------------------------- /src/cli/accounting.py: -------------------------------------------------------------------------------- 1 | import typer 2 | import asyncio 3 | from enum import Enum 4 | from fastapi.exceptions import HTTPException 5 | from .config_loader import set_config, config_default 6 | from loguru import logger 7 | from email_validator import validate_email, EmailUndeliverableError 8 | 9 | from .console import print, info, warning, success, error 10 | 11 | 12 | from .shared import prepare_db_through_vault 13 | 14 | 15 | async def create_user(username: str, password: str, email: str, superuser: bool = False): 16 | from accounting.users.models import User 17 | try: 18 | validate_email(email, check_deliverability=False) 19 | except EmailUndeliverableError: 20 | error(f"Email {email} is not valid") 21 | # Raising HTTP error to propagate error to another thread 22 | raise HTTPException(status_code=422) 23 | else: 24 | return await User.add(username, password, email, as_superuser=superuser) 25 | 26 | app = typer.Typer( 27 | no_args_is_help=True, 28 | short_help='Operations with users and other AAA objects' 29 | ) 30 | 31 | 32 | class CreatingObjects(str, Enum): 33 | superuser: str = "superuser" 34 | user: str = "user" 35 | secret: str = "secret" 36 | 37 | 38 | @app.command(short_help='Creating objects', no_args_is_help=True) 39 | def create(object: CreatingObjects, c: str = config_default): 40 | """ 41 | Creating AAA objects 42 | """ 43 | set_config(c) 44 | from configuration import config 45 | from utils import vault, events 46 | from utils.security import generate_random_string_token 47 | prepare_db_through_vault() 48 | match object: 49 | case 'superuser': 50 | username = typer.prompt("Username") 51 | password = typer.prompt("Password") 52 | email = typer.prompt("Email") 53 | try: 54 | resp = asyncio.run(create_user( 55 | username, password, email, True)) 56 | except HTTPException as ex: 57 | error(f'Unable to create superuser: [code]{ex.detail}[/ code]') 58 | else: 59 | success( 60 | f'Superuser [bold]{username}[/ bold] was created with id [bold]{resp.id}[/ bold]') 61 | case 'user': 62 | username = typer.prompt("Username") 63 | password = typer.prompt("Password") 64 | email = typer.prompt("Email") 65 | try: 66 | resp = asyncio.run(create_user(username, password, email)) 67 | except HTTPException as ex: 68 | 69 | error(f'Unable to create user: [code]{ex.detail}[/ code]') 70 | else: 71 | success( 72 | f'User [bold]{username}[/ bold] was created with id [bold]{resp.id}[/ bold]') 73 | case 'secret': 74 | if config.Security.jwt_base_secret: 75 | warning( 76 | """ 77 | jwt_base_secret is defined in config file. 78 | Comment this line in [bold]Security[/bold] section to load secret from external storage 79 | """ 80 | ) 81 | info( 82 | f"Using [bold]{config.Security.jwt_base_secret_storage}[/bold] external storage") 83 | new_secret: str = generate_random_string_token() 84 | secret_to_check: str = str() 85 | match config.Security.jwt_base_secret_storage: 86 | case 'local': 87 | try: 88 | with open(config.Security.jwt_base_secret_filename, 'w') as f: 89 | f.write(new_secret) 90 | f.close() 91 | with open(config.Security.jwt_base_secret_filename, 'r') as f: 92 | secret_to_check = f.readline() 93 | config.Security.set_jwt_base_secret( 94 | secret_to_check) 95 | except (FileNotFoundError, PermissionError): 96 | error( 97 | f"Cannot write jwt secret to file {config.Security.jwt_base_secret_filename}") 98 | case 'vault': 99 | asyncio.run(events.init_vault()) 100 | vault_subkey: str = 'base_secret' 101 | try: 102 | response_write: dict = asyncio.run( 103 | vault.write_kv_data( 104 | secret_name=config.Security.jwt_base_secret_vault_secret_name, 105 | payload={'data': {vault_subkey: new_secret}}, 106 | storage_name=config.Security.jwt_base_secret_vault_storage_name 107 | ) 108 | ) 109 | assert response_write, "Vault write operation failed" 110 | response_read: dict | None = asyncio.run( 111 | vault.read_kv_data( 112 | secret_name=config.Security.jwt_base_secret_vault_secret_name, 113 | storage_name=config.Security.jwt_base_secret_vault_storage_name 114 | ) 115 | ) 116 | assert (response_read and isinstance( 117 | response_read, dict)), "Cannot load secret from Vault" 118 | secret_to_check = response_read.get( 119 | vault_subkey, str()) 120 | config.Security.set_jwt_base_secret(secret_to_check) 121 | except AssertionError as ex: 122 | error(str(ex)) 123 | else: 124 | success("Secret generation completed") 125 | case _: 126 | pass 127 | try: 128 | assert new_secret == secret_to_check, "Wroted secret is broken" 129 | assert secret_to_check == config.Security.get_jwt_base_secret( 130 | ), "Loading to CONFIG is broken" 131 | except AssertionError as ex: 132 | error(str(ex)) 133 | else: 134 | success("All checks successfully passed") 135 | 136 | case _: 137 | pass 138 | 139 | 140 | if __name__ == "__main__": 141 | app() 142 | -------------------------------------------------------------------------------- /src/utils/events.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from configuration import config 3 | 4 | 5 | def load_endpoints(app): 6 | from accounting import ( 7 | user_router, 8 | role_router, 9 | group_router, 10 | rbac_user_router, 11 | rbac_permissions_router, 12 | rbac_group_router, 13 | rbac_policies_router, 14 | rbac_role_router 15 | ) 16 | from accounting.authentication.routing import auth_router 17 | from utils.routing import misc_router 18 | 19 | app.include_router(user_router) 20 | app.include_router(role_router) 21 | app.include_router(group_router) 22 | app.include_router(auth_router) 23 | app.include_router(rbac_user_router) 24 | app.include_router(rbac_role_router) 25 | app.include_router(rbac_group_router) 26 | app.include_router(rbac_permissions_router) 27 | app.include_router(rbac_policies_router) 28 | app.include_router(misc_router) 29 | 30 | 31 | def create_admin_gui(app, admin_url: str, site_name: str): 32 | from piccolo_admin.endpoints import create_admin 33 | from fastapi import routing 34 | from accounting import User, Role, Group, Permission, Policy, Sessions, M2MUserRole, M2MUserGroup 35 | app.routes.append( 36 | routing.Mount( 37 | admin_url, 38 | create_admin( 39 | [ 40 | User, 41 | Role, 42 | Group, 43 | Permission, 44 | Policy, 45 | M2MUserGroup, 46 | M2MUserRole 47 | ], 48 | auth_table=User, # type: ignore 49 | session_table=Sessions, 50 | allowed_hosts=['localhost'], 51 | production=False, 52 | site_name=site_name 53 | ) 54 | ) 55 | ) 56 | 57 | 58 | async def init_vault(): 59 | from . import vault 60 | await vault.init() 61 | 62 | 63 | async def load_vault_db_creds(): 64 | from . import vault 65 | from configuration import config 66 | from loguru import logger 67 | if config.Database.is_vault_enable: 68 | logger.info('Using Vault for DB credentials') 69 | creds = await vault.get_db_creds( 70 | config.Database.db_vault_role, 71 | static=config.Database.is_vault_static, 72 | storage_name=config.Database.db_vault_storage 73 | ) 74 | config.Database.set_connection_string( 75 | config.Database.build_connection_string( 76 | username=creds.username, password=creds.password) 77 | ) 78 | logger.debug(f'DB engine will be created from user {creds.username}') 79 | else: 80 | # Was already tested directly 81 | config.Database.set_connection_string( # pragma: no cover 82 | config.Database.build_connection_string() 83 | ) 84 | 85 | 86 | async def reload_db_creds(): 87 | from . import vault 88 | from configuration import config 89 | from loguru import logger 90 | creds = await vault.get_db_creds( 91 | config.Database.db_vault_role, 92 | static=config.Database.is_vault_static, 93 | storage_name=config.Database.db_vault_storage 94 | ) 95 | logger.info(f'Obtained new DB credentials from Vault for {creds.username}') 96 | config.Database.set_connection_string( 97 | config.Database.build_connection_string( 98 | username=creds.username, password=creds.password) 99 | ) 100 | 101 | 102 | async def load_endpoint_permissions(app): 103 | from accounting import Permission 104 | from accounting.schemas import PermissionCreate 105 | from loguru import logger 106 | BASE_PERMISSIONS = list() 107 | for r in app.routes: 108 | try: 109 | BASE_PERMISSIONS.append( 110 | PermissionCreate( 111 | object=r.endpoint.__name__, 112 | name=r.summary, 113 | description=r.endpoint.__doc__ 114 | ) 115 | ) 116 | if r.endpoint.__getattribute__('rbac_enable'): 117 | r.summary = f'{r.summary} | RBAC enabled' 118 | except AttributeError: 119 | continue 120 | (existing_permissions, inserted_permissons) = await Permission.add_from_list(BASE_PERMISSIONS) 121 | logger.info( 122 | f"Base permissions were loaded. {inserted_permissons} entries were inserted, {existing_permissions} entries were existed" 123 | ) 124 | 125 | 126 | async def load_base_jwt_secret( 127 | jwt_base_secret: str | None = config.Security.jwt_base_secret, 128 | jwt_base_secret_storage: str | None = config.Security.jwt_base_secret_storage, 129 | jwt_base_secret_filename: str | None = config.Security.jwt_base_secret_filename, 130 | jwt_base_secret_vault_secret_name: str | None = config.Security.jwt_base_secret_vault_secret_name, 131 | jwt_base_secret_vault_storage_name: str | None = config.Security.jwt_base_secret_vault_storage_name, 132 | vault=None 133 | ): 134 | if not vault: 135 | from . import vault 136 | # if secret is defined in config file with `jwt_base_secret =` 137 | # we will use this one 138 | if jwt_base_secret: 139 | return 140 | # or trying to load it from external storage 141 | match jwt_base_secret_storage: 142 | case 'local': 143 | try: 144 | with open(jwt_base_secret_filename, 'r') as f: 145 | key: str = f.readline() 146 | assert key, 'File is empty!' 147 | config.Security.set_jwt_base_secret(key) 148 | except (FileNotFoundError, PermissionError): 149 | logger.critical( 150 | f"Cannot open jwt secret file {jwt_base_secret_filename}") 151 | except AssertionError as ex: 152 | logger.critical(str(ex)) 153 | case 'vault': 154 | vault_subkey: str = 'base_secret' 155 | try: 156 | resp: dict = await vault.read_kv_data( 157 | secret_name=jwt_base_secret_vault_secret_name, 158 | storage_name=jwt_base_secret_vault_storage_name 159 | ) 160 | assert (resp and isinstance(resp, dict) 161 | ), "Cannot load secret from Vault" 162 | key: str | None = resp.get(vault_subkey) 163 | assert key, f"{vault_subkey} not found" 164 | config.Security.set_jwt_base_secret(key) 165 | except AssertionError as ex: 166 | logger.critical(str(ex)) 167 | case _: 168 | logger.critical("Cannot load jwt secret") 169 | -------------------------------------------------------------------------------- /src/tests/04_accounting_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | from fastapi.testclient import TestClient 3 | from app import create_app 4 | from .payloads import ( 5 | test_superuser_1, 6 | test_user_1, 7 | test_user_2, 8 | test_role_1, 9 | test_group_1, 10 | ) 11 | from .payload_models import UserModel 12 | from .shared import prepare_db_with_users 13 | 14 | app = create_app() 15 | client = TestClient(app) 16 | api_version: str = 'v1' 17 | application_name: str = 'accounting' 18 | base_url: str = f"/{api_version}/{application_name}/" 19 | 20 | 21 | def authenticate_as(user) -> str: 22 | response = client.post( 23 | "/auth/token", 24 | data={ 25 | 'username': user.username, 26 | 'password': user.password 27 | }, 28 | headers={ 29 | 'Content-Type': 'application/x-www-form-urlencoded' 30 | } 31 | ) 32 | assert response.status_code == 200 33 | auth_data: dict = response.json() 34 | return f"{auth_data.get('token_type')} {auth_data.get('access_token')}" 35 | 36 | 37 | def test_auth_as_superuser(): 38 | prepare_db_with_users(test_superuser_1, test_user_1) 39 | response = authenticate_as(test_superuser_1) 40 | assert isinstance(response, str) 41 | 42 | 43 | def test_user_get_all_unauth(): 44 | response = client.get("/v1/accounting/users/") 45 | assert response.status_code == 401 46 | 47 | 48 | def test_user_get_all(): 49 | prepare_db_with_users(test_superuser_1, test_user_1) 50 | response = client.get( 51 | f"{base_url}users", 52 | headers={ 53 | 'Authorization': authenticate_as(test_superuser_1) 54 | } 55 | ) 56 | assert response.status_code == 200 57 | assert len(response.json()) == 2 58 | 59 | 60 | def test_create_user(): 61 | prepare_db_with_users(test_superuser_1, test_user_1) 62 | response = client.post( 63 | f"{base_url}users/", 64 | headers={ 65 | 'Authorization': authenticate_as(test_superuser_1) 66 | }, 67 | data=test_user_2.json() 68 | ) 69 | assert response.status_code == 201 70 | assert response.json().get('id') 71 | return response.json().get('id') 72 | 73 | 74 | def test_get_user_by_id(): 75 | prepare_db_with_users(test_superuser_1, test_user_1) 76 | user_id: str = test_create_user() 77 | response = client.get( 78 | f"{base_url}users/{user_id}", 79 | headers={ 80 | 'Authorization': authenticate_as(test_superuser_1) 81 | } 82 | ) 83 | assert response.status_code == 200 84 | assert response.json().get('id') == user_id 85 | assert response.json().get('username') == test_user_2.username 86 | 87 | 88 | def test_update_user_by_id(): 89 | prepare_db_with_users(test_superuser_1, test_user_1) 90 | user_id: str = test_create_user() 91 | response = client.put( 92 | f"{base_url}users/{user_id}", 93 | headers={ 94 | 'Authorization': authenticate_as(test_superuser_1), 95 | 'Content-Type': 'application/json' 96 | }, 97 | json={'username': 'new_username'} 98 | ) 99 | assert response.status_code == 200 100 | assert response.json().get('id') == user_id 101 | assert response.json().get('username') != test_user_2.username 102 | assert response.json().get('username') == 'new_username' 103 | 104 | 105 | def test_change_user_password(): 106 | prepare_db_with_users(test_superuser_1, test_user_1) 107 | user_id: str = test_create_user() 108 | assert authenticate_as(test_user_2) 109 | response = client.patch( 110 | f"{base_url}users/{user_id}", 111 | headers={ 112 | 'Authorization': authenticate_as(test_superuser_1), 113 | 'Content-Type': 'application/json' 114 | }, 115 | json={ 116 | 'old_password': test_user_2.password, 117 | 'new_password': 'new_password' 118 | } 119 | ) 120 | assert response.status_code == 200 121 | assert authenticate_as( 122 | UserModel( 123 | username=test_user_2.username, 124 | password='new_password', 125 | email='foobar@mail.com' 126 | ) 127 | ) 128 | 129 | 130 | def test_delete_user(): 131 | prepare_db_with_users(test_superuser_1, test_user_1) 132 | user_id: str = test_create_user() 133 | response = client.delete( 134 | f"{base_url}users/{user_id}", 135 | headers={ 136 | 'Authorization': authenticate_as(test_superuser_1) 137 | } 138 | ) 139 | assert response.status_code == 204 140 | response = client.get( 141 | f"{base_url}users/{user_id}", 142 | headers={ 143 | 'Authorization': authenticate_as(test_superuser_1) 144 | } 145 | ) 146 | assert response.status_code == 404 147 | 148 | 149 | def test_create_role(): 150 | prepare_db_with_users(test_superuser_1, test_user_1) 151 | response = client.post( 152 | f"{base_url}roles/", 153 | headers={ 154 | 'Authorization': authenticate_as(test_superuser_1) 155 | }, 156 | data=test_role_1.json() 157 | ) 158 | assert response.status_code == 201 159 | assert response.json().get('id') 160 | return response.json().get('id') 161 | 162 | 163 | def test_get_all_roles(): 164 | test_create_role() 165 | response = client.get( 166 | f"{base_url}roles/", 167 | headers={ 168 | 'Authorization': authenticate_as(test_superuser_1) 169 | } 170 | ) 171 | assert response.status_code == 200 172 | assert isinstance(response.json(), list) 173 | assert len(response.json()) == 1 174 | 175 | 176 | def test_get_role_by_id(): 177 | role_id: str = test_create_role() 178 | response = client.get( 179 | f"{base_url}roles/{role_id}", 180 | headers={ 181 | 'Authorization': authenticate_as(test_superuser_1) 182 | } 183 | ) 184 | assert response.status_code == 200 185 | assert isinstance(response.json(), dict) 186 | assert response.json().get('id') == role_id 187 | assert response.json().get('name') == test_role_1.name 188 | 189 | 190 | def test_update_role_by_id(): 191 | role_id: str = test_create_role() 192 | response = client.put( 193 | f"{base_url}roles/{role_id}", 194 | headers={ 195 | 'Authorization': authenticate_as(test_superuser_1), 196 | 'Content-Type': 'application/json' 197 | }, 198 | json={'name': 'new_name'} 199 | ) 200 | assert response.status_code == 200 201 | assert isinstance(response.json(), dict) 202 | assert response.json().get('id') == role_id 203 | assert response.json().get('name') == 'new_name' 204 | 205 | 206 | def test_delete_role(): 207 | role_id: str = test_create_role() 208 | response = client.delete( 209 | f"{base_url}roles/{role_id}", 210 | headers={ 211 | 'Authorization': authenticate_as(test_superuser_1) 212 | } 213 | ) 214 | assert response.status_code == 204 215 | response = client.get( 216 | f"{base_url}roles/{role_id}", 217 | headers={ 218 | 'Authorization': authenticate_as(test_superuser_1) 219 | } 220 | ) 221 | assert response.status_code == 404 222 | 223 | 224 | def test_create_group(): 225 | prepare_db_with_users(test_superuser_1, test_user_1) 226 | response = client.post( 227 | f"{base_url}groups/", 228 | headers={ 229 | 'Authorization': authenticate_as(test_superuser_1) 230 | }, 231 | data=test_group_1.json() 232 | ) 233 | assert response.status_code == 201 234 | assert response.json().get('id') 235 | return response.json().get('id') 236 | 237 | 238 | def test_get_all_groups(): 239 | test_create_group() 240 | response = client.get( 241 | f"{base_url}groups/", 242 | headers={ 243 | 'Authorization': authenticate_as(test_superuser_1) 244 | } 245 | ) 246 | assert response.status_code == 200 247 | assert isinstance(response.json(), list) 248 | assert len(response.json()) == 1 249 | 250 | 251 | def test_get_group_by_id(): 252 | group_id: str = test_create_group() 253 | response = client.get( 254 | f"{base_url}groups/{group_id}", 255 | headers={ 256 | 'Authorization': authenticate_as(test_superuser_1) 257 | } 258 | ) 259 | assert response.status_code == 200 260 | assert isinstance(response.json(), dict) 261 | assert response.json().get('id') == group_id 262 | assert response.json().get('name') == test_group_1.name 263 | 264 | 265 | def test_update_group_by_id(): 266 | group_id: str = test_create_group() 267 | response = client.put( 268 | f"{base_url}groups/{group_id}", 269 | headers={ 270 | 'Authorization': authenticate_as(test_superuser_1), 271 | 'Content-Type': 'application/json' 272 | }, 273 | json={'name': 'new_name'} 274 | ) 275 | assert response.status_code == 200 276 | assert isinstance(response.json(), dict) 277 | assert response.json().get('id') == group_id 278 | assert response.json().get('name') == 'new_name' 279 | 280 | 281 | def test_delete_group(): 282 | group_id: str = test_create_group() 283 | response = client.delete( 284 | f"{base_url}groups/{group_id}", 285 | headers={ 286 | 'Authorization': authenticate_as(test_superuser_1) 287 | } 288 | ) 289 | assert response.status_code == 204 290 | response = client.get( 291 | f"{base_url}groups/{group_id}", 292 | headers={ 293 | 'Authorization': authenticate_as(test_superuser_1) 294 | } 295 | ) 296 | assert response.status_code == 404 297 | -------------------------------------------------------------------------------- /src/cli/db.py: -------------------------------------------------------------------------------- 1 | import typer 2 | import asyncio 3 | from .console import print, info, success, error, warning 4 | from rich.table import Table as CLI_Table 5 | from rich.console import Console 6 | from .config_loader import set_config, config_default 7 | from dataclasses import dataclass 8 | from .shared import prepare_db_through_vault 9 | app = typer.Typer(no_args_is_help=True, short_help='Operations with DB') 10 | console = Console() 11 | migrations_app = typer.Typer(short_help='DB migrations', no_args_is_help=True) 12 | app.add_typer(migrations_app, name='mg') 13 | 14 | 15 | @dataclass 16 | class TableScan(): 17 | from piccolo.table import Table # CircularImport error 18 | application: str 19 | table: Table 20 | exists: bool | None = None 21 | action: str | None = None 22 | result: bool | None = None 23 | rows: int = int() 24 | 25 | 26 | def get_tables_list( 27 | apps: list | None = None, 28 | check_for_existance: bool = False, 29 | count_rows: bool = False 30 | ) -> list[TableScan]: 31 | """ 32 | Scan APP_REGISTRY for registered tables and returns list of TableScan objects 33 | 34 | Args: 35 | apps (list | None, optional): Applications to include or ALL of them. Defaults to None. 36 | check_for_existance (bool, optional): Check existing in db and fill `exists` field. Defaults to False. 37 | count_rows (bool, optional): Count rows for every table. Defaults to False 38 | Returns: 39 | list[TableScan]: list of scanned tables 40 | """ 41 | from piccolo_conf import APP_REGISTRY 42 | tables: list = list() 43 | for app in APP_REGISTRY.apps: 44 | app_name: str = app.rstrip('.piccolo_app') 45 | if apps is not None: 46 | if app_name not in apps: 47 | continue 48 | app_tables: list = APP_REGISTRY.get_table_classes( 49 | app_name 50 | ) 51 | for app_table in app_tables: 52 | exists: bool | None = None 53 | if check_for_existance: 54 | exists = app_table.table_exists().run_sync() 55 | rows: int = int() 56 | if count_rows and exists: 57 | rows = app_table.count().run_sync() 58 | tables.append( 59 | TableScan( 60 | application=app_name, 61 | table=app_table, 62 | exists=exists, 63 | rows=rows 64 | ) 65 | ) 66 | return tables 67 | 68 | 69 | @app.command(help="Show current state of tables") 70 | def show( 71 | app_name: str = typer.Argument( 72 | 'all', help='Application name, ex. `accounting` or `all` for all registered apps'), 73 | c: str = config_default 74 | ): 75 | """ 76 | Show scanned tables 77 | 78 | Args: 79 | app_name (str, optional): Specify an application. Defaults to 'all'. 80 | """ 81 | set_config(c) 82 | prepare_db_through_vault() 83 | apps: list[str] | None = None 84 | if app_name != 'all': 85 | apps = [app_name] 86 | tables: list = get_tables_list( 87 | apps=apps, check_for_existance=True, count_rows=True) 88 | cli_table: CLI_Table = CLI_Table( 89 | "#", "Application", "Table name", "Exists", "Rows") 90 | counter: int = 1 91 | for table in tables: 92 | exists_str: str = ':no_entry:' 93 | if table.exists: 94 | exists_str = ':green_circle:' 95 | cli_table.add_row( 96 | str(counter), 97 | table.application, 98 | table.table.__name__, 99 | exists_str, 100 | str(table.rows) 101 | ) 102 | counter = counter + 1 103 | console.print(cli_table) 104 | 105 | 106 | @app.command(help="Create all or app specified tables for application, existing tables will be ignored") 107 | def init( 108 | app_name: str = typer.Argument( 109 | 'all', help='Application name, ex. `accounting` or `all` for all registered apps'), 110 | c: str = config_default 111 | ): 112 | """ 113 | Create tables from scanned apps, all or for selected application 114 | 115 | Args: 116 | app_name (str, optional): _description_. Defaults to 'all'. 117 | """ 118 | set_config(c) 119 | prepare_db_through_vault() 120 | from piccolo.table import create_db_tables_sync 121 | apps: list[str] | None = None 122 | if app_name != 'all': 123 | apps = [app_name] 124 | tables: list = get_tables_list(apps=apps, check_for_existance=True) 125 | cli_table: CLI_Table = CLI_Table( 126 | "#", "Application", "Table name", "Already exists", "Action", "Result") 127 | counter: int = 1 128 | for table in tables: 129 | if table.exists: 130 | table.action = "[yellow]Ignore[/yellow]" 131 | else: 132 | table.action = "[green]Create[/green]" 133 | create_db_tables_sync(table.table, if_not_exists=True) 134 | if table.table.table_exists().run_sync(): 135 | if table.exists: 136 | table.result = "[green]Ignored[/green]" 137 | else: 138 | table.result = "[green]Created[/green]" 139 | else: 140 | table.result = "[red]Error[/red]" 141 | exists_str: str = ':no_entry:' 142 | if table.exists: 143 | exists_str = ':green_circle:' 144 | cli_table.add_row( 145 | str(counter), 146 | table.application, 147 | table.table.__name__, 148 | exists_str, 149 | table.action, 150 | table.result 151 | ) 152 | counter = counter + 1 153 | console.print(cli_table) 154 | 155 | 156 | @app.command(help="Drop all or app specified tables for application, existing tables will be ignored") 157 | def drop( 158 | app_name: str = typer.Argument( 159 | 'all', help='Application name, ex. `accounting` or `all` for all registered apps'), 160 | c: str = config_default 161 | ): 162 | """ 163 | Drop tables from scanned apps, all or for selected application 164 | 165 | Args: 166 | app_name (str, optional): _description_. Defaults to 'all'. 167 | """ 168 | set_config(c) 169 | prepare_db_through_vault() 170 | from piccolo.table import drop_db_tables_sync 171 | from piccolo_conf import APP_REGISTRY 172 | apps: list[str] | None = None 173 | if app_name != 'all': 174 | apps = [app_name] 175 | tables: list = get_tables_list(apps=apps, check_for_existance=True) 176 | typer.confirm( 177 | f"Are you sure you want to delete {app_name} tables?", abort=True) 178 | cli_table: CLI_Table = CLI_Table( 179 | "#", "Application", "Table name", "Already exists", "Action", "Result") 180 | counter: int = 1 181 | for table in tables: 182 | if table.exists: 183 | table.action = "[red bold]Drop[/red bold]" 184 | else: 185 | table.action = "[yellow]Ignore[/yellow]" 186 | drop_db_tables_sync(table.table) 187 | if not table.table.table_exists().run_sync(): 188 | if table.exists: 189 | table.result = "[green]Dropped[/green]" 190 | else: 191 | table.result = "[green]Ignored[/green]" 192 | else: 193 | table.result = "[red]Error[/red]" 194 | exists_str: str = ':no_entry:' 195 | if table.exists: 196 | exists_str = ':green_circle:' 197 | cli_table.add_row( 198 | str(counter), 199 | table.application, 200 | table.table.__name__, 201 | exists_str, 202 | table.action, 203 | table.result 204 | ) 205 | counter = counter + 1 206 | console.print(cli_table) 207 | 208 | 209 | """Migrations commands""" 210 | 211 | 212 | @migrations_app.command(help='Create migrations without running') 213 | def create( 214 | app_name: str = typer.Argument( 215 | 'all', help='Application name, ex. `accounting` or `all` for all registered apps'), 216 | c: str = config_default 217 | ) -> None: 218 | set_config(c) 219 | from piccolo.apps.migrations.commands.new import new 220 | from piccolo_conf import APP_REGISTRY 221 | apps: list = list() 222 | if app_name == 'all': 223 | apps = APP_REGISTRY.apps 224 | else: 225 | apps.append(app_name) 226 | for app in apps: 227 | app_name = app.rstrip('.piccolo_app') 228 | info(f'Running for [bold]{app_name}[/bold] app') 229 | asyncio.run( 230 | new( 231 | app_name=app_name, 232 | auto=True 233 | ) 234 | ) 235 | 236 | 237 | @migrations_app.command(help='Run created migrations') 238 | def run( 239 | app_name: str = typer.Argument( 240 | 'all', help='Application name, ex. `accounting` or `all` for all registered apps'), 241 | c: str = config_default, 242 | m: str = typer.Option('all', help='Migration id to run'), 243 | fake: bool = typer.Option( 244 | False, is_flag=True, help='Runs migrations in FAKE mode') 245 | ) -> None: 246 | set_config(c) 247 | prepare_db_through_vault() 248 | from piccolo.apps.migrations.commands.forwards import run_forwards 249 | from piccolo_conf import APP_REGISTRY 250 | apps: list = list() 251 | if app_name == 'all': 252 | apps = APP_REGISTRY.apps 253 | else: 254 | apps.append(app_name) 255 | for app in apps: 256 | app_name = app.rstrip('.piccolo_app') 257 | info(f'Running for [bold]{app_name}[/bold] app') 258 | asyncio.run( 259 | run_forwards( 260 | app_name=app_name, 261 | migration_id=m, 262 | fake=fake 263 | ) 264 | ) 265 | 266 | 267 | if __name__ == "__main__": 268 | app() 269 | -------------------------------------------------------------------------------- /src/accounting/users/models.py: -------------------------------------------------------------------------------- 1 | from utils.piccolo import Table, uuid4_for_PK, get_pk_from_resp 2 | from piccolo.columns.column_types import ( 3 | Text, Boolean, Timestamp, LazyTableReference 4 | ) 5 | from piccolo.columns import m2m 6 | import datetime 7 | from utils.crypto import create_password_hash 8 | from typing import TypeVar, Type, Optional 9 | from loguru import logger 10 | from asyncpg.exceptions import UniqueViolationError, SyntaxOrAccessError 11 | from utils.exceptions import IntegrityException, ObjectNotFoundException, BaseBadRequestException 12 | from piccolo.columns.readable import Readable 13 | from configuration import config 14 | 15 | T_U = TypeVar('T_U', bound='User') 16 | 17 | 18 | def foo() -> datetime.datetime: 19 | return datetime.datetime.now() 20 | 21 | 22 | class User(Table, tablename="users"): 23 | 24 | # Main section 25 | id = Text(primary_key=True, index=True, default=uuid4_for_PK) 26 | username = Text(unique=True, index=True, null=False) 27 | email = Text(unique=True, index=False, nullable=True) 28 | password = Text(unique=False, index=False, null=False) 29 | 30 | first_name = Text(null=True) 31 | last_name = Text(null=True) 32 | # Flags 33 | active = Boolean(nullable=False, default=True) 34 | admin = Boolean( 35 | default=False, help_text="An admin can log into the Piccolo admin GUI." 36 | ) 37 | superuser = Boolean( 38 | default=False, 39 | help_text=( 40 | "If True, this user can manage other users's passwords in the " 41 | "Piccolo admin GUI." 42 | ), 43 | ) 44 | # Dates 45 | created_at = Timestamp(null=True) 46 | updated_at = Timestamp(null=True) 47 | last_login = Timestamp(null=True) 48 | birthdate = Timestamp(null=True) 49 | # Relations 50 | roles = m2m.M2M(LazyTableReference( 51 | "M2MUserRole", module_path='accounting')) 52 | groups = m2m.M2M(LazyTableReference( 53 | "M2MUserGroup", module_path='accounting')) 54 | 55 | def is_valid_password(self, plain_password) -> bool: 56 | return self.password == create_password_hash(plain_password) 57 | 58 | def is_active(self) -> bool: 59 | return self.active 60 | 61 | def get_user_id(self) -> str: 62 | return str(self.id) 63 | 64 | async def update_login_ts(self) -> None: 65 | data = { 66 | 'last_login': datetime.datetime.now() 67 | } 68 | await self.update_by_id( 69 | self.id, 70 | data, 71 | update_ts=False 72 | ) 73 | return None 74 | 75 | @classmethod 76 | async def add(cls: Type[T_U], username: str, password: str, email: str, as_superuser: bool = False) -> T_U: 77 | 78 | new_id = uuid4_for_PK() 79 | password_hash: str = create_password_hash(password) 80 | user: T_U = cls( 81 | id=new_id, 82 | username=username, 83 | password=password_hash, 84 | email=email, 85 | superuser=as_superuser, 86 | admin=as_superuser 87 | ) 88 | try: 89 | resp = await cls.insert(user) 90 | except UniqueViolationError as ex: 91 | raise IntegrityException(ex) 92 | else: 93 | inserted_pk: str | None = get_pk_from_resp(resp, 'id') 94 | return await cls.get_by_id(inserted_pk) # type: ignore 95 | 96 | @classmethod 97 | async def get_all(cls: Type[T_U], offset: int, limit: int) -> list[T_U]: 98 | resp: list[T_U] = await cls.objects().limit(limit).offset(offset) 99 | # Running JOIN for m2m relations, I don`t now how to do this shit better 100 | for r in resp: 101 | await r.join_m2m() 102 | return resp 103 | 104 | @classmethod 105 | async def get_by_id(cls: Type[T_U], id: str) -> T_U: 106 | user: T_U = await cls.objects().where(cls.id == id).first() 107 | try: 108 | assert user 109 | except AssertionError: 110 | raise ObjectNotFoundException(object_name=__name__, object_id=id) 111 | else: 112 | await user.join_m2m() 113 | return user 114 | 115 | @classmethod 116 | async def get_by_username(cls: Type[T_U], username: str, raise_404: bool = True) -> T_U | None: 117 | user: T_U = await cls.objects().where(cls.username == username).first() 118 | try: 119 | assert user 120 | except AssertionError: 121 | if raise_404: 122 | raise ObjectNotFoundException( 123 | object_name=__name__, object_id=username) 124 | else: 125 | return None 126 | else: 127 | await user.join_m2m() 128 | return user 129 | 130 | @classmethod 131 | async def get_by_email(cls: Type[T_U], email: str, raise_404: bool = True) -> T_U | None: 132 | user: T_U = await cls.objects().where(cls.email == email).first() 133 | try: 134 | assert user 135 | except AssertionError: 136 | if raise_404: 137 | raise ObjectNotFoundException( 138 | object_name=__name__, object_id=email) 139 | else: 140 | return None 141 | else: 142 | await user.join_m2m() 143 | return user 144 | 145 | @classmethod 146 | async def update_by_id(cls: Type[T_U], id: str, data: dict, update_ts: bool = True) -> T_U: 147 | if update_ts: 148 | data['updated_at'] = datetime.datetime.now() 149 | await cls.update(**data).where(cls.id == id) 150 | return await cls.get_by_id(id) 151 | 152 | @classmethod 153 | async def change_password(cls: Type[T_U], id: str, old_plaintext_password: str, new_plaintext_password: str) -> T_U: 154 | user: T_U = await cls.get_by_id(id) 155 | try: 156 | assert old_plaintext_password != new_plaintext_password, 'Passwords are equal' 157 | assert user.password == create_password_hash( 158 | old_plaintext_password), 'Invalid old password' 159 | data: dict = { 160 | 'password': create_password_hash(new_plaintext_password) 161 | } 162 | return await cls.update_by_id(id, data) 163 | except AssertionError as ex: 164 | raise BaseBadRequestException(str(ex)) 165 | 166 | @classmethod 167 | async def delete_by_id(cls: Type[T_U], id: str) -> None: 168 | await cls.get_by_id(id) 169 | await cls.delete().where(cls.id == id) 170 | 171 | @classmethod 172 | async def authenticate_user(cls: Type[T_U], username: str, password: str) -> T_U | None: 173 | user: T_U | None 174 | login_fields: list[str] = config.Security.available_login_fields 175 | # We are using searching in the list instead of raw attrs to avoid SQL injections 176 | for field in login_fields: 177 | try: 178 | match field: 179 | case "username": user = await cls.get_by_username(username, raise_404=False) 180 | case "email": user = await cls.get_by_email(username, raise_404=False) 181 | case _: raise ObjectNotFoundException(object_name='User', object_id=username) 182 | # We set raise_404=False to avoid 404 Exception and try to find User with another field 183 | assert user 184 | except AssertionError: 185 | # Step forward to try another field 186 | continue 187 | else: 188 | # Break the loop, if the found user 189 | break 190 | try: 191 | # Check the result from last field 192 | assert user 193 | except AssertionError: 194 | raise ObjectNotFoundException( 195 | object_name='User', object_id=username) 196 | try: 197 | assert user.is_valid_password( 198 | plain_password=password), 'Bad credentials' # type: ignore 199 | assert user.is_active(), 'User was deactivated' # type: ignore 200 | except AssertionError as ex: 201 | logger.warning(f'AUTH | {ex} | {username}') 202 | raise BaseBadRequestException(str(ex)) 203 | 204 | else: 205 | await user.update_login_ts() # type: ignore 206 | logger.info(f'AUTH | SUCCESS | {username}') 207 | return user 208 | 209 | @classmethod 210 | async def login(cls, username: str, password: str) -> Optional[int]: 211 | """ 212 | Implementation of 'login' method for piccolo admin (session auth) 213 | 214 | :returns: 215 | The id of the user if a match is found, otherwise ``None``. 216 | """ 217 | user = await cls.authenticate_user(username, password) 218 | if not user: 219 | return None 220 | else: 221 | return user.id # type: ignore 222 | 223 | @classmethod 224 | def get_readable(cls): 225 | return Readable(template="%s", columns=[cls.username]) 226 | 227 | @classmethod 228 | async def add_roles(cls: Type[T_U], user_id: str, role_ids: list[str]): 229 | from accounting import Role # CircularImport error 230 | user: T_U = await cls.objects().get(cls.id == user_id) 231 | for role_id in role_ids: 232 | role = await Role.get_by_id(role_id) 233 | await user.add_m2m( 234 | role, # type: ignore 235 | m2m=cls.roles 236 | ) 237 | return await cls.get_by_id(user_id) 238 | 239 | @classmethod 240 | async def delete_roles(cls: Type[T_U], user_id: str, role_ids: list[str]): 241 | from accounting import Role # CircularImport error 242 | user: T_U = await cls.objects().get(cls.id == user_id) 243 | for role_id in role_ids: 244 | role = await Role.get_by_id(role_id) 245 | await user.remove_m2m( 246 | role, # type: ignore 247 | m2m=cls.roles 248 | ) 249 | return await cls.get_by_id(user_id) 250 | 251 | @classmethod 252 | async def add_groups(cls: Type[T_U], user_id: str, group_ids: list[str]): 253 | from accounting import Group # CircularImport error 254 | user: T_U = await cls.objects().get(cls.id == user_id) 255 | for group_id in group_ids: 256 | group = await Group.get_by_id(group_id) 257 | await user.add_m2m( 258 | group, # type: ignore 259 | m2m=cls.groups 260 | ) 261 | return await cls.get_by_id(user_id) 262 | 263 | @classmethod 264 | async def delete_groups(cls: Type[T_U], user_id: str, group_ids: list[str]): 265 | from accounting import Group # CircularImport error 266 | user: T_U = await cls.objects().get(cls.id == user_id) 267 | for group_id in group_ids: 268 | group = await Group.get_by_id(group_id) 269 | await user.remove_m2m( 270 | group, # type: ignore 271 | m2m=cls.groups 272 | ) 273 | return await cls.get_by_id(user_id) 274 | 275 | async def get_all_user_roles(self): 276 | from accounting import Role 277 | await self.join_m2m() 278 | roles: list[Role] = self.roles 279 | return roles 280 | -------------------------------------------------------------------------------- /src/utils/vault.py: -------------------------------------------------------------------------------- 1 | from configuration import config 2 | from loguru import logger 3 | from utils.telemetry import tracer 4 | from async_hvac import AsyncClient, exceptions 5 | import aiohttp 6 | from pydantic import BaseModel 7 | import os 8 | import json 9 | 10 | 11 | class Vault(): 12 | 13 | class DBCredsModel(BaseModel): 14 | username: str 15 | password: str 16 | 17 | class UnsealingKeys(BaseModel): 18 | keys: list[str] | None 19 | keys_base64: list[str] = list() 20 | root_token: str | None 21 | 22 | class VaultAuth(BaseModel): 23 | auth_method: str | None = None 24 | token: str | None = None 25 | credentials: str | None = None 26 | 27 | def __init__(self, auth=None): 28 | self.unsealing_keys: self.UnsealingKeys = self.UnsealingKeys() 29 | if auth: 30 | self.auth: self.VaultAuth = auth 31 | else: 32 | self.auth: self.VaultAuth = self.VaultAuth() 33 | 34 | def __repr__(self) -> str: 35 | return f"" 36 | 37 | async def init(self): 38 | if config.Vault.is_enabled: 39 | self.load_auth_data() 40 | if await self.check_vault_state(): 41 | logger.info("Vault instance is ready") 42 | else: 43 | logger.critical("Vault instance creation failed") 44 | else: 45 | logger.info("Vault module is inactive") 46 | 47 | def get_auth_token(self) -> str | None: 48 | if config.Vault.vault_token: 49 | return config.Vault.vault_token 50 | if self.unsealing_keys: 51 | return self.unsealing_keys.root_token 52 | logger.critical('Cannot obtain Vault auth token') 53 | return None 54 | 55 | def load_auth_data(self) -> None: 56 | auth_method: str = config.Vault.vault_auth_method 57 | self.auth.auth_method = auth_method 58 | if config.Vault.vault_keyfile_type: 59 | self.unsealing_keys = self.open_keys_file( 60 | filetype=config.Vault.vault_keyfile_type, 61 | # type: ignore #TODO prev field check in config-loader 62 | filename=config.Vault.vault_unseal_keys 63 | ) 64 | match auth_method: 65 | case 'token': 66 | token = self.get_auth_token() 67 | self.auth.token = token 68 | case _: 69 | logger.critical('Unknown vault auth method') 70 | return None 71 | 72 | def open_keys_file(self, filetype: str, filename: str) -> UnsealingKeys: # type: ignore 73 | try: 74 | with open(filename, 'r') as f: 75 | if filetype == 'json': 76 | data = json.load(f) 77 | return self.UnsealingKeys(**data) 78 | if filetype == 'keys': 79 | data = f.readlines() 80 | return self.UnsealingKeys(keys_base64=data) 81 | except FileNotFoundError: 82 | logger.critical( 83 | f'File with unsealing vault keys {filename} not found') 84 | except PermissionError: 85 | logger.critical(f'Cannot open {filename}, permission denied') 86 | 87 | async def get_db_creds(self, role_name: str, static: bool = True, storage_name: str = 'database') -> DBCredsModel: 88 | with tracer.start_as_current_span("security:Vault:get_db_creds") as span: 89 | span.set_attribute("role.name", role_name) 90 | span.set_attribute("role.static", static) 91 | span.set_attribute("role.storage", storage_name) 92 | creds_type = 'static-creds' 93 | if not static: 94 | creds_type = 'creds' 95 | data = await self._action('read', f'{storage_name}/{creds_type}/{role_name}') 96 | try: 97 | assert data, "Unable to obtain db credential from Vault" 98 | except AssertionError as ex: 99 | logger.critical(ex) 100 | os._exit(0) 101 | creds = self.DBCredsModel.parse_obj(data) 102 | return creds 103 | 104 | async def read_kv_data(self, secret_name: str, storage_name: str = 'kv') -> dict | None: # type: ignore 105 | with tracer.start_as_current_span("security:Vault:read_kv_data") as span: 106 | span.set_attribute("vault.storage", storage_name) 107 | span.set_attribute("vault.secret", secret_name) 108 | resp: dict = await self._action('read', f'{storage_name}/data/{secret_name}') 109 | try: 110 | assert resp, "Unable to obtain KV data from Vault" 111 | except AssertionError as ex: 112 | logger.error(ex) 113 | return None 114 | else: 115 | return resp.get('data', dict()) 116 | 117 | async def write_kv_data(self, secret_name: str, payload: dict, storage_name: str = 'kv') -> dict: 118 | with tracer.start_as_current_span("security:Vault:write_kv_data") as span: 119 | span.set_attribute("vault.storage", storage_name) 120 | span.set_attribute("vault.secret", secret_name) 121 | resp: dict = await self._action('write', f'{storage_name}/data/{secret_name}', payload=payload) 122 | try: 123 | assert resp, "Unable to write KV data to Vault" 124 | except AssertionError as ex: 125 | logger.error(ex) 126 | except exceptions.InvalidPath: 127 | logger.error(f"Invalid vault path {storage_name}") 128 | return resp 129 | 130 | async def request_certificate( 131 | self, 132 | role_name: str, 133 | storage_name: str, 134 | common_name: str, 135 | cert_ttl: str 136 | ) -> dict | None: 137 | with tracer.start_as_current_span("security:Vault:request_certificate") as span: 138 | span.set_attribute("role.name", role_name) 139 | span.set_attribute("role.storage", storage_name) 140 | payload: dict = { 141 | 'common_name': common_name, 142 | 'ttl': cert_ttl 143 | } 144 | data = await self._action('write', f'{storage_name}/issue/{role_name}', payload=payload) 145 | try: 146 | assert data, "Unable to request certificate from Vault" 147 | except AssertionError as ex: 148 | logger.warning(ex) 149 | return None 150 | return data 151 | 152 | async def _action(self, action_type, route, payload=None): 153 | with tracer.start_as_current_span("security:Vault:_action") as span: 154 | span.set_attribute("action.type", action_type) 155 | span.set_attribute("action.route", route) 156 | try: 157 | async with self.get_instance() as client: 158 | try: 159 | assert not await client.is_sealed(), "Vault storage is sealed" 160 | assert await client.is_initialized(), "Vault storage is not initialized" 161 | assert await client.is_authenticated(), "Vault authentication error" 162 | match action_type: 163 | case 'read': 164 | return await self._read(client, route) 165 | case 'write': 166 | return await self._write(client, route, **payload) 167 | case _: 168 | logger.error( 169 | f"Unknown Vault operation {action_type}") 170 | except (AssertionError, aiohttp.client_exceptions.ClientConnectorError) as ex: 171 | logger.error(ex) 172 | raise AttributeError 173 | except AttributeError: 174 | logger.error("Vault instance creation failed") 175 | return False 176 | 177 | async def _read(self, instance: AsyncClient, route: str): 178 | with tracer.start_as_current_span("security:Vault:_read"): 179 | try: 180 | resp = await instance.read(route) 181 | assert resp, "Empty responce" 182 | data = resp.get('data') 183 | assert data, "Empty data field" 184 | except (exceptions.InvalidRequest, AssertionError) as ex: 185 | logger.error(f"Vault read operation error: {ex}") 186 | return None 187 | else: 188 | return data 189 | 190 | async def _write(self, instance: AsyncClient, route: str, **kwargs): 191 | with tracer.start_as_current_span("security:Vault:_write"): 192 | try: 193 | resp = await instance.write(route, **kwargs) 194 | assert resp, "Empty responce" 195 | data = resp.get('data') 196 | assert data, "Empty data field" 197 | except (exceptions.InvalidRequest, AssertionError, exceptions.InvalidPath) as ex: 198 | logger.error(f"Vault write operation error: {ex}") 199 | return None 200 | else: 201 | return data 202 | 203 | async def unseal_vault(self, vault_instance: AsyncClient) -> bool: 204 | logger.warning('Vault instance is sealed, trying to unseal...') 205 | try: 206 | keys: list = self.unsealing_keys.keys_base64 207 | await vault_instance.unseal_multi(keys) 208 | except exceptions.InvalidRequest: 209 | logger.critical('Broken keys structure in unsealing_keys file') 210 | 211 | if await vault_instance.is_sealed(): 212 | logger.critical('Vault unsealing process failed') 213 | return False 214 | else: 215 | logger.info('Vault instance was successfully unsealed') 216 | return True 217 | 218 | async def check_vault_state(self) -> bool: 219 | """ 220 | Method for checking vault service state with separate 221 | client session 222 | """ 223 | 224 | try: 225 | async with self.get_instance() as client: 226 | try: 227 | if config.Vault.is_unsealing_available and await client.is_sealed(): 228 | assert await self.unseal_vault(client), "Unable to unseal vault" 229 | assert not await client.is_sealed(), "Vault storage is sealed" 230 | assert await client.is_initialized(), "Vault storage is not initialized" 231 | assert await client.is_authenticated(), "Vault authentication error" 232 | except AssertionError as ex: 233 | logger.error(ex) 234 | raise AttributeError 235 | except aiohttp.client_exceptions.ClientConnectorError as ex: 236 | logger.error(ex) 237 | raise AttributeError 238 | except AttributeError: 239 | logger.error("Vault instance creation failed") 240 | return False 241 | else: 242 | logger.debug("Vault instance is ready") 243 | return True 244 | 245 | def get_instance(self) -> AsyncClient: 246 | instance = AsyncClient() 247 | scheme = "http" 248 | if config.Vault.is_tls: 249 | scheme = "https" 250 | url = f"{scheme}://{config.Vault.vault_host}:{config.Vault.vault_port}" 251 | match self.auth.auth_method: 252 | case "token": 253 | token = self.auth.token 254 | instance = AsyncClient(url=url, token=token) 255 | case _: 256 | pass 257 | return instance 258 | -------------------------------------------------------------------------------- /src/configuration/sections.py: -------------------------------------------------------------------------------- 1 | from .base import BaseSectionModel 2 | from pydantic import (validator, PostgresDsn) 3 | import ipaddress 4 | import datetime 5 | from loguru import logger 6 | from typing import Any 7 | import sys 8 | import re 9 | import os 10 | 11 | 12 | class MainSectionConfiguration(BaseSectionModel): 13 | 14 | application_mode: str = 'prod' 15 | log_level: str = 'INFO' 16 | log_destination: str = 'stdout' 17 | log_in_json: int = 0 18 | log_sql: int = 0 19 | timezone: int = +3 20 | enable_swagger: int = 1 21 | swagger_doc_url: str = '/doc' 22 | swagger_redoc_url: str = '/redoc' 23 | 24 | @validator('application_mode') 25 | def check_appmode(cls, v): 26 | assert isinstance(v, str) 27 | assert v in ['prod', 'dev'], f"Unknown app_mode {v}" 28 | return v 29 | 30 | @validator('log_level') 31 | def check_loglevel(cls, v): 32 | assert isinstance(v, str) 33 | assert v.upper() in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] 34 | return v.upper() 35 | 36 | @validator('log_destination') 37 | def check_logdestination(cls, v): 38 | assert isinstance(v, str) 39 | assert v == 'stdout' or os.path.isfile(v) 40 | return v 41 | 42 | @validator('log_in_json', 'log_sql', 'enable_swagger') 43 | def check_int_as_bool(cls, v): 44 | assert isinstance(v, int) 45 | assert v in [0, 1] 46 | return v 47 | 48 | @validator('swagger_doc_url', 'swagger_redoc_url') 49 | def check_admin_url_slashes(cls, v): 50 | assert isinstance(v, str) 51 | assert bool(re.match('/.*', v)), 'url MUST starts with /' 52 | return v 53 | 54 | @property 55 | def log_sink(self): 56 | if self.log_destination == 'stdout': 57 | return sys.stdout 58 | else: 59 | return self.log_destination 60 | 61 | @property 62 | def tz(self) -> datetime.timezone: 63 | return ( 64 | datetime.timezone( 65 | datetime.timedelta( 66 | hours=self.timezone 67 | ) 68 | ) 69 | ) 70 | 71 | @property 72 | def is_swagger_enabled(self) -> bool: 73 | return bool(self.enable_swagger) 74 | 75 | @property 76 | def doc_url(self) -> str | None: 77 | if self.is_swagger_enabled: 78 | return self.swagger_doc_url 79 | else: 80 | return None 81 | 82 | @property 83 | def redoc_url(self) -> str | None: 84 | if self.is_swagger_enabled: 85 | return self.swagger_redoc_url 86 | else: 87 | return None 88 | 89 | @property 90 | def is_prod_mode(self) -> bool: 91 | return self.application_mode == 'prod' 92 | 93 | 94 | class AdminGUISectionConfiguration(BaseSectionModel): 95 | 96 | admin_enable: int = 1 97 | admin_url: str = '/admin/' 98 | 99 | @validator('admin_enable') 100 | def check_admin_enable(cls, v): 101 | assert isinstance(v, int) 102 | assert v in [0, 1] 103 | return v 104 | 105 | @validator('admin_url') 106 | def check_admin_url_slashes(cls, v): 107 | assert isinstance(v, str) 108 | assert bool(re.match('/.*/', v)), 'url MUST starts and end with /' 109 | return v 110 | 111 | @property 112 | def is_admin_gui_enable(self) -> bool: 113 | return bool(self.admin_enable) 114 | 115 | 116 | class ServerSectionConfiguration(BaseSectionModel): 117 | 118 | bind_address: str = '127.0.0.1' 119 | bind_port: int = 8000 120 | base_url: str = 'localhost' 121 | 122 | @validator('bind_port') 123 | def check_port(cls, v): 124 | assert isinstance(v, int) 125 | assert v in range(0, 65535) 126 | return v 127 | 128 | @validator('bind_address') 129 | def check_address(cls, v): 130 | try: 131 | ipaddress.ip_address(v) 132 | except ValueError: 133 | assert v == 'localhost' 134 | v = '127.0.0.1' 135 | return v 136 | 137 | 138 | class DatabaseSectionConfiguration(BaseSectionModel): 139 | 140 | db_driver: str = 'postgresql+asyncpg' 141 | db_host: str = '127.0.0.1' 142 | db_port: int = 5432 143 | db_name: str = 'enforcer' 144 | db_username: str = 'enforcer' 145 | db_password: str = 'enforcer' 146 | 147 | db_vault_enable: int = 0 148 | db_vault_role: str = 'myrole' 149 | db_vault_static: int = 1 150 | db_vault_storage: str = 'database' 151 | 152 | connection_string: str = "empty" 153 | 154 | engine: Any | None = None 155 | 156 | def set_connection_string(self, s: str): 157 | self.connection_string = s 158 | 159 | def get_connection_string(self) -> str: 160 | return self.connection_string 161 | 162 | def get_engine(self): 163 | return self.engine 164 | 165 | def set_engine(self, new_engine): 166 | self.engine = new_engine 167 | 168 | @validator('db_driver') 169 | def check_driver(cls, v): 170 | assert v in ['postgresql+asyncpg', 'postgresql'] 171 | return v 172 | 173 | @validator('db_port') 174 | def check_port(cls, v): 175 | assert isinstance(v, int) 176 | assert v in range(0, 65535) 177 | return v 178 | 179 | @validator('db_host') 180 | def check_address(cls, v): 181 | try: 182 | ipaddress.ip_address(v) 183 | except ValueError: 184 | assert v == 'localhost' 185 | v = '127.0.0.1' 186 | return v 187 | 188 | @validator('db_vault_enable', 'db_vault_static') 189 | def check_int_as_bool(cls, v): 190 | assert v in [0, 1] 191 | return v 192 | 193 | @property 194 | def is_vault_enable(self) -> bool: 195 | return bool(self.db_vault_enable) 196 | 197 | @property 198 | def is_vault_static(self) -> bool: 199 | return bool(self.db_vault_static) 200 | 201 | def build_connection_string(self, username: str | None = None, password: str | None = None) -> str: 202 | if not username or not password: 203 | logger.debug('Using plaintext credentials') 204 | username = self.db_username 205 | password = self.db_password 206 | return PostgresDsn.build( 207 | scheme=self.db_driver, 208 | host=self.db_host, 209 | port=str(self.db_port), 210 | path=f'/{self.db_name}', 211 | user=username, 212 | password=password, 213 | ) 214 | 215 | 216 | class VaultSectionConfiguration(BaseSectionModel): 217 | 218 | vault_enable: int = 0 219 | vault_host: str = 'localhost' 220 | vault_port: int = 8200 221 | vault_disable_tls: int = 0 222 | vault_auth_method: str = 'token' 223 | vault_token: str | None = None 224 | vault_credentials: str | None = None 225 | vault_keyfile_type: str | None = None 226 | vault_try_to_unseal: int = 0 227 | vault_unseal_keys: str | None = None 228 | 229 | @validator('vault_enable', 'vault_disable_tls', 'vault_try_to_unseal') 230 | def check_int_as_bool(cls, v): 231 | assert v in [0, 1] 232 | return v 233 | 234 | @validator('vault_host') 235 | def check_address(cls, v): 236 | try: 237 | ipaddress.ip_address(v) 238 | except ValueError: 239 | assert v == 'localhost' 240 | v = '127.0.0.1' 241 | return v 242 | 243 | @validator('vault_port') 244 | def check_port(cls, v): 245 | assert isinstance(v, int) 246 | assert v in range(0, 65535) 247 | return v 248 | 249 | @validator('vault_auth_method') 250 | def check_auth_method(cls, v): 251 | assert v in ['token'] 252 | return v 253 | 254 | @validator('vault_keyfile_type') 255 | def check_keyfile_type(cls, v): 256 | if v: 257 | assert v in ['json', 'keys'] 258 | return v 259 | 260 | @property 261 | def is_enabled(self) -> bool: 262 | return bool(self.vault_enable) 263 | 264 | @property 265 | def is_tls(self) -> bool: 266 | return not bool(self.vault_disable_tls) 267 | 268 | @property 269 | def is_unsealing_available(self) -> bool: 270 | return bool(self.vault_try_to_unseal) 271 | 272 | 273 | class TelemetrySectionConfiguration(BaseSectionModel): 274 | 275 | enable: int = 1 276 | agent_type: str = 'jaeger' 277 | agent_host: str = '127.0.0.1' 278 | agent_port: int = 6831 279 | trace_id_length: int = 0 280 | 281 | @validator('enable') 282 | def check_status(cls, v): 283 | assert v in [0, 1] 284 | return v 285 | 286 | @validator('agent_type') 287 | def check_type(cls, v): 288 | assert v in ['jaeger'] 289 | return v 290 | 291 | @validator('agent_port') 292 | def check_port(cls, v): 293 | assert isinstance(v, int) 294 | assert v in range(0, 65535) 295 | return v 296 | 297 | @validator('agent_host') 298 | def check_address(cls, v): 299 | try: 300 | ipaddress.ip_address(v) 301 | except ValueError: 302 | assert v == 'localhost' 303 | v = '127.0.0.1' 304 | return v 305 | 306 | @validator('trace_id_length') 307 | def check_trace_id_len(cls, v): 308 | assert v in range(0, 32) 309 | return v 310 | 311 | @property 312 | def is_active(self) -> bool: 313 | return bool(self.enable) 314 | 315 | @property 316 | def is_trace_id_enabled(self) -> bool: 317 | return self.trace_id_length > 0 318 | 319 | 320 | class SecuritySectionConfiguration(BaseSectionModel): 321 | 322 | _available_jwt_algorithms: list[str] = ['HS256'] 323 | _available_jwt_base_secret_storage_types: list[str] = ['local', 'vault'] 324 | jwt_base_secret: str | None = None 325 | 326 | enable_rbac: int = 1 327 | login_with_username: int = 1 328 | login_with_email: int = 0 329 | jwt_algorithm: str = "HS256" 330 | jwt_ttl: int = 3600 331 | 332 | jwt_base_secret_storage: str | None = 'local' 333 | jwt_base_secret_filename: str = 'secret.key' 334 | jwt_base_secret_vault_storage_name: str = 'kv' 335 | jwt_base_secret_vault_secret_name: str = 'jwt' 336 | 337 | @validator('enable_rbac', 'login_with_username', 'login_with_email') 338 | def check_int_as_bool(cls, v): 339 | assert v in [0, 1] 340 | return v 341 | 342 | @validator('jwt_algorithm') 343 | def check_jwt_algo(cls, v): 344 | assert v in cls._available_jwt_algorithms, f"JWT algorithm {v} is unavailable or unknown" 345 | return v 346 | 347 | @validator('jwt_base_secret_storage') 348 | def check_jwt_storage(cls, v): 349 | assert v in cls._available_jwt_base_secret_storage_types, f"JWT storage type {v} is unavailable or unknown" 350 | return v 351 | 352 | @validator('jwt_ttl') 353 | def check_jwt_ttl(cls, v): 354 | assert v > 0, "JWT ttl MUST be greater then 0 seconds" 355 | return v 356 | 357 | @property 358 | def is_rbac_enabled(self) -> bool: 359 | return bool(self.enable_rbac) 360 | 361 | @property 362 | def is_username_login_enabled(self) -> bool: 363 | return bool(self.login_with_username) 364 | 365 | @property 366 | def is_email_login_enabled(self) -> bool: 367 | return bool(self.login_with_email) 368 | 369 | @property 370 | def available_login_fields(self) -> list[str]: 371 | login_fields: list[str] = [] 372 | if self.is_username_login_enabled: 373 | login_fields.append('username') 374 | if self.is_email_login_enabled: 375 | login_fields.append('email') 376 | return login_fields 377 | 378 | def set_jwt_base_secret(self, secret: str) -> None: 379 | self.jwt_base_secret = secret 380 | 381 | def get_jwt_base_secret(self) -> str | None: 382 | return self.jwt_base_secret 383 | --------------------------------------------------------------------------------