├── .dockerignore ├── .env.example ├── .github └── workflows │ └── main.yaml ├── .gitignore ├── .vscode └── settings.json ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── alembic.ini ├── app ├── api │ ├── endpoints │ │ └── auth.py │ ├── middleware.py │ └── router.py ├── auth │ ├── jwt.py │ └── key.py ├── cache │ ├── redis.py │ └── time_cache.py ├── core │ └── config.py ├── db │ ├── connection.py │ ├── dep.py │ └── util.py ├── discord │ ├── client.py │ ├── config.py │ ├── exceptions.py │ └── models │ │ ├── guild.py │ │ ├── role.py │ │ └── user.py ├── lmbd │ ├── invoke.py │ ├── sample_data │ │ └── snsCloudWatchEvent.json │ └── sample_lambda │ │ └── main.py ├── log │ └── setup.py ├── migrations │ ├── env.py │ └── script.py.mako ├── models │ ├── mysql.py │ └── mysql_no_fk.py ├── types │ ├── cache.py │ ├── jwt.py │ ├── lmbd.py │ ├── server.py │ └── time.py └── util │ ├── common.py │ ├── json.py │ └── time.py ├── dev-requirements.txt ├── docker-compose.yaml ├── docker-entrypoint.sh ├── gunicorn_conf.py ├── loguru ├── __init__.py ├── __init__.pyi ├── _asyncio_loop.py ├── _better_exceptions.py ├── _colorama.py ├── _colorizer.py ├── _contextvars.py ├── _ctime_functions.py ├── _datetime.py ├── _defaults.py ├── _error_interceptor.py ├── _file_sink.py ├── _filters.py ├── _get_frame.py ├── _handler.py ├── _locks_machinery.py ├── _logger.py ├── _recattrs.py ├── _simple_sinks.py ├── _string_parsers.py └── py.typed ├── main.py ├── make-env-example.sh ├── mypy.ini ├── pyproject.toml ├── requirements.txt └── shell.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Version control system files 2 | .git 3 | .gitignore 4 | 5 | # Logs and temporary files 6 | *.log 7 | logs/ 8 | __pycache__/ 9 | .mypy_cache/ 10 | .ruff_cache/ 11 | .cache/ 12 | 13 | # Virtual environment and dependency files 14 | env/ 15 | .venv/ 16 | 17 | # Environment configuration files 18 | .env.prod 19 | .env.production 20 | .env.dev 21 | .env.development 22 | .env.staging 23 | .env.stage 24 | .env.test 25 | .env.testing 26 | .env.local 27 | .env.example 28 | 29 | # Build and setup scripts 30 | make-env-example.sh 31 | make-env.sh 32 | 33 | # Test files and directories 34 | test.py 35 | test.json 36 | tests/ 37 | 38 | # IDE and editor specific files 39 | .vscode/ 40 | 41 | # GitHub workflow and contribution files 42 | .github/ 43 | CHANGELOG.md 44 | CONTRIBUTING.md 45 | 46 | # Documentation and license files 47 | README.md 48 | 49 | # Docker Compose file (if not needed in build context) 50 | docker-compose.yml 51 | 52 | # Package and archive files 53 | *.zip 54 | *.7zip 55 | *.dmg 56 | *.gz 57 | *.iso 58 | *.jar 59 | *.rar 60 | 61 | # Other 62 | .DS_STORE 63 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | DEBUG=0 2 | ENABLE_METRICS=1 3 | DOMAIN=localhost 4 | SECRET_KEY= 5 | REFRESH_KEY= 6 | PROFILING=0 7 | JWT_USE_NONCE=0 8 | 9 | # Backend 10 | BACKEND_CORS_ORIGINS=["http://localhost:8000","http://localhost:5000"] 11 | 12 | # MySQL 13 | MYSQL_HOST=localhost:3306 14 | MYSQL_USER=root 15 | MYSQL_PASSWORD=root 16 | MYSQL_DATABASE=fastapi-db 17 | MYSQL_SSL=/etc/ssl/cert.pem 18 | 19 | # Redis 20 | REDIS_HOST=localhost 21 | REDIS_PORT=6379 22 | REDIS_PASSWORD=test 23 | 24 | # AWS 25 | AWS_ACCESS_KEY_ID=test 26 | AWS_SECRET_ACCESS_KEY=test 27 | AWS_DEFAULT_REGION=us-east-1 28 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | 19 | - name: Set up Python 3.11 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: "3.11" 23 | 24 | - name: Install linter 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install ruff 28 | 29 | - name: Linting code 30 | run: | 31 | ruff check . 32 | 33 | type-check: 34 | runs-on: ubuntu-latest 35 | steps: 36 | - name: Checkout 37 | uses: actions/checkout@v4 38 | 39 | - name: Set up Python 3.11 40 | uses: actions/setup-python@v3 41 | with: 42 | python-version: "3.11" 43 | 44 | - name: Install dependencies 45 | run: | 46 | python -m pip install --upgrade pip 47 | pip install -r requirements.txt 48 | pip install -r dev-requirements.txt 49 | 50 | - name: Type checking code 51 | run: | 52 | mypy . --explicit-package-bases --exclude loguru 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source code files 2 | __pycache__/ 3 | *.c 4 | *.so 5 | *.dll 6 | *.pyd 7 | *.exe 8 | *.pyc 9 | compiled/ 10 | 11 | # Build and environment files 12 | build/ 13 | env/ 14 | .venv/ 15 | .docker/ 16 | docker/ 17 | make-env.sh 18 | 19 | # Configuration and environment variable files 20 | .env 21 | .env.prod 22 | .env.production 23 | .env.dev 24 | .env.development 25 | .env.staging 26 | .env.stage 27 | .env.test 28 | .env.testing 29 | .env.local 30 | *prod.env 31 | *production.env 32 | *dev.env 33 | *development.env 34 | *staging.env 35 | *stage.env 36 | *test.env 37 | *testing.env 38 | *local.env 39 | 40 | # Logs and temporary files 41 | *.log 42 | logs/ 43 | .ruff_cache/ 44 | .cache/ 45 | .mypy_cache/ 46 | 47 | # Package and archive files 48 | *.zip 49 | *.7zip 50 | *.dmg 51 | *.gz 52 | *.iso 53 | *.jar 54 | *.rar 55 | 56 | # IDE and editor specific files 57 | !.vscode/settings.json 58 | .DS_STORE 59 | 60 | # Test files 61 | test.py 62 | test.json 63 | 64 | # Specific includes and excludes 65 | *.txt 66 | !requirements.txt 67 | !dev-requirements.txt 68 | chromedriver* 69 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "basic", 3 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Creating a python base with shared environment variables 2 | FROM python:3.11-slim-bullseye AS python-base 3 | ENV PYTHONUNBUFFERED=1 \ 4 | PYTHONDONTWRITEBYTECODE=1 \ 5 | PIP_NO_CACHE_DIR=off \ 6 | PIP_DISABLE_PIP_VERSION_CHECK=on \ 7 | PIP_DEFAULT_TIMEOUT=100 \ 8 | UV_HOME="/opt/uv" \ 9 | PYSETUP_PATH="/opt/pysetup" \ 10 | VENV_PATH="/opt/pysetup/.venv" 11 | 12 | ENV PATH="$UV_HOME/bin:$VENV_PATH/bin:$PATH" 13 | 14 | # builder-base is used to build dependencies 15 | FROM python-base AS builder-base 16 | RUN buildDeps="build-essential" \ 17 | && apt-get update \ 18 | && apt-get install --no-install-recommends -y \ 19 | curl \ 20 | vim \ 21 | netcat \ 22 | && apt-get install -y --no-install-recommends $buildDeps \ 23 | && rm -rf /var/lib/apt/lists/* 24 | 25 | # Install UV 26 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh 27 | 28 | # We copy our Python requirements here to cache them 29 | # and install only runtime deps using uv 30 | WORKDIR $PYSETUP_PATH 31 | COPY ./requirements.txt ./ 32 | RUN uv venv $VENV_PATH && \ 33 | . $VENV_PATH/bin/activate && \ 34 | uv pip install -r requirements.txt 35 | 36 | # 'development' stage installs all dev deps and can be used to develop code. 37 | # For example using docker-compose to mount local volume under /app 38 | FROM python-base as development 39 | ENV FASTAPI_ENV=development 40 | 41 | # Copying uv and venv into image 42 | COPY --from=builder-base $UV_HOME $UV_HOME 43 | COPY --from=builder-base $PYSETUP_PATH $PYSETUP_PATH 44 | 45 | # Copying in our entrypoint 46 | COPY docker-entrypoint.sh /docker-entrypoint.sh 47 | RUN chmod +x /docker-entrypoint.sh 48 | 49 | # venv already has runtime deps installed we get a quicker install 50 | WORKDIR $PYSETUP_PATH 51 | COPY ./dev-requirements.txt ./ 52 | RUN . $VENV_PATH/bin/activate && \ 53 | uv pip install -r dev-requirements.txt 54 | 55 | WORKDIR /app 56 | COPY . . 57 | 58 | # Needs to be consistent with gunicorn_conf.py 59 | EXPOSE 8000 60 | ENTRYPOINT ["/docker-entrypoint.sh"] 61 | 62 | # 'production' stage uses the clean 'python-base' stage and copies 63 | # in only our runtime deps that were installed in the 'builder-base' 64 | FROM python-base AS production 65 | ENV FASTAPI_ENV=production 66 | ENV PROD=true 67 | 68 | COPY --from=builder-base $VENV_PATH $VENV_PATH 69 | COPY gunicorn_conf.py /gunicorn_conf.py 70 | 71 | COPY docker-entrypoint.sh /docker-entrypoint.sh 72 | RUN chmod +x /docker-entrypoint.sh 73 | 74 | COPY main.py /main.py 75 | COPY .env /.env 76 | 77 | # Create user with the name appuser 78 | RUN groupadd -g 1500 appuser && \ 79 | useradd -m -u 1500 -g appuser appuser 80 | 81 | COPY --chown=appuser:appuser ./app /app 82 | USER appuser 83 | 84 | ENTRYPOINT ["/docker-entrypoint.sh"] 85 | CMD ["gunicorn", "--worker-class", "uvicorn.workers.UvicornWorker", "--config", "/gunicorn_conf.py", "main:server"] 86 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Abe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build-container: ## Build container 2 | docker buildx build --platform=linux/amd64 -t $(CONTAINER_NAME) . 3 | 4 | run-container: ## Run container 5 | docker run --name=$(CONTAINER_NAME) -d -p 5500:8000 $(CONTAINER_NAME) 6 | 7 | upload-container: ## Upload container to ECR 8 | aws ecr get-login-password --region $(AWS_REGION) | docker login --username AWS --password-stdin $(AWS_ACCOUNT_ID).dkr.ecr.$(AWS_REGION).amazonaws.com/$(CONTAINER_REPO_NAME) 9 | docker tag $(CONTAINER_NAME):latest $(AWS_ACCOUNT_ID).dkr.ecr.$(AWS_REGION).amazonaws.com/$(CONTAINER_REPO_NAME):latest 10 | docker push $(AWS_ACCOUNT_ID).dkr.ecr.$(AWS_REGION).amazonaws.com/$(CONTAINER_REPO_NAME):latest 11 | 12 | cutover-ecs: ## Cutover to new ECS task definition 13 | aws ecs update-service --cluster $(ECS_CLUSTER_NAME) --service $(ECS_SERVICE_NAME) --force-new-deployment 14 | 15 | deploy-cloudwatch-lambda: ## Deploy the cloudwatch lambda ## TODO: UPDATE FOR SAMPLE LAMBDA 16 | zip -j sns_cloudwatch_webhook.zip app/lmbd/sns_cloudwatch_webhook/main.py 17 | aws lambda update-function-code --function-name sns-cloudwatch-webhook --zip-file fileb://sns_cloudwatch_webhook.zip 18 | rm sns_cloudwatch_webhook.zip 19 | 20 | help: ## Display help 21 | @awk -F ':|##' '/^[^\t].+?:.*?##/ {printf "\033[36m%-30s\033[0m %s\n", $$1, $$NF}' $(MAKEFILE_LIST) | sort 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

⚡️ Fastest FastAPI

3 | 4 |
5 | 6 | ![Build](https://github.com/FastestMolasses/Fast-Python-Server-Template/actions/workflows/main.yaml/badge.svg) [![GitHub license](https://badgen.net/github/license/FastestMolasses/Fast-Python-Server-Template)](https://github.com/FastestMolasses/Fast-Python-Server-Template/blob/main/LICENSE) 7 | 8 |
9 | 10 |

11 | A production-ready FastAPI server template, emphasizing performance and type safety. It includes a configurable set of features and options, allowing customization to retain or exclude components. 12 |
13 |
14 | Built with FastAPI, Pydantic, Ruff, and MyPy. 15 |
16 | Report Bug 17 | · 18 | Request Feature 19 |

20 |

21 | 22 | ## Features 23 | 24 | - ⚡ Async and type safety by default 25 | - 🛠️ CI/CD and tooling setup 26 | - 🚀 High performance libraries integrated ([orjson](https://github.com/ijl/orjson), [uvloop](https://github.com/MagicStack/uvloop), [pydantic2](https://github.com/pydantic/pydantic)) 27 | - 📝 [Loguru](https://github.com/Delgan/loguru) + [picologging](https://github.com/microsoft/picologging) for simplified and performant logging 28 | - 🐳 Dockerized and includes AWS deployment flow 29 | - 🗃️ Several database implementations with sample ORM models (MySQL, Postgres, Timescale) & migrations 30 | - 🔐 Optional JWT authentication and authorization 31 | - 🌐 AWS Lambda functions support 32 | - 🧩 Modularized features 33 | - 📊 Prometheus metrics 34 | - 📜 Makefile commands 35 | - 🗺️ Route profiling 36 | 37 | ## Table of Contents 38 | 39 | - [Requirements](#requirements) 40 | - [Installation](#installation) 41 | - [Environment Specific Configuration](#environment-specific-configuration) 42 | - [Upgrading Dependencies](#upgrading-dependencies) 43 | - [Databases](#databases) 44 | - [Shell](#shell) 45 | - [Migrations](#migrations) 46 | - [Downgrade Migration](#downgrade-migration) 47 | - [JWT Auth](#jwt-auth) 48 | - [JWT Overview](#jwt-overview) 49 | - [Modifying JWT Payload Fields](#modifying-jwt-payload-fields) 50 | - [Project Structure](#project-structure) 51 | - [Makefile Commands](#makefile-commands) 52 | - [Contributing](#contributing) 53 | 54 | ## Requirements 55 | 56 | - [Python 3.11+](https://www.python.org/downloads/) 57 | - [Docker](https://www.docker.com/get-started/) 58 | 59 | ## Installation 60 | 61 | 1. Fork this repo ([How to create a private fork](https://gist.github.com/0xjac/85097472043b697ab57ba1b1c7530274)) 62 | 63 | 2. Install UV Package Manager: 64 | Follow the installation instructions at https://github.com/astral-sh/uv 65 | 66 | 3. Set up the virtual environment and install dependencies: 67 | 68 | ```bash 69 | uv venv 70 | 71 | # On macOS and Linux 72 | source .venv/bin/activate 73 | 74 | # On Windows 75 | .venv\Scripts\activate 76 | 77 | # Install main dependencies 78 | uv pip install -r requirements.txt 79 | 80 | # Install development dependencies (optional) 81 | uv pip install -r dev-requirements.txt 82 | ``` 83 | 84 | 4. Install [Docker](https://www.docker.com/get-started/) 85 | 86 | 5. Start your Docker services: 87 | 88 | ```bash 89 | docker compose up 90 | ``` 91 | 92 | 6. Clone `.env.example` to `.env` and update the values: 93 | 94 | ```bash 95 | # macOS and Linux 96 | cp .env.example .env 97 | 98 | # Windows (PowerShell) 99 | Copy-Item .env.example .env 100 | ``` 101 | 102 | You can use this command to generate secret keys: 103 | 104 | ```bash 105 | # macOS and Linux 106 | openssl rand -hex 128 107 | 108 | # Windows (PowerShell) 109 | $bytes = New-Object byte[] 128; (New-Object Security.Cryptography.RNGCryptoServiceProvider).GetBytes($bytes); [System.BitConverter]::ToString($bytes) -Replace '-' 110 | ``` 111 | 112 | 7. Run the server: 113 | 114 | ```bash 115 | uvicorn main:server --reload 116 | ``` 117 | 118 | Note: If you need to update dependencies, you can modify the `requirements.txt` or `dev-requirements.txt` files directly and then run `uv pip install -r requirements.txt` or `uv pip install -r dev-requirements.txt` respectively. 119 | 120 | ## Environment Specific Configuration 121 | 122 | This project uses environment-specific configuration files and symbolic links to manage different environments such as development, production, and staging. Follow the steps below for your operating system to set up the desired environment. 123 | 124 | ```bash 125 | # macOS, linux 126 | ln -s .env .env 127 | # example: ln -s prod.env .env 128 | 129 | # windows 130 | mklink .env .env 131 | # example: mklink .env prod.env 132 | ``` 133 | 134 | ## Databases 135 | 136 | ### Shell 137 | 138 | To access the database shell, run this command 139 | 140 | ```bash 141 | python -i shell.py 142 | ``` 143 | 144 | The `shell.py` script will be loaded including the database session and models. 145 | 146 | ### Migrations 147 | 148 | To do a database migration, follow the steps below. 149 | 150 | 1. Update `database/models.py` with the changes you want 151 | 2. Run this command to generate the migration file in `migrations/versions` 152 | 153 | ```bash 154 | alembic revision --autogenerate -m "Describe your migration" 155 | ``` 156 | 157 | 3. Check the newly generated migration file and verify that it generated correctly. 158 | 4. Run this command to apply the migration 159 | ```bash 160 | alembic upgrade head 161 | ``` 162 | 163 | ⛔️ Autogenerated migrations cannot detect these changes: 164 | 165 | - Changes of table name 166 | - Changes of column name 167 | - Anonymously named constraints 168 | - Special SQLAlchemy types such as Enum when generated on a backend which doesn’t support ENUM directly 169 | 170 | [Reference](https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect) 171 | 172 | These changes will need to be migrated manually by creating an empty migration file and then writing the code to create the changes. 173 | 174 | ```bash 175 | # Manual creation of empty migration file 176 | alembic revision -m "Describe your migration" 177 | ``` 178 | 179 | ### Downgrade Migration 180 | 181 | Run this command to revert every migration back to the beginning. 182 | 183 | ```bash 184 | alembic downgrade base 185 | ``` 186 | 187 | ## JWT Implementation 188 | 189 | In this FastAPI template, JSON Web Tokens (JWT) can be optionally utilized for authentication. This documentation section elucidates the JWT implementation and related functionalities. 190 | 191 | ### JWT Overview 192 | 193 | The JWT implementation can be found in the module: app/auth/jwt.py. The primary functions include: 194 | 195 | - Creating access and refresh JWT tokens. 196 | - Verifying and decoding a given JWT token. 197 | - Handling JWT-based authentication for FastAPI routes. 198 | 199 | #### User Management 200 | 201 | If a user associated with a JWT token is not found in the database, a new user will be created. This is managed by the get_or_create_user function. When a token is decoded and the corresponding user ID (sub field in the token) is not found, the system will attempt to create a new user with that ID. 202 | 203 | #### Nonce Usage 204 | 205 | A nonce is an arbitrary number that can be used just once. It's an optional field in the JWT token to ensure additional security. If a nonce is used: 206 | 207 | - It is stored in Redis for the duration of the refresh token's validity. 208 | - It must match between access and refresh tokens to ensure their pairing. 209 | - Its presence in Redis is verified before the token is considered valid. 210 | 211 | Enabling nonce usage provides an additional layer of security against token reuse, but requires Redis to function. 212 | 213 | ### Modifying JWT Payload Fields 214 | 215 | The JWT token payload structure is defined in `app/types/jwt.py`` under the JWTPayload class. If you wish to add more fields to the JWT token payload: 216 | 217 | 1. Update the TokenData and JWTPayload class in `app/types/jwt.py`` by adding the desired fields. 218 | 219 | ```python 220 | class JWTPayload(BaseModel): 221 | # ... existing fields ... 222 | new_field: Type 223 | 224 | class TokenData(BaseModel): 225 | # ... existing fields ... 226 | new_field: Type 227 | ``` 228 | 229 | TokenData is separated from JWTPayload to make it clear what is automatically filled in and what is manually added. Both classes must be updated to include the new fields. 230 | 231 | 2. Wherever the token is created, update the payload to include the new fields. 232 | 233 | ```python 234 | from app.auth.jwt import create_jwt 235 | from app.types.jwt import TokenData 236 | 237 | payload = TokenData( 238 | sub='user_id_1', 239 | field1='value1', 240 | # ... all fields ... 241 | ) 242 | access_token, refresh_token = create_jwt(payload) 243 | ``` 244 | 245 | Remember, the JWT token has a size limit. The more data you include, the bigger your token becomes, so ensure that you only include essential data in the token payload. 246 | 247 | ## Project Structure 248 | 249 | ``` 250 | 📄 main.py - Server entry point 251 | 📁 .github/ - Github specific files 252 | 📁 app/ - Application code 253 | ├── 📁 api - API endpoints and middleware 254 | ├── 📁 auth - Authentication / authorization 255 | ├── 📁 cache - Redis code and caching functions 256 | ├── 📁 core - Core configuration 257 | ├── 📁 db - Database connections 258 | ├── 📁 discord - Discord library for auth (optional) 259 | ├── 📁 lmbd - Holds AWS lambda functions 260 | ├── 📁 migrations - Database migrations 261 | ├── 📁 models - Database ORM models 262 | ├── 📁 types - Type definitions 263 | └── 📁 util - Helper functions 264 | ``` 265 | 266 | ## Makefile Commands 267 | 268 | Make files are used to run common commands. You can find the list of commands in the `Makefile` file. 269 | To use these commands, first copy `make-env-example.sh` to `make-env.sh` and update the values. 270 | 271 | ```bash 272 | # macOS 273 | cp make-env-example.sh make-env.sh 274 | 275 | # windows (powershell) 276 | copy make-env-example.sh make-env.sh 277 | ``` 278 | 279 | Remember to make the file executable 280 | 281 | ```bash 282 | chmod +x make-env.sh 283 | ``` 284 | 285 | Then you can run the commands like this 286 | 287 | ```bash 288 | ./make-env.sh 289 | ``` 290 | 291 | Try it with the help command, which will list all the available commands. 292 | 293 | ```bash 294 | ./make-env.sh help 295 | ``` 296 | 297 | ## Contributing 298 | 299 | 1. **Fork the Repository**: Start by forking the repository to your own GitHub account. 300 | 2. **Clone the Forked Repository**: Clone the fork to your local machine. 301 | 3. **Create a New Branch**: Always create a new branch for your changes. 302 | 4. **Make Your Changes**: Implement your changes. 303 | 5. **Run Tests**: Make sure to test your changes locally. 304 | 6. **Submit a Pull Request**: Commit and push your changes, then create a pull request against the main branch. 305 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | [alembic] 2 | script_location = app/migrations 3 | 4 | # Template used to generate migration file names 5 | file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(slug)s 6 | 7 | # sys.path path, will be prepended to sys.path if present 8 | prepend_sys_path = . 9 | 10 | # Max length of characters to apply to the "slug" field 11 | truncate_slug_length = 8 12 | 13 | # Set to 'true' to run the environment during 14 | # the 'revision' command, regardless of autogenerate 15 | revision_environment = false 16 | 17 | # The output encoding used when revision files 18 | # are written from script.py.mako 19 | output_encoding = utf-8 20 | 21 | # Logging configuration 22 | [loggers] 23 | keys = root,sqlalchemy,alembic 24 | 25 | [handlers] 26 | keys = console 27 | 28 | [formatters] 29 | keys = generic 30 | 31 | [logger_root] 32 | level = WARN 33 | handlers = console 34 | qualname = 35 | 36 | [logger_sqlalchemy] 37 | level = WARN 38 | handlers = 39 | qualname = sqlalchemy.engine 40 | 41 | [logger_alembic] 42 | level = INFO 43 | handlers = 44 | qualname = alembic 45 | 46 | [handler_console] 47 | class = StreamHandler 48 | args = (sys.stderr,) 49 | level = NOTSET 50 | formatter = generic 51 | 52 | [formatter_generic] 53 | format = %(levelname)-5.5s [%(name)s] %(message)s 54 | datefmt = %H:%M:%S 55 | -------------------------------------------------------------------------------- /app/api/endpoints/auth.py: -------------------------------------------------------------------------------- 1 | from jose import JWTError 2 | from loguru import logger 3 | from fastapi import APIRouter, Depends 4 | from fastapi.responses import ORJSONResponse 5 | 6 | from app.models.mysql import User 7 | from app.core.config import settings 8 | from app.db.connection import MySqlSession 9 | from app.auth.jwt import RequireRefreshToken, RequireJWT, create_jwt 10 | 11 | from app.types.jwt import TokenData, JWTPayload 12 | from app.types.server import ServerResponse, Cookie 13 | 14 | router = APIRouter(prefix='/auth') 15 | 16 | 17 | @router.get('/signup') 18 | async def signup() -> ServerResponse[str]: 19 | """ 20 | Example route to create a user. 21 | """ 22 | with MySqlSession() as session: 23 | user = User() 24 | session.add(user) 25 | session.commit() 26 | session.refresh(user) 27 | 28 | return ServerResponse() 29 | 30 | 31 | @router.get('/login') 32 | async def login(response: ORJSONResponse) -> ServerResponse[str]: 33 | """ 34 | Example route to login a user. Will just grab the first user from the database 35 | and create a JWT for them. 36 | """ 37 | # Grab the first user from the database 38 | with MySqlSession() as session: 39 | user = session.query(User).first() 40 | 41 | if not user: 42 | return ServerResponse(status='error', message='No user found') 43 | 44 | token = TokenData(sub=str(user.id)) 45 | try: 46 | accessToken, refreshToken = create_jwt(token) 47 | except JWTError as e: 48 | logger.error(f'JWT Error during login: {e}') 49 | return ServerResponse(status='error', message='JWT Error, try again') 50 | 51 | # Save the refresh token in an HTTPOnly cookie 52 | response.set_cookie( 53 | Cookie.REFRESH_TOKEN, 54 | value=refreshToken, 55 | httponly=True, 56 | max_age=settings.REFRESH_TOKEN_EXPIRE_MINUTES, 57 | expires=settings.REFRESH_TOKEN_EXPIRE_MINUTES, 58 | ) 59 | return ServerResponse[str](data=accessToken) 60 | 61 | 62 | @router.get('/refresh') 63 | async def refresh( 64 | response: ORJSONResponse, payload: JWTPayload = Depends(RequireRefreshToken) 65 | ) -> ServerResponse[str]: 66 | """ 67 | Example route to refresh a JWT. 68 | """ 69 | token = TokenData(sub=payload.sub) 70 | 71 | try: 72 | accessToken, refreshToken = create_jwt(token) 73 | except JWTError as e: 74 | logger.error(f'JWT Error during login: {e}') 75 | return ServerResponse(status='error', message='JWT Error, try again.') 76 | 77 | # Save the refresh token in an HTTPOnly cookie 78 | response.set_cookie( 79 | Cookie.REFRESH_TOKEN, 80 | value=refreshToken, 81 | httponly=True, 82 | max_age=settings.REFRESH_TOKEN_EXPIRE_MINUTES, 83 | expires=settings.REFRESH_TOKEN_EXPIRE_MINUTES, 84 | ) 85 | return ServerResponse[str](data=accessToken) 86 | 87 | 88 | @router.get('/decodeToken') 89 | async def decodeToken(payload: JWTPayload = Depends(RequireJWT())): 90 | """ 91 | Example route to decode a JWT. 92 | """ 93 | return ServerResponse[dict](data=payload.model_dump()) 94 | -------------------------------------------------------------------------------- /app/api/middleware.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from typing import Callable 3 | from fastapi import Request 4 | # from pyinstrument import Profiler 5 | from app.core.config import settings 6 | from sqlalchemy.exc import IntegrityError 7 | from app.types.server import ServerResponse 8 | from fastapi.responses import ORJSONResponse 9 | from starlette.middleware.base import BaseHTTPMiddleware 10 | from sqlalchemy.exc import NoResultFound, MultipleResultsFound 11 | 12 | 13 | class DBExceptionsMiddleware(BaseHTTPMiddleware): 14 | """ 15 | Middleware to catch and handle database exceptions. 16 | """ 17 | async def dispatch(self, request: Request, call_next: Callable): 18 | try: 19 | response = await call_next(request) 20 | return response 21 | 22 | except NoResultFound as e: 23 | logger.exception(f'NoResultFound: {e}') 24 | response = ServerResponse(status='error', message='Row not found') 25 | 26 | except MultipleResultsFound as e: 27 | logger.exception(f'MultipleResultsFound: {e}') 28 | response = ServerResponse( 29 | status='error', message='Multiple rows found') 30 | 31 | except IntegrityError as e: 32 | e.hide_parameters = True 33 | logger.exception(f'IntegrityError: {e}') 34 | response = ServerResponse( 35 | status='error', message=str(e)) 36 | 37 | return ORJSONResponse(response.dict(), status_code=400) 38 | 39 | 40 | class CatchAllMiddleware(BaseHTTPMiddleware): 41 | """ 42 | Middleware to catch errors. 43 | """ 44 | async def dispatch(self, request: Request, call_next: Callable): 45 | try: 46 | response = await call_next(request) 47 | return response 48 | 49 | except Exception as e: 50 | # TODO: SEND NOTIFICATION HERE 51 | logger.exception(e) 52 | response = ServerResponse( 53 | status='error', message=str(e)) 54 | return ORJSONResponse(response.dict(), status_code=400) 55 | 56 | 57 | class ProfilingMiddleware(BaseHTTPMiddleware): 58 | """ 59 | Middleware to catch errors. 60 | """ 61 | async def dispatch(self, request: Request, call_next: Callable): 62 | if settings.PROFILING: 63 | # profiler = Profiler(interval=settings.profiling_interval, async_mode='enabled') 64 | 65 | # profiler.start() 66 | # response = await call_next(request) 67 | # profiler.stop() 68 | 69 | # return response 70 | return await call_next(request) 71 | else: 72 | return await call_next(request) 73 | -------------------------------------------------------------------------------- /app/api/router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from app.api.endpoints import auth 3 | 4 | apiRouter = APIRouter() 5 | apiRouter.include_router(auth.router, tags=['auth']) 6 | -------------------------------------------------------------------------------- /app/auth/jwt.py: -------------------------------------------------------------------------------- 1 | from jose import jwt 2 | from loguru import logger 3 | from jose.constants import Algorithms 4 | from fastapi.security import HTTPBearer 5 | from datetime import datetime, timedelta 6 | from fastapi import Request, HTTPException 7 | from jose.exceptions import JWTClaimsError, JWTError, ExpiredSignatureError 8 | 9 | from app.core.config import settings 10 | from app.cache.redis import SessionStore 11 | from app.util.common import generateNonce 12 | 13 | from app.types.server import Cookie 14 | from app.types.jwt import TokenData, JWTPayload 15 | from app.types.cache import RedisTokenPrefix, UserKey 16 | 17 | ALGORITHM = Algorithms.HS256 18 | 19 | 20 | def create_jwt(data: TokenData) -> tuple[str, str]: 21 | """ 22 | Create access and refresh JWT tokens. 23 | If the user ID is provided, the database won't be queried. 24 | If the user does not exist in the database and no user ID is provided, a new user will be created. 25 | The nonce is stored in the cache for refresh token invalidation. 26 | """ 27 | nonce = None 28 | if settings.JWT_USE_NONCE: 29 | nonce = generateNonce() 30 | 31 | access_token = create_token(data, nonce, settings.ACCESS_TOKEN_EXPIRE_MINUTES) 32 | refresh_token = create_token(data, nonce, settings.REFRESH_TOKEN_EXPIRE_MINUTES) 33 | 34 | # Save the nonce in the cache for refresh token invalidation, only if using nonce 35 | if nonce: 36 | set_nonce_in_cache(data.sub, nonce, settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60) 37 | 38 | return access_token, refresh_token 39 | 40 | 41 | def verify_token(token: str) -> JWTPayload | None: 42 | """ 43 | Decode a JWT token. 44 | """ 45 | try: 46 | payload = JWTPayload( 47 | **jwt.decode( 48 | token, 49 | settings.SECRET_KEY, 50 | algorithms=[ALGORITHM], 51 | options={ 52 | 'require_iat': True, 53 | 'require_exp': True, 54 | 'require_sub': True, 55 | }, 56 | ) 57 | ) 58 | if settings.JWT_USE_NONCE and not payload.nonce: 59 | logger.error('Nonce not found in JWT payload.') 60 | return None 61 | return payload 62 | except (JWTError, ExpiredSignatureError, JWTClaimsError) as e: 63 | logger.error(f'Error while verifying JWT: {e}') 64 | return None 65 | 66 | 67 | class RequireJWT(HTTPBearer): 68 | """ 69 | Custom FastAPI dependency for JWT authentication. 70 | Returns the decoded JWT payload if the token is valid. 71 | """ 72 | 73 | async def __call__(self, request: Request): 74 | credentials = await super(RequireJWT, self).__call__(request) 75 | 76 | if credentials and credentials.credentials: 77 | payload = verify_token(credentials.credentials) 78 | if not payload: 79 | raise HTTPException(status_code=403, detail='Invalid token or expired token.') 80 | 81 | validate_nonce(payload) 82 | return payload 83 | else: 84 | raise HTTPException(status_code=403, detail='Invalid authorization code.') 85 | 86 | 87 | def RequireRefreshToken(request: Request) -> JWTPayload: 88 | refreshToken = request.cookies.get(Cookie.REFRESH_TOKEN, '') 89 | payload = verify_token(refreshToken) 90 | if not payload: 91 | raise HTTPException(status_code=403, detail='Invalid token or expired token.') 92 | 93 | validate_nonce(payload) 94 | return payload 95 | 96 | 97 | def create_token(data: TokenData, nonce: str | None, expire_minutes: int) -> str: 98 | payload = { 99 | **data.model_dump(), 100 | 'exp': datetime.utcnow() + timedelta(minutes=expire_minutes), 101 | 'iat': datetime.utcnow(), 102 | } 103 | if settings.JWT_USE_NONCE and nonce: 104 | payload['nonce'] = nonce 105 | return jwt.encode(payload, settings.SECRET_KEY, algorithm=ALGORITHM) 106 | 107 | 108 | def set_nonce_in_cache(user_id: str, nonce: str, expiration_time: int): 109 | """ 110 | Store nonce in cache with a specified expiration time. 111 | """ 112 | if settings.JWT_USE_NONCE: 113 | cache = SessionStore(RedisTokenPrefix.USER, user_id, ttl=expiration_time) 114 | cache.set(UserKey.NONCE, nonce) 115 | 116 | 117 | def validate_nonce(payload: JWTPayload): 118 | if not settings.JWT_USE_NONCE: 119 | return 120 | if not is_nonce_in_cache(payload.sub, payload.nonce): 121 | raise HTTPException(status_code=403, detail='Invalid authorization code.') 122 | 123 | 124 | def is_nonce_in_cache(user_id: str, nonce: str | None) -> bool: 125 | """ 126 | Check if the nonce is in the cache. 127 | """ 128 | if not nonce: 129 | return False 130 | cache = SessionStore(RedisTokenPrefix.USER, user_id) 131 | return cache.get(UserKey.NONCE) == nonce 132 | -------------------------------------------------------------------------------- /app/auth/key.py: -------------------------------------------------------------------------------- 1 | import os 2 | import binascii 3 | 4 | 5 | def generate_key(): 6 | """ 7 | Generate a random hex key for the application 8 | """ 9 | return '0x' + binascii.hexlify(os.urandom(6)).decode() 10 | -------------------------------------------------------------------------------- /app/cache/redis.py: -------------------------------------------------------------------------------- 1 | import redis 2 | 3 | from app.core.config import settings 4 | 5 | 6 | class SessionStore: 7 | _pool = None 8 | 9 | @classmethod 10 | def get_pool(cls): 11 | if cls._pool is None: 12 | cls._pool = redis.ConnectionPool( 13 | host=settings.REDIS_HOST, 14 | port=settings.REDIS_PORT, 15 | password=settings.REDIS_PASSWORD, 16 | connection_class=redis.SSLConnection, 17 | ) 18 | return cls._pool 19 | 20 | def __init__(self, *tokens: str, ttl: int = (60 * 60 * 4)): 21 | """ 22 | Params:\n 23 | tokens - Used to create a session in redis. Key/value pairs are unique to this token. 24 | Pass in multiple tokens and they will be joined into one.\n 25 | ttl - Time to live in seconds. Defaults to 4 hours. 26 | """ 27 | self.token = ':'.join(tokens) 28 | self.redis = redis.StrictRedis(connection_pool=self.get_pool()) 29 | self.ttl = ttl 30 | 31 | def set(self, key: str, value: str) -> int: 32 | self._refresh() 33 | return self.redis.hset(self.token, key, value) 34 | 35 | def get(self, key: str) -> str: 36 | self._refresh() 37 | val = self.redis.hget(self.token, key) 38 | if not val: 39 | return '' 40 | return val.decode('utf-8') 41 | 42 | def delete(self, key: str) -> int: 43 | self._refresh() 44 | return self.redis.hdel(self.token, key) 45 | 46 | def deleteSelf(self): 47 | self.redis.delete(self.token) 48 | 49 | def incr(self, key: str, amount: int = 1) -> int: 50 | self._refresh() 51 | return self.redis.hincrby(self.token, key, amount) 52 | 53 | def _refresh(self): 54 | self.redis.expire(self.token, self.ttl) 55 | -------------------------------------------------------------------------------- /app/cache/time_cache.py: -------------------------------------------------------------------------------- 1 | import time 2 | import asyncio 3 | import functools 4 | 5 | from collections import OrderedDict 6 | 7 | 8 | def time_cache(max_age_seconds, maxsize=128, typed=False): 9 | """ 10 | Least-recently-used cache decorator with time-based cache invalidation. 11 | 12 | Args: 13 | max_age_seconds: Time to live for cached results (in seconds). 14 | maxsize: Maximum cache size (see `functools.lru_cache`). 15 | typed: Cache on distinct input types (see `functools.lru_cache`). 16 | """ 17 | def _decorator(fn): 18 | @functools.lru_cache(maxsize=maxsize, typed=typed) 19 | def _new(*args, __time_salt, **kwargs): 20 | return fn(*args, **kwargs) 21 | 22 | @functools.wraps(fn) 23 | def _wrapped(*args, **kwargs): 24 | return _new(*args, **kwargs, __time_salt=int(time.time() / max_age_seconds)) 25 | 26 | return _wrapped 27 | 28 | return _decorator 29 | 30 | 31 | def aio_time_cache(max_age_seconds, maxsize=128, typed=False): 32 | """Least-recently-used cache decorator with time-based cache invalidation for async functions. 33 | 34 | Args: 35 | max_age_seconds: Time to live for cached results (in seconds). 36 | maxsize: Maximum cache size. 37 | typed: Cache on distinct input types. 38 | """ 39 | cache = OrderedDict() 40 | lock = asyncio.Lock() 41 | 42 | def _key(args, kwargs): 43 | return args, tuple(sorted(kwargs.items())) 44 | 45 | def _decorator(fn): 46 | @functools.wraps(fn) 47 | async def _wrapped(*args, **kwargs): 48 | async with lock: 49 | key = _key(args, kwargs) 50 | if typed: 51 | key += (tuple(type(arg) for arg in args), 52 | tuple(type(value) for value in kwargs.values())) 53 | now = time.time() 54 | 55 | # Cache hit and check if value is still fresh 56 | if key in cache: 57 | result, timestamp = cache.pop(key) 58 | if now - timestamp <= max_age_seconds: 59 | # Move to end to show that it was recently used 60 | cache[key] = (result, timestamp) 61 | return result 62 | 63 | # Cache miss or value has expired 64 | result = await fn(*args, **kwargs) 65 | cache[key] = (result, now) 66 | 67 | # Remove oldest items if cache is full 68 | while len(cache) > maxsize: 69 | cache.popitem(last=False) 70 | 71 | return result 72 | 73 | return _wrapped 74 | 75 | return _decorator 76 | -------------------------------------------------------------------------------- /app/core/config.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pydantic_settings import BaseSettings 3 | from pydantic import AnyHttpUrl, validator 4 | 5 | 6 | class EnvConfigSettings(BaseSettings): 7 | """ 8 | Settings for the app. Reads from environment variables. 9 | """ 10 | DEBUG: bool 11 | API_V1_STR: str = '/api/v1' 12 | SECRET_KEY: str 13 | REFRESH_KEY: str 14 | PROFILING: bool = False 15 | JWT_USE_NONCE: bool 16 | ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # 30 minutes 17 | REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 3 # 3 days 18 | BACKEND_CORS_ORIGINS: list[AnyHttpUrl] = [] 19 | 20 | @validator('BACKEND_CORS_ORIGINS', pre=True) 21 | def assemble_cors_origins(cls, v: str | list[str]) -> list[str] | str: 22 | if isinstance(v, str) and not v.startswith('['): 23 | return [i.strip() for i in v.split(',')] 24 | elif isinstance(v, (list, str)): 25 | return v 26 | raise ValueError(v) 27 | 28 | # MySQL 29 | MYSQL_HOST: str 30 | MYSQL_USER: str 31 | MYSQL_PASSWORD: str 32 | MYSQL_DATABASE: str 33 | MYSQL_DATABASE_URI: str | None = None 34 | MYSQL_SSL: str 35 | 36 | @validator('MYSQL_DATABASE_URI', pre=True) 37 | def assemble_mysql_connection(cls, v: str | None, values: dict[str, Any]) -> Any: 38 | if isinstance(v, str): 39 | return v 40 | return (f"mysql+pymysql://{values.get('MYSQL_USER')}:{values.get('MYSQL_PASSWORD')}" 41 | f"@{values.get('MYSQL_HOST')}/{values.get('MYSQL_DATABASE')}") 42 | 43 | # Redis 44 | REDIS_HOST: str 45 | REDIS_PORT: int 46 | REDIS_PASSWORD: str 47 | 48 | class Config: 49 | case_sensitive = True 50 | env_file = '.env' 51 | extra = 'ignore' 52 | 53 | 54 | settings = EnvConfigSettings(**{}) 55 | -------------------------------------------------------------------------------- /app/db/connection.py: -------------------------------------------------------------------------------- 1 | from app.core.config import settings 2 | from sqlalchemy.engine import Engine 3 | from sqlalchemy import create_engine 4 | from sqlalchemy.orm import sessionmaker, Session, DeclarativeBase 5 | 6 | if not settings.MYSQL_DATABASE_URI: 7 | raise ValueError('Missing database URI') 8 | 9 | mysqlEngine: Engine = create_engine( 10 | settings.MYSQL_DATABASE_URI, 11 | echo=settings.DEBUG, 12 | pool_pre_ping=True, 13 | connect_args={ 14 | 'ssl_ca': settings.MYSQL_SSL, 15 | }, 16 | ) 17 | 18 | 19 | class MySQLTableBase(DeclarativeBase): 20 | pass 21 | 22 | 23 | def MySqlSession(expireOnCommit: bool = False) -> Session: 24 | return sessionmaker(bind=mysqlEngine, expire_on_commit=expireOnCommit)() 25 | 26 | 27 | def createMySQLTables(): 28 | MySQLTableBase.metadata.create_all(mysqlEngine) 29 | 30 | 31 | def dropMySQLTables(): 32 | MySQLTableBase.metadata.drop_all(mysqlEngine) 33 | -------------------------------------------------------------------------------- /app/db/dep.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | from sqlalchemy.orm import Session 3 | from app.db.connection import MySqlSession 4 | from app.cache.time_cache import time_cache 5 | 6 | 7 | def get_db() -> Iterator[Session]: 8 | """ 9 | Dependency that gets a cached database session. 10 | """ 11 | session = _get_session() 12 | try: 13 | yield session 14 | session.commit() 15 | except Exception as exc: 16 | session.rollback() 17 | raise exc 18 | finally: 19 | session.close() 20 | 21 | 22 | @time_cache(max_age_seconds=60 * 10) 23 | def _get_session(): 24 | return MySqlSession() 25 | -------------------------------------------------------------------------------- /app/db/util.py: -------------------------------------------------------------------------------- 1 | from app.db.connection import mysqlEngine 2 | 3 | 4 | def printTables(): 5 | tables = mysqlEngine.table_names() 6 | print('Tables:') 7 | for table in tables: 8 | print('-', table) 9 | -------------------------------------------------------------------------------- /app/discord/client.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | 3 | from typing import Optional 4 | from fastapi import Depends, Request 5 | from app.cache.time_cache import aio_time_cache 6 | from typing_extensions import TypedDict, Literal 7 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 8 | 9 | from app.discord.models.user import User 10 | from app.discord.models.guild import Guild, GuildPreview 11 | from app.discord.config import DISCORD_API_URL, DISCORD_OAUTH_AUTHENTICATION_URL, DISCORD_TOKEN_URL 12 | from app.discord.exceptions import RateLimited, ScopeMissing, Unauthorized, InvalidToken, \ 13 | ClientSessionNotInitialized 14 | 15 | 16 | class RefreshTokenPayload(TypedDict): 17 | client_id: str 18 | client_secret: str 19 | grant_type: Literal['refresh_token'] 20 | refresh_token: str 21 | 22 | 23 | class TokenGrantPayload(TypedDict): 24 | client_id: str 25 | client_secret: str 26 | grant_type: Literal['authorization_code'] 27 | code: str 28 | redirect_uri: str 29 | 30 | 31 | class TokenResponse(TypedDict): 32 | access_token: str 33 | token_type: str 34 | expires_in: int 35 | refresh_token: str 36 | scope: str 37 | 38 | 39 | PAYLOAD = TokenGrantPayload | RefreshTokenPayload 40 | 41 | 42 | def _tokens(resp: TokenResponse) -> tuple[str, str]: 43 | """ 44 | Extracts tokens from TokenResponse 45 | 46 | Parameters 47 | ---------- 48 | resp: TokenResponse 49 | Response 50 | 51 | Returns 52 | ------- 53 | Tuple[str, str] 54 | An union of access_token and refresh_token 55 | 56 | Raises 57 | ------ 58 | InvalidToken 59 | If tokens are `None` 60 | """ 61 | access_token, refresh_token = resp.get( 62 | 'access_token'), resp.get('refresh_token') 63 | if access_token is None or refresh_token is None: 64 | raise InvalidToken('Tokens can\'t be None') 65 | return access_token, refresh_token 66 | 67 | 68 | class DiscordOAuthClient: 69 | """ 70 | Client for Discord Oauth2. 71 | 72 | Parameters 73 | ---------- 74 | client_id: 75 | Discord application client ID. 76 | client_secret: 77 | Discord application client secret. 78 | redirect_uri: 79 | Discord application redirect URI. 80 | scopes: 81 | Optional Discord Oauth scopes 82 | proxy: 83 | Optional proxy url 84 | proxy_auth: 85 | Optional aiohttp.BasicAuth proxy authentification 86 | """ 87 | client_id: str 88 | client_secret: str 89 | redirect_uri: str 90 | scopes: str 91 | proxy: Optional[str] = None 92 | proxy_auth: Optional[aiohttp.BasicAuth] = None 93 | client_session: Optional[aiohttp.ClientSession] = None 94 | 95 | def __init__( 96 | self, 97 | client_id, 98 | client_secret, 99 | redirect_uri, 100 | scopes=("identify",), 101 | proxy=None, 102 | proxy_auth: Optional[aiohttp.BasicAuth] = None, 103 | ): 104 | self.client_id = client_id 105 | self.client_secret = client_secret 106 | self.redirect_uri = redirect_uri 107 | self.scopes = '%20'.join(scope for scope in scopes) 108 | self.proxy = proxy 109 | self.proxy_auth = proxy_auth 110 | 111 | async def init(self): 112 | """ 113 | Initialized the connection to the discord api 114 | """ 115 | if self.client_session is not None: 116 | return 117 | self.client_session = aiohttp.ClientSession() 118 | 119 | def get_oauth_login_url(self, state: Optional[str] = None): 120 | """ 121 | 122 | Returns a Discord Login URL 123 | 124 | """ 125 | client_id = f'client_id={self.client_id}' 126 | redirect_uri = f'redirect_uri={self.redirect_uri}' 127 | scopes = f'scope={self.scopes}' 128 | response_type = 'response_type=code' 129 | state = f'&state={state}' if state else '' 130 | return f'{DISCORD_OAUTH_AUTHENTICATION_URL}?{client_id}&{redirect_uri}&{scopes}' + \ 131 | f'&{response_type}{state}' 132 | 133 | oauth_login_url = property(get_oauth_login_url) 134 | 135 | @aio_time_cache(max_age_seconds=550) 136 | async def request(self, route: str, token: Optional[str] = None, 137 | method: Literal['GET', 'POST'] = 'GET'): 138 | if self.client_session is None: 139 | raise ClientSessionNotInitialized 140 | headers: dict = {} 141 | if token: 142 | headers = {'Authorization': f'Bearer {token}'} 143 | if method == 'GET': 144 | async with self.client_session.get( 145 | f'{DISCORD_API_URL}{route}', 146 | headers=headers, 147 | proxy=self.proxy, 148 | proxy_auth=self.proxy_auth, 149 | ) as resp: 150 | data = await resp.json() 151 | elif method == 'POST': 152 | async with self.client_session.post( 153 | f'{DISCORD_API_URL}{route}', 154 | headers=headers, 155 | proxy=self.proxy, 156 | proxy_auth=self.proxy_auth, 157 | ) as resp: 158 | data = await resp.json() 159 | else: 160 | raise Exception( 161 | 'Other HTTP than GET and POST are currently not Supported') 162 | if resp.status == 401: 163 | raise Unauthorized 164 | if resp.status == 429: 165 | raise RateLimited(data, resp.headers) 166 | return data 167 | 168 | async def get_token_response(self, payload: PAYLOAD) -> TokenResponse: 169 | if self.client_session is None: 170 | raise ClientSessionNotInitialized 171 | async with self.client_session.post( 172 | DISCORD_TOKEN_URL, 173 | data=payload, 174 | proxy=self.proxy, 175 | proxy_auth=self.proxy_auth, 176 | ) as resp: 177 | return await resp.json() 178 | 179 | async def get_access_token(self, code: str) -> tuple[str, str]: 180 | payload: TokenGrantPayload = { 181 | 'client_id': self.client_id, 182 | 'client_secret': self.client_secret, 183 | 'grant_type': 'authorization_code', 184 | 'code': code, 185 | 'redirect_uri': self.redirect_uri, 186 | } 187 | resp = await self.get_token_response(payload) 188 | return _tokens(resp) 189 | 190 | async def refresh_access_token(self, refresh_token: str) -> tuple[str, str]: 191 | payload: RefreshTokenPayload = { 192 | 'client_id': self.client_id, 193 | 'client_secret': self.client_secret, 194 | 'grant_type': 'refresh_token', 195 | 'refresh_token': refresh_token, 196 | } 197 | resp = await self.get_token_response(payload) 198 | return _tokens(resp) 199 | 200 | async def user(self, request: Request): 201 | if 'identify' not in self.scopes: 202 | raise ScopeMissing('identify') 203 | route = '/users/@me' 204 | token = self.get_token(request) 205 | return User(**(await self.request(route, token))) 206 | 207 | async def guilds(self, request: Request) -> list[GuildPreview]: 208 | if 'guilds' not in self.scopes: 209 | raise ScopeMissing('guilds') 210 | route = '/users/@me/guilds' 211 | token = self.get_token(request) 212 | return [Guild(**guild) for guild in await self.request(route, token)] 213 | 214 | def get_token(self, request: Request): 215 | authorization_header = request.headers.get('Authorization') 216 | if not authorization_header: 217 | raise Unauthorized 218 | all_headers = authorization_header.split(' ') 219 | if not all_headers[0] == 'Bearer' or len(all_headers) > 2: 220 | raise Unauthorized 221 | 222 | token = all_headers[1] 223 | return token 224 | 225 | async def isAuthenticated(self, token: str): 226 | route = '/oauth2/@me' 227 | try: 228 | await self.request(route, token) 229 | return True 230 | except Unauthorized: 231 | return False 232 | 233 | async def requires_authorization(self, bearer: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer())): # noqa: E501 234 | if bearer is None: 235 | raise Unauthorized 236 | if not await self.isAuthenticated(bearer.credentials): 237 | raise Unauthorized 238 | -------------------------------------------------------------------------------- /app/discord/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discord module code taken and updated from: 3 | https://github.com/Tert0/fastapi-discord 4 | """ 5 | 6 | DISCORD_URL = 'https://discord.com' 7 | DISCORD_API_URL = f'{DISCORD_URL}/api/v10' 8 | DISCORD_OAUTH_URL = f'{DISCORD_URL}/api/oauth2' 9 | DISCORD_TOKEN_URL = f'{DISCORD_OAUTH_URL}/token' 10 | DISCORD_OAUTH_AUTHENTICATION_URL = f'{DISCORD_OAUTH_URL}/authorize' 11 | -------------------------------------------------------------------------------- /app/discord/exceptions.py: -------------------------------------------------------------------------------- 1 | class Unauthorized(Exception): 2 | """Raised when user is not authorized.""" 3 | 4 | 5 | class InvalidRequest(Exception): 6 | """Raised when a Request is not Valid""" 7 | 8 | 9 | class RateLimited(Exception): 10 | """Raised when a Request is not Valid""" 11 | 12 | def __init__(self, json, headers): 13 | self.json = json 14 | self.headers = headers 15 | self.message = json['message'] 16 | self.retry_after = json['retry_after'] 17 | super().__init__(self.message) 18 | 19 | 20 | class InvalidToken(Exception): 21 | """Raised when a Response has invalid tokens""" 22 | 23 | 24 | class ScopeMissing(Exception): 25 | scope: str 26 | 27 | def __init__(self, scope: str): 28 | self.scope = scope 29 | super().__init__(self.scope) 30 | 31 | 32 | class ClientSessionNotInitialized(Exception): 33 | """Raised when no Client Session is initialized but one would be needed""" 34 | pass 35 | -------------------------------------------------------------------------------- /app/discord/models/guild.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pydantic import BaseModel 3 | from app.discord.models.role import Role 4 | 5 | 6 | class GuildPreview(BaseModel): 7 | id: str 8 | name: str 9 | icon: Optional[str] = None 10 | owner: bool 11 | permissions: int 12 | features: list[str] 13 | 14 | 15 | class Guild(GuildPreview): 16 | owner_id: Optional[int] = None 17 | verification_level: Optional[int] = None 18 | default_message_notifications: Optional[int] = None 19 | roles: Optional[list[Role]] = None 20 | -------------------------------------------------------------------------------- /app/discord/models/role.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Role(BaseModel): 5 | id: int 6 | name: str 7 | color: int 8 | position: int 9 | permissions: int 10 | managed: bool 11 | mentionable: bool 12 | -------------------------------------------------------------------------------- /app/discord/models/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pydantic import BaseModel 3 | 4 | 5 | class User(BaseModel): 6 | id: str 7 | username: str 8 | discriminator: Optional[int] = None 9 | global_name: Optional[str] = None 10 | avatar: Optional[str] 11 | avatar_url: Optional[str] = 'https://cdn.discordapp.com/embed/avatars/1.png' 12 | locale: str 13 | email: Optional[str] = None 14 | mfa_enabled: Optional[bool] = None 15 | flags: Optional[int] = None 16 | premium_type: Optional[int] = None 17 | public_flags: Optional[int] = None 18 | banner: Optional[str] = None 19 | accent_color: Optional[int] = None 20 | verified: Optional[bool] = None 21 | avatar_decoration: Optional[str] = None 22 | 23 | def __init__(self, **data): 24 | super().__init__(**data) 25 | if self.avatar: 26 | self.avatar_url = f'https://cdn.discordapp.com/avatars/{self.id}/{self.avatar}.png' 27 | else: 28 | self.avatar_url = 'https://cdn.discordapp.com/embed/avatars/1.png' 29 | if self.discriminator == 0: 30 | self.discriminator = None 31 | -------------------------------------------------------------------------------- /app/lmbd/invoke.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | 3 | from typing import Callable 4 | from app.types.lmbd import InvokeResponse 5 | 6 | LAMBDA = boto3.client('lambda') 7 | 8 | 9 | def catchLambdaError(func: Callable): 10 | def wrapper(*args, **kwargs): 11 | try: 12 | return func(*args, **kwargs) 13 | # TODO: Catch specific Lambda errors 14 | except Exception: 15 | # TODO: SEND ERROR NOTIFICATION 16 | return {} 17 | return wrapper 18 | 19 | 20 | @catchLambdaError 21 | def invokeEvent(payload: str, functionName: str) -> InvokeResponse: 22 | return LAMBDA.invoke( 23 | FunctionName=functionName, 24 | InvocationType='Event', 25 | LogType='None', 26 | Payload=payload, 27 | ) 28 | -------------------------------------------------------------------------------- /app/lmbd/sample_data/snsCloudWatchEvent.json: -------------------------------------------------------------------------------- 1 | { 2 | "Records": [ 3 | { 4 | "EventVersion": "1.0", 5 | "EventSubscriptionArn": "arn:aws:sns:EXAMPLE", 6 | "EventSource": "aws:sns", 7 | "Sns": { 8 | "SignatureVersion": "1", 9 | "Timestamp": "1970-01-01T00:00:00.000Z", 10 | "Signature": "EXAMPLE", 11 | "SigningCertUrl": "EXAMPLE", 12 | "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e", 13 | "Message": { 14 | "AlarmName": "Saffron-Octopus-RDS", 15 | "AlarmDescription": null, 16 | "AWSAccountId": "498849832712", 17 | "NewStateValue": "ALARM", 18 | "NewStateReason": "Threshold Crossed: 1 datapoint [2.1533759377604764 (20/07/20 21: 07: 00)] was greater than or equal to the threshold (0.0175).", 19 | "StateChangeTime": "2020-07-20T21: 12: 01.544+0000", 20 | "Region": "US East (N. Virginia)", 21 | "AlarmArn": "arn:aws:cloudwatch:us-east-1: 498849832712:alarm:Saffron-Octopus-RDS", 22 | "OldStateValue": "INSUFFICIENT_DATA", 23 | "Trigger": { 24 | "MetricName": "CPUUtilization", 25 | "Namespace": "AWS/RDS", 26 | "StatisticType": "Statistic", 27 | "Statistic": "AVERAGE", 28 | "Unit": null, 29 | "Dimensions": [ 30 | { 31 | "value": "sm16lm1jrrjf0rk", 32 | "name": "DBInstanceIdentifier" 33 | } 34 | ], 35 | "Period": 300, 36 | "EvaluationPeriods": 1, 37 | "ComparisonOperator": "GreaterThanOrEqualToThreshold", 38 | "Threshold": 0.0175, 39 | "TreatMissingData": "", 40 | "EvaluateLowSampleCountPercentile": "" 41 | } 42 | }, 43 | "MessageAttributes": { 44 | "Test": { 45 | "Type": "String", 46 | "Value": "TestString" 47 | }, 48 | "TestBinary": { 49 | "Type": "Binary", 50 | "Value": "TestBinary" 51 | } 52 | }, 53 | "Type": "Notification", 54 | "UnsubscribeUrl": "EXAMPLE", 55 | "TopicArn": "arn:aws:sns:EXAMPLE", 56 | "Subject": "TestInvoke" 57 | } 58 | } 59 | ] 60 | } -------------------------------------------------------------------------------- /app/lmbd/sample_lambda/main.py: -------------------------------------------------------------------------------- 1 | # TODO: SAMPLE LAMBDA 2 | -------------------------------------------------------------------------------- /app/log/setup.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | 4 | def setup_logging(): 5 | # Format of the logs 6 | log_format = ( 7 | '{time:YYYY-MM-DD HH:mm:ss!UTC} | ' 8 | '{level: <8} | ' 9 | '{name}:{function}:{line} - ' 10 | '{message}' 11 | ) 12 | 13 | # General logs 14 | logger.add( 15 | 'logs/server.log', 16 | rotation='10 MB', 17 | retention='30 days', 18 | level='INFO', # INFO logs will log everything from INFO and above (WARNING, ERROR, etc.) 19 | format=log_format, 20 | enqueue=True, 21 | filter=lambda record: not bool(record['exception']), 22 | ) 23 | 24 | # Exception logs 25 | logger.add( 26 | 'logs/exceptions.log', 27 | rotation='10 MB', 28 | retention='30 days', 29 | level='ERROR', # ERROR will only log errors 30 | format=log_format, 31 | enqueue=True, 32 | backtrace=True, 33 | diagnose=True, 34 | filter=lambda record: bool(record['exception']), 35 | ) 36 | -------------------------------------------------------------------------------- /app/migrations/env.py: -------------------------------------------------------------------------------- 1 | from alembic import context 2 | from sqlalchemy import pool 3 | from app.core.config import settings 4 | from logging.config import fileConfig 5 | from sqlalchemy import engine_from_config 6 | from app.db.connection import MySQLTableBase 7 | 8 | # This is the Alembic Config object, which provides 9 | # access to the values within the .ini file in use. 10 | config = context.config 11 | 12 | # Interpret the config file for Python logging. 13 | # This line sets up loggers basically. 14 | if config.config_file_name is not None: 15 | fileConfig(config.config_file_name) 16 | 17 | TARGET_METADATA = MySQLTableBase.metadata 18 | 19 | 20 | def run_migrations_offline(): 21 | """ 22 | Run migrations in 'offline' mode. 23 | 24 | This configures the context with just a URL 25 | and not an Engine, though an Engine is acceptable 26 | here as well. By skipping the Engine creation 27 | we don't even need a DBAPI to be available. 28 | 29 | Calls to context.execute() here emit the given string to the 30 | script output. 31 | """ 32 | context.configure( 33 | url=settings.MYSQL_DATABASE_URI, 34 | target_metadata=TARGET_METADATA, 35 | literal_binds=True, 36 | dialect_opts={'paramstyle': 'named'}, 37 | ) 38 | 39 | with context.begin_transaction(): 40 | context.run_migrations() 41 | 42 | 43 | def run_migrations_online(): 44 | """ 45 | Run migrations in 'online' mode. 46 | 47 | In this scenario we need to create an Engine 48 | and associate a connection with the context. 49 | """ 50 | connectable = engine_from_config( 51 | config.get_section(config.config_ini_section), 52 | prefix='sqlalchemy.', 53 | poolclass=pool.NullPool, 54 | ) 55 | 56 | with connectable.connect() as connection: 57 | context.configure( 58 | connection=connection, 59 | target_metadata=TARGET_METADATA 60 | ) 61 | 62 | with context.begin_transaction(): 63 | context.run_migrations() 64 | 65 | 66 | if context.is_offline_mode(): 67 | run_migrations_offline() 68 | else: 69 | run_migrations_online() 70 | -------------------------------------------------------------------------------- /app/migrations/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade() -> None: 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade() -> None: 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /app/models/mysql.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from app.db.connection import MySQLTableBase 3 | 4 | from sqlalchemy.sql import func 5 | from sqlalchemy.dialects.mysql import BIGINT 6 | from sqlalchemy import String, ForeignKey, DateTime, Boolean 7 | from sqlalchemy.orm import relationship, Mapped, mapped_column 8 | 9 | 10 | class User(MySQLTableBase): 11 | __tablename__ = 'User' 12 | 13 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True) 14 | lastActive: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) 15 | joinDate: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) 16 | 17 | # One-to-many relationship with Post 18 | posts: Mapped[list['Post']] = relationship(back_populates='user', 19 | lazy='select') 20 | # One-to-one relationship with Profile 21 | profile: Mapped['Profile'] = relationship(back_populates='user', 22 | lazy='select') 23 | 24 | # Many-to-many with association object UserNotification 25 | notifications: Mapped[list['UserNotification']] = relationship( 26 | 'UserNotification', back_populates='user', lazy='select', cascade='all, delete-orphan' 27 | ) 28 | 29 | def __repr__(self) -> str: 30 | return f'' 31 | 32 | 33 | class UserNotification(MySQLTableBase): 34 | """ 35 | Association Object for User and Notification. 36 | Represents the notifications received by a user. 37 | """ 38 | __tablename__ = 'UserNotification' 39 | 40 | userID: Mapped[int] = mapped_column( 41 | BIGINT(unsigned=True), ForeignKey('User.id', ondelete='CASCADE'), primary_key=True 42 | ) 43 | notificationID: Mapped[int] = mapped_column( 44 | BIGINT(unsigned=True), ForeignKey('Notification.id', ondelete='CASCADE'), primary_key=True 45 | ) 46 | read: Mapped[bool] = mapped_column(Boolean) 47 | 48 | # Many-to-one relationship with User and Notification 49 | user: Mapped['User'] = relationship('User', back_populates='notifications') 50 | notification: Mapped['Notification'] = relationship('Notification', back_populates='users') 51 | 52 | def __repr__(self) -> str: 53 | return f'' 54 | 55 | 56 | class Notification(MySQLTableBase): 57 | """ 58 | Notification class representing the notifications sent to the users. 59 | """ 60 | 61 | __tablename__ = 'Notification' 62 | 63 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True) 64 | message: Mapped[str] = mapped_column(String(length=256)) 65 | time: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) 66 | 67 | # One-to-many relationship with UserNotification 68 | users: Mapped[list['UserNotification']] = relationship( 69 | 'UserNotification', back_populates='notification', lazy='select', passive_deletes=True 70 | ) 71 | 72 | def __repr__(self) -> str: 73 | return f'' 74 | 75 | 76 | class Post(MySQLTableBase): 77 | """ 78 | Post class representing the posts created by the users. 79 | """ 80 | 81 | __tablename__ = 'Post' 82 | 83 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True) 84 | content: Mapped[str] = mapped_column(String(length=1024)) 85 | created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) 86 | user_id: Mapped[int] = mapped_column(BIGINT(unsigned=True), ForeignKey('User.id')) 87 | 88 | # Relationship with User 89 | user: Mapped['User'] = relationship('User', back_populates='posts') 90 | 91 | def __repr__(self) -> str: 92 | return f'' 93 | 94 | 95 | class Profile(MySQLTableBase): 96 | """ 97 | Profile class representing the profile of each user. 98 | """ 99 | 100 | __tablename__ = 'Profile' 101 | 102 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True) 103 | bio: Mapped[str] = mapped_column(String(length=256)) 104 | user_id: Mapped[int] = mapped_column(BIGINT(unsigned=True), ForeignKey('User.id'), unique=True) 105 | 106 | # Relationship with User 107 | user: Mapped['User'] = relationship('User', back_populates='profile') 108 | 109 | def __repr__(self) -> str: 110 | return f'' 111 | -------------------------------------------------------------------------------- /app/models/mysql_no_fk.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this example, we use SQLAlchemy's ORM to create a MySQL tables *without* foreign keys. 3 | Instead, we create relationships in place of foreign keys. There are many reasons why 4 | someone might want to not use foreign keys: 5 | 1. Schema flexibility 6 | 2. Cross-database relationships 7 | 3. Performance 8 | 4. database agnostic code 9 | """ 10 | 11 | from datetime import datetime 12 | from sqlalchemy.sql import func 13 | from app.db.connection import MySQLTableBase 14 | from sqlalchemy.dialects.mysql import BIGINT 15 | from sqlalchemy import String, Boolean, DateTime 16 | from sqlalchemy.orm import relationship, Session, mapped_column, Mapped 17 | 18 | # TODO: CHECK IF THESE RELATIONSHIPS LOAD LAZILY OR EAGERLY 19 | # TODO: FIX WARNING ABOUT RELATIONSHIP COPYING COLUMN 20 | 21 | 22 | class User(MySQLTableBase): 23 | """ 24 | User class representing the user. 25 | """ 26 | __tablename__ = 'User' 27 | 28 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), 29 | primary_key=True, 30 | autoincrement=True) 31 | address: Mapped[str] = mapped_column(String(length=42), 32 | primary_key=True, 33 | unique=True) 34 | lastActive: Mapped[datetime] = mapped_column(DateTime, 35 | server_default=func.now(), 36 | onupdate=func.now()) 37 | joinDate: Mapped[datetime] = mapped_column(DateTime, 38 | server_default=func.now()) 39 | isBetaUser: Mapped[bool] = mapped_column(Boolean, default=False) 40 | 41 | # Relationships 42 | notifications: Mapped[list['UserNotification']] = relationship( 43 | primaryjoin='User.id == foreign(UserNotification.userID)', 44 | ) 45 | 46 | def load(self, session: Session) -> 'User': 47 | """ 48 | Load the user's data from the database and return the user instance. 49 | """ 50 | return session.query(User).filter_by(address=self.address).scalar() 51 | 52 | def __repr__(self) -> str: 53 | return f'' 54 | 55 | 56 | class UserNotification(MySQLTableBase): 57 | """ 58 | Association Table for User and Notification. 59 | Represents the notifications received by a user. 60 | """ 61 | __tablename__ = 'UserNotification' 62 | 63 | userID: Mapped[int] = mapped_column( 64 | BIGINT(unsigned=True), primary_key=True) 65 | notificationID: Mapped[int] = mapped_column( 66 | BIGINT(unsigned=True), primary_key=True) 67 | read: Mapped[bool] = mapped_column(Boolean, default=False) 68 | 69 | # Relationships 70 | user: Mapped[User] = relationship( 71 | primaryjoin='UserNotification.userID == foreign(User.id)' 72 | ) 73 | notification: Mapped['Notification'] = relationship( 74 | primaryjoin='UserNotification.notificationID == foreign(Notification.id)' 75 | ) 76 | 77 | def __repr__(self) -> str: 78 | return f'' 79 | 80 | 81 | class Notification(MySQLTableBase): 82 | """ 83 | Notification class representing the notification sent to the users. 84 | """ 85 | __tablename__ = 'Notification' 86 | 87 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), 88 | primary_key=True, 89 | autoincrement=True) 90 | message: Mapped[str] = mapped_column(String(length=256), primary_key=True) 91 | time: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) 92 | 93 | # Relationships 94 | users: Mapped[list['UserNotification']] = relationship( 95 | primaryjoin='Notification.id == foreign(UserNotification.notificationID)', 96 | ) 97 | 98 | def __repr__(self) -> str: 99 | return f'' 100 | -------------------------------------------------------------------------------- /app/types/cache.py: -------------------------------------------------------------------------------- 1 | # Don't use enums because usually places where we use these values require 2 | # string values, not Enum values. 3 | 4 | 5 | class RedisTokenPrefix: 6 | """ 7 | When creating a SessionStore, it is useful to have a prefix to avoid 8 | collisions with other keys in Redis. 9 | """ 10 | USER = 'user' 11 | 12 | 13 | class UserKey: 14 | """ 15 | Keys used to store user data in Redis. 16 | """ 17 | NONCE = 'nonce' 18 | -------------------------------------------------------------------------------- /app/types/jwt.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pydantic import BaseModel 3 | 4 | 5 | class JWTPayload(BaseModel): 6 | """ 7 | The payload of a JWT token. 8 | """ 9 | # Subject 10 | sub: str 11 | # Expiration 12 | exp: datetime 13 | # Issued at 14 | iat: datetime 15 | # Unique hex number 16 | nonce: str | None = None 17 | 18 | 19 | class TokenData(BaseModel): 20 | """ 21 | Data passed to create a JWT token. 22 | """ 23 | # Subject 24 | sub: str 25 | -------------------------------------------------------------------------------- /app/types/lmbd.py: -------------------------------------------------------------------------------- 1 | # AWS Lambda 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class InvokeResponse(BaseModel): 7 | StatusCode: int 8 | FunctionError: str 9 | Payload: bytes 10 | LogResult: str 11 | ExecutedVersion: str 12 | -------------------------------------------------------------------------------- /app/types/server.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | from pydantic import BaseModel 3 | from typing import Generic, TypeVar, Optional, Any, Literal 4 | 5 | DataT = TypeVar('DataT') 6 | 7 | ServerStatus = Literal['ok', 'error', 'missing_parameters'] 8 | 9 | 10 | class ServerCode(IntEnum): 11 | OK = 200 12 | CREATED = 201 13 | NO_CONTENT = 204 14 | BAD_REQUEST = 400 15 | UNAUTHORIZED = 401 16 | FORBIDDEN = 403 17 | NOT_FOUND = 404 18 | INTERNAL_SERVER_ERROR = 500 19 | 20 | 21 | # Not an enum because it's not a finite set of values 22 | class Cookie: 23 | REFRESH_TOKEN = 'refresh_token' 24 | 25 | 26 | class ServerResponse(BaseModel, Generic[DataT]): 27 | data: Optional[DataT] = None 28 | status: ServerStatus = 'ok' 29 | message: Optional[str] = None 30 | 31 | def dict(self, *args, **kwargs) -> dict[str, Any]: 32 | """ 33 | Override the default dict method to exclude None values in the response 34 | """ 35 | kwargs.pop('exclude_none', None) 36 | return super().model_dump(*args, exclude_none=True, **kwargs) 37 | -------------------------------------------------------------------------------- /app/types/time.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TimeBucketInterval(str, Enum): 5 | FIVE_MIN = '5 min' 6 | FIFTEEN_MIN = '15 min' 7 | ONE_HOUR = '1 hour' 8 | ONE_DAY = '1 day' 9 | ONE_WEEK = '1 week' 10 | ONE_MONTH = '1 month' 11 | ONE_YEAR = '1 year' 12 | -------------------------------------------------------------------------------- /app/util/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Callable 4 | 5 | 6 | def getFunctionArguments(func: Callable[..., str]) -> tuple[str, ...]: 7 | """ 8 | Gets the arguments of a function. 9 | """ 10 | return func.__code__.co_varnames[:func.__code__.co_argcount] 11 | 12 | 13 | def generateNonce() -> str: 14 | """ 15 | Generates a 32 length nonce hex string. 16 | """ 17 | return os.urandom(16).hex() 18 | -------------------------------------------------------------------------------- /app/util/json.py: -------------------------------------------------------------------------------- 1 | import orjson 2 | 3 | 4 | def orjson_dumps_str(obj): 5 | return orjson.dumps(obj).decode('utf-8') 6 | -------------------------------------------------------------------------------- /app/util/time.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | 4 | def iso8601ToTimestamp(iso8601: str) -> int: 5 | """ 6 | Converts an ISO8601 string to a timestamp. 7 | """ 8 | if iso8601.count('.') == 1: 9 | dt = datetime.strptime(iso8601, '%Y-%m-%dT%H:%M:%S.%f') 10 | else: 11 | dt = datetime.strptime(iso8601, '%Y-%m-%dT%H:%M:%S') 12 | 13 | dt = dt.replace(tzinfo=timezone.utc) 14 | return int(dt.timestamp()) 15 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | autopep8==2.3.1 2 | mypy==1.11.1 3 | types-redis==4.6.0.20240806 4 | ruff==0.5.6 -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | mysql-server: 3 | image: mysql/mysql-server:latest 4 | container_name: fastapi-mysql 5 | environment: 6 | MYSQL_ROOT_PASSWORD: "root" 7 | MYSQL_DATABASE: "fastapi-db" 8 | MYSQL_USER: "root" 9 | MYSQL_PASSWORD: "root" 10 | ports: 11 | - "3306:3306" 12 | volumes: 13 | - "./docker/mysql:/var/lib/mysql" 14 | 15 | redis-server: 16 | image: redis:latest 17 | container_name: fastapi-redis 18 | ports: 19 | - "6379:6379" 20 | -------------------------------------------------------------------------------- /docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | # Activate virtual environment 6 | . /opt/pysetup/.venv/bin/activate 7 | 8 | # You can put other setup logic here 9 | 10 | # Execute passed in command 11 | exec "$@" 12 | -------------------------------------------------------------------------------- /gunicorn_conf.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/tiangolo/uvicorn-gunicorn-docker/blob/315f04413114e938ff37a410b5979126facc90af/python3.7/gunicorn_conf.py 2 | 3 | import os 4 | import orjson 5 | import multiprocessing 6 | 7 | workers_per_core_str = os.getenv('WORKERS_PER_CORE', '1') 8 | web_concurrency_str = os.getenv('WEB_CONCURRENCY', None) 9 | host = '0.0.0.0' 10 | port = '8000' 11 | bind_env = os.getenv('BIND', None) 12 | use_loglevel = os.getenv('LOG_LEVEL', 'info') 13 | if bind_env: 14 | use_bind = bind_env 15 | else: 16 | use_bind = f'{host}:{port}' 17 | 18 | cores = multiprocessing.cpu_count() 19 | workers_per_core = float(workers_per_core_str) 20 | default_web_concurrency = workers_per_core * cores 21 | if web_concurrency_str: 22 | web_concurrency = int(web_concurrency_str) 23 | assert web_concurrency > 0 24 | else: 25 | web_concurrency = max(int(default_web_concurrency), 2) 26 | 27 | # Gunicorn config variables 28 | loglevel = use_loglevel 29 | workers = 2 # web_concurrency 30 | bind = use_bind 31 | keepalive = 120 32 | errorlog = '-' 33 | timeout = 600 34 | 35 | # For debugging and testing 36 | log_data = { 37 | 'loglevel': loglevel, 38 | 'workers': workers, 39 | 'bind': bind, 40 | # Additional, non-gunicorn variables 41 | 'workers_per_core': workers_per_core, 42 | 'host': host, 43 | 'port': port, 44 | } 45 | print(orjson.dumps(log_data)) 46 | -------------------------------------------------------------------------------- /loguru/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Loguru library provides a pre-instanced logger to facilitate dealing with logging in Python. 3 | 4 | Just ``from loguru import logger``. 5 | """ 6 | 7 | import atexit as _atexit 8 | import sys as _sys 9 | 10 | from . import _defaults 11 | from ._logger import Core as _Core 12 | from ._logger import Logger as _Logger 13 | 14 | __version__ = "0.7.2" 15 | 16 | __all__ = ["logger"] 17 | 18 | logger = _Logger( 19 | core=_Core(), 20 | exception=None, 21 | depth=0, 22 | record=False, 23 | lazy=False, 24 | colors=False, 25 | raw=False, 26 | capture=True, 27 | patchers=[], 28 | extra={}, 29 | ) 30 | 31 | if _defaults.LOGURU_AUTOINIT and _sys.stderr: 32 | logger.add(_sys.stderr) 33 | 34 | _atexit.register(logger.remove) 35 | -------------------------------------------------------------------------------- /loguru/__init__.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | .. |str| replace:: :class:`str` 3 | .. |namedtuple| replace:: :func:`namedtuple` 4 | .. |dict| replace:: :class:`dict` 5 | 6 | .. |Logger| replace:: :class:`~loguru._logger.Logger` 7 | .. |catch| replace:: :meth:`~loguru._logger.Logger.catch()` 8 | .. |contextualize| replace:: :meth:`~loguru._logger.Logger.contextualize()` 9 | .. |complete| replace:: :meth:`~loguru._logger.Logger.complete()` 10 | .. |bind| replace:: :meth:`~loguru._logger.Logger.bind()` 11 | .. |patch| replace:: :meth:`~loguru._logger.Logger.patch()` 12 | .. |opt| replace:: :meth:`~loguru._logger.Logger.opt()` 13 | .. |level| replace:: :meth:`~loguru._logger.Logger.level()` 14 | 15 | .. _stub file: https://www.python.org/dev/peps/pep-0484/#stub-files 16 | .. _string literals: https://www.python.org/dev/peps/pep-0484/#forward-references 17 | .. _postponed evaluation of annotations: https://www.python.org/dev/peps/pep-0563/ 18 | .. |future| replace:: ``__future__`` 19 | .. _future: https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7 20 | .. |loguru-mypy| replace:: ``loguru-mypy`` 21 | .. _loguru-mypy: https://github.com/kornicameister/loguru-mypy 22 | .. |documentation of loguru-mypy| replace:: documentation of ``loguru-mypy`` 23 | .. _documentation of loguru-mypy: 24 | https://github.com/kornicameister/loguru-mypy/blob/master/README.md 25 | .. _@kornicameister: https://github.com/kornicameister 26 | 27 | Loguru relies on a `stub file`_ to document its types. This implies that these types are not 28 | accessible during execution of your program, however they can be used by type checkers and IDE. 29 | Also, this means that your Python interpreter has to support `postponed evaluation of annotations`_ 30 | to prevent error at runtime. This is achieved with a |future|_ import in Python 3.7+ or by using 31 | `string literals`_ for earlier versions. 32 | 33 | A basic usage example could look like this: 34 | 35 | .. code-block:: python 36 | 37 | from __future__ import annotations 38 | 39 | import loguru 40 | from loguru import logger 41 | 42 | def good_sink(message: loguru.Message): 43 | print("My name is", message.record["name"]) 44 | 45 | def bad_filter(record: loguru.Record): 46 | return record["invalid"] 47 | 48 | logger.add(good_sink, filter=bad_filter) 49 | 50 | 51 | .. code-block:: bash 52 | 53 | $ mypy test.py 54 | test.py:8: error: TypedDict "Record" has no key 'invalid' 55 | Found 1 error in 1 file (checked 1 source file) 56 | 57 | There are several internal types to which you can be exposed using Loguru's public API, they are 58 | listed here and might be useful to type hint your code: 59 | 60 | - ``Logger``: the usual |logger| object (also returned by |opt|, |bind| and |patch|). 61 | - ``Message``: the formatted logging message sent to the sinks (a |str| with ``record`` 62 | attribute). 63 | - ``Record``: the |dict| containing all contextual information of the logged message. 64 | - ``Level``: the |namedtuple| returned by |level| (with ``name``, ``no``, ``color`` and ``icon`` 65 | attributes). 66 | - ``Catcher``: the context decorator returned by |catch|. 67 | - ``Contextualizer``: the context decorator returned by |contextualize|. 68 | - ``AwaitableCompleter``: the awaitable object returned by |complete|. 69 | - ``RecordFile``: the ``record["file"]`` with ``name`` and ``path`` attributes. 70 | - ``RecordLevel``: the ``record["level"]`` with ``name``, ``no`` and ``icon`` attributes. 71 | - ``RecordThread``: the ``record["thread"]`` with ``id`` and ``name`` attributes. 72 | - ``RecordProcess``: the ``record["process"]`` with ``id`` and ``name`` attributes. 73 | - ``RecordException``: the ``record["exception"]`` with ``type``, ``value`` and ``traceback`` 74 | attributes. 75 | 76 | If that is not enough, one can also use the |loguru-mypy|_ library developed by `@kornicameister`_. 77 | Plugin can be installed separately using:: 78 | 79 | pip install loguru-mypy 80 | 81 | It helps to catch several possible runtime errors by performing additional checks like: 82 | 83 | - ``opt(lazy=True)`` loggers accepting only ``typing.Callable[[], typing.Any]`` arguments 84 | - ``opt(record=True)`` loggers wrongly calling log handler like so ``logger.info(..., record={})`` 85 | - and even more... 86 | 87 | For more details, go to official |documentation of loguru-mypy|_. 88 | """ 89 | 90 | import sys 91 | from asyncio import AbstractEventLoop 92 | from datetime import datetime, time, timedelta 93 | from picologging import Handler 94 | from multiprocessing.context import BaseContext 95 | from types import TracebackType 96 | from typing import ( 97 | Any, 98 | BinaryIO, 99 | Callable, 100 | Dict, 101 | Generator, 102 | Generic, 103 | List, 104 | NamedTuple, 105 | NewType, 106 | Optional, 107 | Pattern, 108 | Sequence, 109 | TextIO, 110 | Tuple, 111 | Type, 112 | TypeVar, 113 | Union, 114 | overload, 115 | ) 116 | 117 | if sys.version_info >= (3, 5, 3): 118 | from typing import Awaitable 119 | else: 120 | from typing_extensions import Awaitable 121 | 122 | if sys.version_info >= (3, 6): 123 | from os import PathLike 124 | from typing import ContextManager 125 | 126 | PathLikeStr = PathLike[str] 127 | else: 128 | from pathlib import PurePath as PathLikeStr 129 | 130 | from typing_extensions import ContextManager 131 | 132 | if sys.version_info >= (3, 8): 133 | from typing import Protocol, TypedDict 134 | else: 135 | from typing_extensions import Protocol, TypedDict 136 | 137 | _T = TypeVar("_T") 138 | _F = TypeVar("_F", bound=Callable[..., Any]) 139 | ExcInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]] 140 | 141 | class _GeneratorContextManager(ContextManager[_T], Generic[_T]): 142 | def __call__(self, func: _F) -> _F: ... 143 | def __exit__( 144 | self, 145 | typ: Optional[Type[BaseException]], 146 | value: Optional[BaseException], 147 | traceback: Optional[TracebackType], 148 | ) -> Optional[bool]: ... 149 | 150 | Catcher = NewType("Catcher", _GeneratorContextManager[None]) 151 | Contextualizer = NewType("Contextualizer", _GeneratorContextManager[None]) 152 | AwaitableCompleter = Awaitable[None] 153 | 154 | class Level(NamedTuple): 155 | name: str 156 | no: int 157 | color: str 158 | icon: str 159 | 160 | class _RecordAttribute: 161 | def __repr__(self) -> str: ... 162 | def __format__(self, spec: str) -> str: ... 163 | 164 | class RecordFile(_RecordAttribute): 165 | name: str 166 | path: str 167 | 168 | class RecordLevel(_RecordAttribute): 169 | name: str 170 | no: int 171 | icon: str 172 | 173 | class RecordThread(_RecordAttribute): 174 | id: int 175 | name: str 176 | 177 | class RecordProcess(_RecordAttribute): 178 | id: int 179 | name: str 180 | 181 | class RecordException(NamedTuple): 182 | type: Optional[Type[BaseException]] 183 | value: Optional[BaseException] 184 | traceback: Optional[TracebackType] 185 | 186 | class Record(TypedDict): 187 | elapsed: timedelta 188 | exception: Optional[RecordException] 189 | extra: Dict[Any, Any] 190 | file: RecordFile 191 | function: str 192 | level: RecordLevel 193 | line: int 194 | message: str 195 | module: str 196 | name: Union[str, None] 197 | process: RecordProcess 198 | thread: RecordThread 199 | time: datetime 200 | 201 | class Message(str): 202 | record: Record 203 | 204 | class Writable(Protocol): 205 | def write(self, message: Message) -> None: ... 206 | 207 | FilterDict = Dict[Union[str, None], Union[str, int, bool]] 208 | FilterFunction = Callable[[Record], bool] 209 | FormatFunction = Callable[[Record], str] 210 | PatcherFunction = Callable[[Record], None] 211 | RotationFunction = Callable[[Message, TextIO], bool] 212 | RetentionFunction = Callable[[List[str]], None] 213 | CompressionFunction = Callable[[str], None] 214 | 215 | # Actually unusable because TypedDict can't allow extra keys: python/mypy#4617 216 | class _HandlerConfig(TypedDict, total=False): 217 | sink: Union[str, PathLikeStr, TextIO, Writable, Callable[[Message], None], Handler] 218 | level: Union[str, int] 219 | format: Union[str, FormatFunction] 220 | filter: Optional[Union[str, FilterFunction, FilterDict]] 221 | colorize: Optional[bool] 222 | serialize: bool 223 | backtrace: bool 224 | diagnose: bool 225 | enqueue: bool 226 | catch: bool 227 | 228 | class LevelConfig(TypedDict, total=False): 229 | name: str 230 | no: int 231 | color: str 232 | icon: str 233 | 234 | ActivationConfig = Tuple[Union[str, None], bool] 235 | 236 | class Logger: 237 | @overload 238 | def add( 239 | self, 240 | sink: Union[TextIO, Writable, Callable[[Message], None], Handler], 241 | *, 242 | level: Union[str, int] = ..., 243 | format: Union[str, FormatFunction] = ..., 244 | filter: Optional[Union[str, FilterFunction, FilterDict]] = ..., 245 | colorize: Optional[bool] = ..., 246 | serialize: bool = ..., 247 | backtrace: bool = ..., 248 | diagnose: bool = ..., 249 | enqueue: bool = ..., 250 | context: Optional[Union[str, BaseContext]] = ..., 251 | catch: bool = ... 252 | ) -> int: ... 253 | @overload 254 | def add( 255 | self, 256 | sink: Callable[[Message], Awaitable[None]], 257 | *, 258 | level: Union[str, int] = ..., 259 | format: Union[str, FormatFunction] = ..., 260 | filter: Optional[Union[str, FilterFunction, FilterDict]] = ..., 261 | colorize: Optional[bool] = ..., 262 | serialize: bool = ..., 263 | backtrace: bool = ..., 264 | diagnose: bool = ..., 265 | enqueue: bool = ..., 266 | context: Optional[Union[str, BaseContext]] = ..., 267 | catch: bool = ..., 268 | loop: Optional[AbstractEventLoop] = ... 269 | ) -> int: ... 270 | @overload 271 | def add( 272 | self, 273 | sink: Union[str, PathLikeStr], 274 | *, 275 | level: Union[str, int] = ..., 276 | format: Union[str, FormatFunction] = ..., 277 | filter: Optional[Union[str, FilterFunction, FilterDict]] = ..., 278 | colorize: Optional[bool] = ..., 279 | serialize: bool = ..., 280 | backtrace: bool = ..., 281 | diagnose: bool = ..., 282 | enqueue: bool = ..., 283 | context: Optional[Union[str, BaseContext]] = ..., 284 | catch: bool = ..., 285 | rotation: Optional[Union[str, int, time, timedelta, RotationFunction]] = ..., 286 | retention: Optional[Union[str, int, timedelta, RetentionFunction]] = ..., 287 | compression: Optional[Union[str, CompressionFunction]] = ..., 288 | delay: bool = ..., 289 | watch: bool = ..., 290 | mode: str = ..., 291 | buffering: int = ..., 292 | encoding: str = ..., 293 | **kwargs: Any 294 | ) -> int: ... 295 | def remove(self, handler_id: Optional[int] = ...) -> None: ... 296 | def complete(self) -> AwaitableCompleter: ... 297 | @overload 298 | def catch( 299 | self, 300 | exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = ..., 301 | *, 302 | level: Union[str, int] = ..., 303 | reraise: bool = ..., 304 | onerror: Optional[Callable[[BaseException], None]] = ..., 305 | exclude: Optional[Union[Type[BaseException], Tuple[Type[BaseException], ...]]] = ..., 306 | default: Any = ..., 307 | message: str = ... 308 | ) -> Catcher: ... 309 | @overload 310 | def catch(self, function: _F) -> _F: ... 311 | def opt( 312 | self, 313 | *, 314 | exception: Optional[Union[bool, ExcInfo, BaseException]] = ..., 315 | record: bool = ..., 316 | lazy: bool = ..., 317 | colors: bool = ..., 318 | raw: bool = ..., 319 | capture: bool = ..., 320 | depth: int = ..., 321 | ansi: bool = ... 322 | ) -> Logger: ... 323 | def bind(__self, **kwargs: Any) -> Logger: ... # noqa: N805 324 | def contextualize(__self, **kwargs: Any) -> Contextualizer: ... # noqa: N805 325 | def patch(self, patcher: PatcherFunction) -> Logger: ... 326 | @overload 327 | def level(self, name: str) -> Level: ... 328 | @overload 329 | def level( 330 | self, name: str, no: int = ..., color: Optional[str] = ..., icon: Optional[str] = ... 331 | ) -> Level: ... 332 | @overload 333 | def level( 334 | self, 335 | name: str, 336 | no: Optional[int] = ..., 337 | color: Optional[str] = ..., 338 | icon: Optional[str] = ..., 339 | ) -> Level: ... 340 | def disable(self, name: Union[str, None]) -> None: ... 341 | def enable(self, name: Union[str, None]) -> None: ... 342 | def configure( 343 | self, 344 | *, 345 | handlers: Sequence[Dict[str, Any]] = ..., 346 | levels: Optional[Sequence[LevelConfig]] = ..., 347 | extra: Optional[Dict[Any, Any]] = ..., 348 | patcher: Optional[PatcherFunction] = ..., 349 | activation: Optional[Sequence[ActivationConfig]] = ... 350 | ) -> List[int]: ... 351 | # @staticmethod cannot be used with @overload in mypy (python/mypy#7781). 352 | # However Logger is not exposed and logger is an instance of Logger 353 | # so for type checkers it is all the same whether it is defined here 354 | # as a static method or an instance method. 355 | @overload 356 | def parse( 357 | self, 358 | file: Union[str, PathLikeStr, TextIO], 359 | pattern: Union[str, Pattern[str]], 360 | *, 361 | cast: Union[Dict[str, Callable[[str], Any]], Callable[[Dict[str, str]], None]] = ..., 362 | chunk: int = ... 363 | ) -> Generator[Dict[str, Any], None, None]: ... 364 | @overload 365 | def parse( 366 | self, 367 | file: BinaryIO, 368 | pattern: Union[bytes, Pattern[bytes]], 369 | *, 370 | cast: Union[Dict[str, Callable[[bytes], Any]], Callable[[Dict[str, bytes]], None]] = ..., 371 | chunk: int = ... 372 | ) -> Generator[Dict[str, Any], None, None]: ... 373 | @overload 374 | def trace(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 375 | @overload 376 | def trace(__self, __message: Any) -> None: ... # noqa: N805 377 | @overload 378 | def debug(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 379 | @overload 380 | def debug(__self, __message: Any) -> None: ... # noqa: N805 381 | @overload 382 | def info(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 383 | @overload 384 | def info(__self, __message: Any) -> None: ... # noqa: N805 385 | @overload 386 | def success(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 387 | @overload 388 | def success(__self, __message: Any) -> None: ... # noqa: N805 389 | @overload 390 | def warning(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 391 | @overload 392 | def warning(__self, __message: Any) -> None: ... # noqa: N805 393 | @overload 394 | def error(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 395 | @overload 396 | def error(__self, __message: Any) -> None: ... # noqa: N805 397 | @overload 398 | def critical(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 399 | @overload 400 | def critical(__self, __message: Any) -> None: ... # noqa: N805 401 | @overload 402 | def exception(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805 403 | @overload 404 | def exception(__self, __message: Any) -> None: ... # noqa: N805 405 | @overload 406 | def log( 407 | __self, __level: Union[int, str], __message: str, *args: Any, **kwargs: Any # noqa: N805 408 | ) -> None: ... 409 | @overload 410 | def log(__self, __level: Union[int, str], __message: Any) -> None: ... # noqa: N805 411 | def start(self, *args: Any, **kwargs: Any) -> int: ... 412 | def stop(self, *args: Any, **kwargs: Any) -> None: ... 413 | 414 | logger: Logger 415 | -------------------------------------------------------------------------------- /loguru/_asyncio_loop.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | 4 | 5 | def load_loop_functions(): 6 | if sys.version_info >= (3, 7): 7 | 8 | def get_task_loop(task): 9 | return task.get_loop() 10 | 11 | get_running_loop = asyncio.get_running_loop 12 | 13 | else: 14 | 15 | def get_task_loop(task): 16 | return task._loop 17 | 18 | def get_running_loop(): 19 | loop = asyncio.get_event_loop() 20 | if not loop.is_running(): 21 | raise RuntimeError("There is no running event loop") 22 | return loop 23 | 24 | return get_task_loop, get_running_loop 25 | 26 | 27 | get_task_loop, get_running_loop = load_loop_functions() 28 | -------------------------------------------------------------------------------- /loguru/_better_exceptions.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import inspect 3 | import io 4 | import keyword 5 | import linecache 6 | import os 7 | import re 8 | import sys 9 | import sysconfig 10 | import tokenize 11 | import traceback 12 | 13 | if sys.version_info >= (3, 11): 14 | 15 | def is_exception_group(exc): 16 | return isinstance(exc, ExceptionGroup) 17 | 18 | else: 19 | try: 20 | from exceptiongroup import ExceptionGroup 21 | except ImportError: 22 | 23 | def is_exception_group(exc): 24 | return False 25 | 26 | else: 27 | 28 | def is_exception_group(exc): 29 | return isinstance(exc, ExceptionGroup) 30 | 31 | 32 | class SyntaxHighlighter: 33 | _default_style = { 34 | "comment": "\x1b[30m\x1b[1m{}\x1b[0m", 35 | "keyword": "\x1b[35m\x1b[1m{}\x1b[0m", 36 | "builtin": "\x1b[1m{}\x1b[0m", 37 | "string": "\x1b[36m{}\x1b[0m", 38 | "number": "\x1b[34m\x1b[1m{}\x1b[0m", 39 | "operator": "\x1b[35m\x1b[1m{}\x1b[0m", 40 | "punctuation": "\x1b[1m{}\x1b[0m", 41 | "constant": "\x1b[36m\x1b[1m{}\x1b[0m", 42 | "identifier": "\x1b[1m{}\x1b[0m", 43 | "other": "{}", 44 | } 45 | 46 | _builtins = set(dir(builtins)) 47 | _constants = {"True", "False", "None"} 48 | _punctuation = {"(", ")", "[", "]", "{", "}", ":", ",", ";"} 49 | _strings = {tokenize.STRING} 50 | _fstring_middle = None 51 | 52 | if sys.version_info >= (3, 12): 53 | _strings.update({tokenize.FSTRING_START, tokenize.FSTRING_MIDDLE, tokenize.FSTRING_END}) 54 | _fstring_middle = tokenize.FSTRING_MIDDLE 55 | 56 | def __init__(self, style=None): 57 | self._style = style or self._default_style 58 | 59 | def highlight(self, source): 60 | style = self._style 61 | row, column = 0, 0 62 | output = "" 63 | 64 | for token in self.tokenize(source): 65 | type_, string, (start_row, start_column), (_, end_column), line = token 66 | 67 | if type_ == self._fstring_middle: 68 | # When an f-string contains "{{" or "}}", they appear as "{" or "}" in the "string" 69 | # attribute of the token. However, they do not count in the column position. 70 | end_column += string.count("{") + string.count("}") 71 | 72 | if type_ == tokenize.NAME: 73 | if string in self._constants: 74 | color = style["constant"] 75 | elif keyword.iskeyword(string): 76 | color = style["keyword"] 77 | elif string in self._builtins: 78 | color = style["builtin"] 79 | else: 80 | color = style["identifier"] 81 | elif type_ == tokenize.OP: 82 | if string in self._punctuation: 83 | color = style["punctuation"] 84 | else: 85 | color = style["operator"] 86 | elif type_ == tokenize.NUMBER: 87 | color = style["number"] 88 | elif type_ in self._strings: 89 | color = style["string"] 90 | elif type_ == tokenize.COMMENT: 91 | color = style["comment"] 92 | else: 93 | color = style["other"] 94 | 95 | if start_row != row: 96 | source = source[column:] 97 | row, column = start_row, 0 98 | 99 | if type_ != tokenize.ENCODING: 100 | output += line[column:start_column] 101 | output += color.format(line[start_column:end_column]) 102 | 103 | column = end_column 104 | 105 | output += source[column:] 106 | 107 | return output 108 | 109 | @staticmethod 110 | def tokenize(source): 111 | # Worth reading: https://www.asmeurer.com/brown-water-python/ 112 | source = source.encode("utf-8") 113 | source = io.BytesIO(source) 114 | 115 | try: 116 | yield from tokenize.tokenize(source.readline) 117 | except tokenize.TokenError: 118 | return 119 | 120 | 121 | class ExceptionFormatter: 122 | _default_theme = { 123 | "introduction": "\x1b[33m\x1b[1m{}\x1b[0m", 124 | "cause": "\x1b[1m{}\x1b[0m", 125 | "context": "\x1b[1m{}\x1b[0m", 126 | "dirname": "\x1b[32m{}\x1b[0m", 127 | "basename": "\x1b[32m\x1b[1m{}\x1b[0m", 128 | "line": "\x1b[33m{}\x1b[0m", 129 | "function": "\x1b[35m{}\x1b[0m", 130 | "exception_type": "\x1b[31m\x1b[1m{}\x1b[0m", 131 | "exception_value": "\x1b[1m{}\x1b[0m", 132 | "arrows": "\x1b[36m{}\x1b[0m", 133 | "value": "\x1b[36m\x1b[1m{}\x1b[0m", 134 | } 135 | 136 | def __init__( 137 | self, 138 | colorize=False, 139 | backtrace=False, 140 | diagnose=True, 141 | theme=None, 142 | style=None, 143 | max_length=128, 144 | encoding="ascii", 145 | hidden_frames_filename=None, 146 | prefix="", 147 | ): 148 | self._colorize = colorize 149 | self._diagnose = diagnose 150 | self._theme = theme or self._default_theme 151 | self._backtrace = backtrace 152 | self._syntax_highlighter = SyntaxHighlighter(style) 153 | self._max_length = max_length 154 | self._encoding = encoding 155 | self._hidden_frames_filename = hidden_frames_filename 156 | self._prefix = prefix 157 | self._lib_dirs = self._get_lib_dirs() 158 | self._pipe_char = self._get_char("\u2502", "|") 159 | self._cap_char = self._get_char("\u2514", "->") 160 | self._catch_point_identifier = " " 161 | 162 | @staticmethod 163 | def _get_lib_dirs(): 164 | schemes = sysconfig.get_scheme_names() 165 | names = ["stdlib", "platstdlib", "platlib", "purelib"] 166 | paths = {sysconfig.get_path(name, scheme) for scheme in schemes for name in names} 167 | return [os.path.abspath(path).lower() + os.sep for path in paths if path in sys.path] 168 | 169 | @staticmethod 170 | def _indent(text, count, *, prefix="| "): 171 | if count == 0: 172 | yield text 173 | return 174 | for line in text.splitlines(True): 175 | indented = " " * count + prefix + line 176 | yield indented.rstrip() + "\n" 177 | 178 | def _get_char(self, char, default): 179 | try: 180 | char.encode(self._encoding) 181 | except (UnicodeEncodeError, LookupError): 182 | return default 183 | else: 184 | return char 185 | 186 | def _is_file_mine(self, file): 187 | filepath = os.path.abspath(file).lower() 188 | if not filepath.endswith(".py"): 189 | return False 190 | return not any(filepath.startswith(d) for d in self._lib_dirs) 191 | 192 | def _extract_frames(self, tb, is_first, *, limit=None, from_decorator=False): 193 | frames, final_source = [], None 194 | 195 | if tb is None or (limit is not None and limit <= 0): 196 | return frames, final_source 197 | 198 | def is_valid(frame): 199 | return frame.f_code.co_filename != self._hidden_frames_filename 200 | 201 | def get_info(frame, lineno): 202 | filename = frame.f_code.co_filename 203 | function = frame.f_code.co_name 204 | source = linecache.getline(filename, lineno).strip() 205 | return filename, lineno, function, source 206 | 207 | infos = [] 208 | 209 | if is_valid(tb.tb_frame): 210 | infos.append((get_info(tb.tb_frame, tb.tb_lineno), tb.tb_frame)) 211 | 212 | get_parent_only = from_decorator and not self._backtrace 213 | 214 | if (self._backtrace and is_first) or get_parent_only: 215 | frame = tb.tb_frame.f_back 216 | while frame: 217 | if is_valid(frame): 218 | infos.insert(0, (get_info(frame, frame.f_lineno), frame)) 219 | if get_parent_only: 220 | break 221 | frame = frame.f_back 222 | 223 | if infos and not get_parent_only: 224 | (filename, lineno, function, source), frame = infos[-1] 225 | function += self._catch_point_identifier 226 | infos[-1] = ((filename, lineno, function, source), frame) 227 | 228 | tb = tb.tb_next 229 | 230 | while tb: 231 | if is_valid(tb.tb_frame): 232 | infos.append((get_info(tb.tb_frame, tb.tb_lineno), tb.tb_frame)) 233 | tb = tb.tb_next 234 | 235 | if limit is not None: 236 | infos = infos[-limit:] 237 | 238 | for (filename, lineno, function, source), frame in infos: 239 | final_source = source 240 | if source: 241 | colorize = self._colorize and self._is_file_mine(filename) 242 | lines = [] 243 | if colorize: 244 | lines.append(self._syntax_highlighter.highlight(source)) 245 | else: 246 | lines.append(source) 247 | if self._diagnose: 248 | relevant_values = self._get_relevant_values(source, frame) 249 | values = self._format_relevant_values(list(relevant_values), colorize) 250 | lines += list(values) 251 | source = "\n ".join(lines) 252 | frames.append((filename, lineno, function, source)) 253 | 254 | return frames, final_source 255 | 256 | def _get_relevant_values(self, source, frame): 257 | value = None 258 | pending = None 259 | is_attribute = False 260 | is_valid_value = False 261 | is_assignment = True 262 | 263 | for token in self._syntax_highlighter.tokenize(source): 264 | type_, string, (_, col), *_ = token 265 | 266 | if pending is not None: 267 | # Keyword arguments are ignored 268 | if type_ != tokenize.OP or string != "=" or is_assignment: 269 | yield pending 270 | pending = None 271 | 272 | if type_ == tokenize.NAME and not keyword.iskeyword(string): 273 | if not is_attribute: 274 | for variables in (frame.f_locals, frame.f_globals): 275 | try: 276 | value = variables[string] 277 | except KeyError: 278 | continue 279 | else: 280 | is_valid_value = True 281 | pending = (col, self._format_value(value)) 282 | break 283 | elif is_valid_value: 284 | try: 285 | value = inspect.getattr_static(value, string) 286 | except AttributeError: 287 | is_valid_value = False 288 | else: 289 | yield (col, self._format_value(value)) 290 | elif type_ == tokenize.OP and string == ".": 291 | is_attribute = True 292 | is_assignment = False 293 | elif type_ == tokenize.OP and string == ";": 294 | is_assignment = True 295 | is_attribute = False 296 | is_valid_value = False 297 | else: 298 | is_attribute = False 299 | is_valid_value = False 300 | is_assignment = False 301 | 302 | if pending is not None: 303 | yield pending 304 | 305 | def _format_relevant_values(self, relevant_values, colorize): 306 | for i in reversed(range(len(relevant_values))): 307 | col, value = relevant_values[i] 308 | pipe_cols = [pcol for pcol, _ in relevant_values[:i]] 309 | pre_line = "" 310 | index = 0 311 | 312 | for pc in pipe_cols: 313 | pre_line += (" " * (pc - index)) + self._pipe_char 314 | index = pc + 1 315 | 316 | pre_line += " " * (col - index) 317 | value_lines = value.split("\n") 318 | 319 | for n, value_line in enumerate(value_lines): 320 | if n == 0: 321 | arrows = pre_line + self._cap_char + " " 322 | else: 323 | arrows = pre_line + " " * (len(self._cap_char) + 1) 324 | 325 | if colorize: 326 | arrows = self._theme["arrows"].format(arrows) 327 | value_line = self._theme["value"].format(value_line) 328 | 329 | yield arrows + value_line 330 | 331 | def _format_value(self, v): 332 | try: 333 | v = repr(v) 334 | except Exception: 335 | v = "" % type(v).__name__ 336 | 337 | max_length = self._max_length 338 | if max_length is not None and len(v) > max_length: 339 | v = v[: max_length - 3] + "..." 340 | return v 341 | 342 | def _format_locations(self, frames_lines, *, has_introduction): 343 | prepend_with_new_line = has_introduction 344 | regex = r'^ File "(?P.*?)", line (?P[^,]+)(?:, in (?P.*))?\n' 345 | 346 | for frame in frames_lines: 347 | match = re.match(regex, frame) 348 | 349 | if match: 350 | file, line, function = match.group("file", "line", "function") 351 | 352 | is_mine = self._is_file_mine(file) 353 | 354 | if function is not None: 355 | pattern = ' File "{}", line {}, in {}\n' 356 | else: 357 | pattern = ' File "{}", line {}\n' 358 | 359 | if self._backtrace and function and function.endswith(self._catch_point_identifier): 360 | function = function[: -len(self._catch_point_identifier)] 361 | pattern = ">" + pattern[1:] 362 | 363 | if self._colorize and is_mine: 364 | dirname, basename = os.path.split(file) 365 | if dirname: 366 | dirname += os.sep 367 | dirname = self._theme["dirname"].format(dirname) 368 | basename = self._theme["basename"].format(basename) 369 | file = dirname + basename 370 | line = self._theme["line"].format(line) 371 | function = self._theme["function"].format(function) 372 | 373 | if self._diagnose and (is_mine or prepend_with_new_line): 374 | pattern = "\n" + pattern 375 | 376 | location = pattern.format(file, line, function) 377 | frame = location + frame[match.end() :] 378 | prepend_with_new_line = is_mine 379 | 380 | yield frame 381 | 382 | def _format_exception( 383 | self, value, tb, *, seen=None, is_first=False, from_decorator=False, group_nesting=0 384 | ): 385 | # Implemented from built-in traceback module: 386 | # https://github.com/python/cpython/blob/a5b76167/Lib/traceback.py#L468 387 | exc_type, exc_value, exc_traceback = type(value), value, tb 388 | 389 | if seen is None: 390 | seen = set() 391 | 392 | seen.add(id(exc_value)) 393 | 394 | if exc_value: 395 | if exc_value.__cause__ is not None and id(exc_value.__cause__) not in seen: 396 | yield from self._format_exception( 397 | exc_value.__cause__, 398 | exc_value.__cause__.__traceback__, 399 | seen=seen, 400 | group_nesting=group_nesting, 401 | ) 402 | cause = "The above exception was the direct cause of the following exception:" 403 | if self._colorize: 404 | cause = self._theme["cause"].format(cause) 405 | if self._diagnose: 406 | yield from self._indent("\n\n" + cause + "\n\n\n", group_nesting) 407 | else: 408 | yield from self._indent("\n" + cause + "\n\n", group_nesting) 409 | 410 | elif ( 411 | exc_value.__context__ is not None 412 | and id(exc_value.__context__) not in seen 413 | and not exc_value.__suppress_context__ 414 | ): 415 | yield from self._format_exception( 416 | exc_value.__context__, 417 | exc_value.__context__.__traceback__, 418 | seen=seen, 419 | group_nesting=group_nesting, 420 | ) 421 | context = "During handling of the above exception, another exception occurred:" 422 | if self._colorize: 423 | context = self._theme["context"].format(context) 424 | if self._diagnose: 425 | yield from self._indent("\n\n" + context + "\n\n\n", group_nesting) 426 | else: 427 | yield from self._indent("\n" + context + "\n\n", group_nesting) 428 | 429 | is_grouped = is_exception_group(value) 430 | 431 | if is_grouped and group_nesting == 0: 432 | yield from self._format_exception( 433 | value, 434 | tb, 435 | seen=seen, 436 | group_nesting=1, 437 | is_first=is_first, 438 | from_decorator=from_decorator, 439 | ) 440 | return 441 | 442 | try: 443 | traceback_limit = sys.tracebacklimit 444 | except AttributeError: 445 | traceback_limit = None 446 | 447 | frames, final_source = self._extract_frames( 448 | exc_traceback, is_first, limit=traceback_limit, from_decorator=from_decorator 449 | ) 450 | exception_only = traceback.format_exception_only(exc_type, exc_value) 451 | 452 | # Determining the correct index for the "Exception: message" part in the formatted exception 453 | # is challenging. This is because it might be preceded by multiple lines specific to 454 | # "SyntaxError" or followed by various notes. However, we can make an educated guess based 455 | # on the indentation; the preliminary context for "SyntaxError" is always indented, while 456 | # the Exception itself is not. This allows us to identify the correct index for the 457 | # exception message. 458 | no_indented_indexes = (i for i, p in enumerate(exception_only) if not p.startswith(" ")) 459 | error_message_index = next(no_indented_indexes, None) 460 | 461 | if error_message_index is not None: 462 | # Remove final new line temporarily. 463 | error_message = exception_only[error_message_index][:-1] 464 | 465 | if self._colorize: 466 | if ":" in error_message: 467 | exception_type, exception_value = error_message.split(":", 1) 468 | exception_type = self._theme["exception_type"].format(exception_type) 469 | exception_value = self._theme["exception_value"].format(exception_value) 470 | error_message = exception_type + ":" + exception_value 471 | else: 472 | error_message = self._theme["exception_type"].format(error_message) 473 | 474 | if self._diagnose and frames: 475 | if issubclass(exc_type, AssertionError) and not str(exc_value) and final_source: 476 | if self._colorize: 477 | final_source = self._syntax_highlighter.highlight(final_source) 478 | error_message += ": " + final_source 479 | 480 | error_message = "\n" + error_message 481 | 482 | exception_only[error_message_index] = error_message + "\n" 483 | 484 | if is_first: 485 | yield self._prefix 486 | 487 | has_introduction = bool(frames) 488 | 489 | if has_introduction: 490 | if is_grouped: 491 | introduction = "Exception Group Traceback (most recent call last):" 492 | else: 493 | introduction = "Traceback (most recent call last):" 494 | if self._colorize: 495 | introduction = self._theme["introduction"].format(introduction) 496 | if group_nesting == 1: # Implies we're processing the root ExceptionGroup. 497 | yield from self._indent(introduction + "\n", group_nesting, prefix="+ ") 498 | else: 499 | yield from self._indent(introduction + "\n", group_nesting) 500 | 501 | frames_lines = self._format_list(frames) + exception_only 502 | if self._colorize or self._backtrace or self._diagnose: 503 | frames_lines = self._format_locations(frames_lines, has_introduction=has_introduction) 504 | 505 | yield from self._indent("".join(frames_lines), group_nesting) 506 | 507 | if is_grouped: 508 | exc = None 509 | for n, exc in enumerate(value.exceptions, start=1): 510 | ruler = "+" + (" %s " % ("..." if n > 15 else n)).center(35, "-") 511 | yield from self._indent(ruler, group_nesting, prefix="+-" if n == 1 else " ") 512 | if n > 15: 513 | message = "and %d more exceptions\n" % (len(value.exceptions) - 15) 514 | yield from self._indent(message, group_nesting + 1) 515 | break 516 | elif group_nesting == 10 and is_exception_group(exc): 517 | message = "... (max_group_depth is 10)\n" 518 | yield from self._indent(message, group_nesting + 1) 519 | else: 520 | yield from self._format_exception( 521 | exc, 522 | exc.__traceback__, 523 | seen=seen, 524 | group_nesting=group_nesting + 1, 525 | ) 526 | if not is_exception_group(exc) or group_nesting == 10: 527 | yield from self._indent("-" * 35, group_nesting + 1, prefix="+-") 528 | 529 | def _format_list(self, frames): 530 | result = [] 531 | for filename, lineno, name, line in frames: 532 | row = [] 533 | row.append(' File "{}", line {}, in {}\n'.format(filename, lineno, name)) 534 | if line: 535 | row.append(" {}\n".format(line.strip())) 536 | result.append("".join(row)) 537 | return result 538 | 539 | def format_exception(self, type_, value, tb, *, from_decorator=False): 540 | yield from self._format_exception(value, tb, is_first=True, from_decorator=from_decorator) 541 | -------------------------------------------------------------------------------- /loguru/_colorama.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import os 3 | import sys 4 | 5 | 6 | def should_colorize(stream): 7 | if stream is None: 8 | return False 9 | 10 | if getattr(builtins, "__IPYTHON__", False) and (stream is sys.stdout or stream is sys.stderr): 11 | try: 12 | import ipykernel 13 | import IPython 14 | 15 | ipython = IPython.get_ipython() 16 | is_jupyter_stream = isinstance(stream, ipykernel.iostream.OutStream) 17 | is_jupyter_shell = isinstance(ipython, ipykernel.zmqshell.ZMQInteractiveShell) 18 | except Exception: 19 | pass 20 | else: 21 | if is_jupyter_stream and is_jupyter_shell: 22 | return True 23 | 24 | if stream is sys.__stdout__ or stream is sys.__stderr__: 25 | if "CI" in os.environ and any( 26 | ci in os.environ 27 | for ci in ["TRAVIS", "CIRCLECI", "APPVEYOR", "GITLAB_CI", "GITHUB_ACTIONS"] 28 | ): 29 | return True 30 | if "PYCHARM_HOSTED" in os.environ: 31 | return True 32 | if os.name == "nt" and "TERM" in os.environ: 33 | return True 34 | 35 | try: 36 | return stream.isatty() 37 | except Exception: 38 | return False 39 | 40 | 41 | def should_wrap(stream): 42 | if os.name != "nt": 43 | return False 44 | 45 | if stream is not sys.__stdout__ and stream is not sys.__stderr__: 46 | return False 47 | 48 | from colorama.win32 import winapi_test 49 | 50 | if not winapi_test(): 51 | return False 52 | 53 | try: 54 | from colorama.winterm import enable_vt_processing 55 | except ImportError: 56 | return True 57 | 58 | try: 59 | return not enable_vt_processing(stream.fileno()) 60 | except Exception: 61 | return True 62 | 63 | 64 | def wrap(stream): 65 | from colorama import AnsiToWin32 66 | 67 | return AnsiToWin32(stream, convert=True, strip=True, autoreset=False).stream 68 | -------------------------------------------------------------------------------- /loguru/_colorizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | from string import Formatter 3 | 4 | 5 | class Style: 6 | RESET_ALL = 0 7 | BOLD = 1 8 | DIM = 2 9 | ITALIC = 3 10 | UNDERLINE = 4 11 | BLINK = 5 12 | REVERSE = 7 13 | HIDE = 8 14 | STRIKE = 9 15 | NORMAL = 22 16 | 17 | 18 | class Fore: 19 | BLACK = 30 20 | RED = 31 21 | GREEN = 32 22 | YELLOW = 33 23 | BLUE = 34 24 | MAGENTA = 35 25 | CYAN = 36 26 | WHITE = 37 27 | RESET = 39 28 | 29 | LIGHTBLACK_EX = 90 30 | LIGHTRED_EX = 91 31 | LIGHTGREEN_EX = 92 32 | LIGHTYELLOW_EX = 93 33 | LIGHTBLUE_EX = 94 34 | LIGHTMAGENTA_EX = 95 35 | LIGHTCYAN_EX = 96 36 | LIGHTWHITE_EX = 97 37 | 38 | 39 | class Back: 40 | BLACK = 40 41 | RED = 41 42 | GREEN = 42 43 | YELLOW = 43 44 | BLUE = 44 45 | MAGENTA = 45 46 | CYAN = 46 47 | WHITE = 47 48 | RESET = 49 49 | 50 | LIGHTBLACK_EX = 100 51 | LIGHTRED_EX = 101 52 | LIGHTGREEN_EX = 102 53 | LIGHTYELLOW_EX = 103 54 | LIGHTBLUE_EX = 104 55 | LIGHTMAGENTA_EX = 105 56 | LIGHTCYAN_EX = 106 57 | LIGHTWHITE_EX = 107 58 | 59 | 60 | def ansi_escape(codes): 61 | return {name: "\033[%dm" % code for name, code in codes.items()} 62 | 63 | 64 | class TokenType: 65 | TEXT = 1 66 | ANSI = 2 67 | LEVEL = 3 68 | CLOSING = 4 69 | 70 | 71 | class AnsiParser: 72 | _style = ansi_escape( 73 | { 74 | "b": Style.BOLD, 75 | "d": Style.DIM, 76 | "n": Style.NORMAL, 77 | "h": Style.HIDE, 78 | "i": Style.ITALIC, 79 | "l": Style.BLINK, 80 | "s": Style.STRIKE, 81 | "u": Style.UNDERLINE, 82 | "v": Style.REVERSE, 83 | "bold": Style.BOLD, 84 | "dim": Style.DIM, 85 | "normal": Style.NORMAL, 86 | "hide": Style.HIDE, 87 | "italic": Style.ITALIC, 88 | "blink": Style.BLINK, 89 | "strike": Style.STRIKE, 90 | "underline": Style.UNDERLINE, 91 | "reverse": Style.REVERSE, 92 | } 93 | ) 94 | 95 | _foreground = ansi_escape( 96 | { 97 | "k": Fore.BLACK, 98 | "r": Fore.RED, 99 | "g": Fore.GREEN, 100 | "y": Fore.YELLOW, 101 | "e": Fore.BLUE, 102 | "m": Fore.MAGENTA, 103 | "c": Fore.CYAN, 104 | "w": Fore.WHITE, 105 | "lk": Fore.LIGHTBLACK_EX, 106 | "lr": Fore.LIGHTRED_EX, 107 | "lg": Fore.LIGHTGREEN_EX, 108 | "ly": Fore.LIGHTYELLOW_EX, 109 | "le": Fore.LIGHTBLUE_EX, 110 | "lm": Fore.LIGHTMAGENTA_EX, 111 | "lc": Fore.LIGHTCYAN_EX, 112 | "lw": Fore.LIGHTWHITE_EX, 113 | "black": Fore.BLACK, 114 | "red": Fore.RED, 115 | "green": Fore.GREEN, 116 | "yellow": Fore.YELLOW, 117 | "blue": Fore.BLUE, 118 | "magenta": Fore.MAGENTA, 119 | "cyan": Fore.CYAN, 120 | "white": Fore.WHITE, 121 | "light-black": Fore.LIGHTBLACK_EX, 122 | "light-red": Fore.LIGHTRED_EX, 123 | "light-green": Fore.LIGHTGREEN_EX, 124 | "light-yellow": Fore.LIGHTYELLOW_EX, 125 | "light-blue": Fore.LIGHTBLUE_EX, 126 | "light-magenta": Fore.LIGHTMAGENTA_EX, 127 | "light-cyan": Fore.LIGHTCYAN_EX, 128 | "light-white": Fore.LIGHTWHITE_EX, 129 | } 130 | ) 131 | 132 | _background = ansi_escape( 133 | { 134 | "K": Back.BLACK, 135 | "R": Back.RED, 136 | "G": Back.GREEN, 137 | "Y": Back.YELLOW, 138 | "E": Back.BLUE, 139 | "M": Back.MAGENTA, 140 | "C": Back.CYAN, 141 | "W": Back.WHITE, 142 | "LK": Back.LIGHTBLACK_EX, 143 | "LR": Back.LIGHTRED_EX, 144 | "LG": Back.LIGHTGREEN_EX, 145 | "LY": Back.LIGHTYELLOW_EX, 146 | "LE": Back.LIGHTBLUE_EX, 147 | "LM": Back.LIGHTMAGENTA_EX, 148 | "LC": Back.LIGHTCYAN_EX, 149 | "LW": Back.LIGHTWHITE_EX, 150 | "BLACK": Back.BLACK, 151 | "RED": Back.RED, 152 | "GREEN": Back.GREEN, 153 | "YELLOW": Back.YELLOW, 154 | "BLUE": Back.BLUE, 155 | "MAGENTA": Back.MAGENTA, 156 | "CYAN": Back.CYAN, 157 | "WHITE": Back.WHITE, 158 | "LIGHT-BLACK": Back.LIGHTBLACK_EX, 159 | "LIGHT-RED": Back.LIGHTRED_EX, 160 | "LIGHT-GREEN": Back.LIGHTGREEN_EX, 161 | "LIGHT-YELLOW": Back.LIGHTYELLOW_EX, 162 | "LIGHT-BLUE": Back.LIGHTBLUE_EX, 163 | "LIGHT-MAGENTA": Back.LIGHTMAGENTA_EX, 164 | "LIGHT-CYAN": Back.LIGHTCYAN_EX, 165 | "LIGHT-WHITE": Back.LIGHTWHITE_EX, 166 | } 167 | ) 168 | 169 | _regex_tag = re.compile(r"\\?\s]*)>") 170 | 171 | def __init__(self): 172 | self._tokens = [] 173 | self._tags = [] 174 | self._color_tokens = [] 175 | 176 | @staticmethod 177 | def strip(tokens): 178 | output = "" 179 | for type_, value in tokens: 180 | if type_ == TokenType.TEXT: 181 | output += value 182 | return output 183 | 184 | @staticmethod 185 | def colorize(tokens, ansi_level): 186 | output = "" 187 | 188 | for type_, value in tokens: 189 | if type_ == TokenType.LEVEL: 190 | if ansi_level is None: 191 | raise ValueError( 192 | "The '' color tag is not allowed in this context, " 193 | "it has not yet been associated to any color value." 194 | ) 195 | value = ansi_level 196 | output += value 197 | 198 | return output 199 | 200 | @staticmethod 201 | def wrap(tokens, *, ansi_level, color_tokens): 202 | output = "" 203 | 204 | for type_, value in tokens: 205 | if type_ == TokenType.LEVEL: 206 | value = ansi_level 207 | output += value 208 | if type_ == TokenType.CLOSING: 209 | for subtype, subvalue in color_tokens: 210 | if subtype == TokenType.LEVEL: 211 | subvalue = ansi_level 212 | output += subvalue 213 | 214 | return output 215 | 216 | def feed(self, text, *, raw=False): 217 | if raw: 218 | self._tokens.append((TokenType.TEXT, text)) 219 | return 220 | 221 | position = 0 222 | 223 | for match in self._regex_tag.finditer(text): 224 | markup, tag = match.group(0), match.group(1) 225 | 226 | self._tokens.append((TokenType.TEXT, text[position : match.start()])) 227 | 228 | position = match.end() 229 | 230 | if markup[0] == "\\": 231 | self._tokens.append((TokenType.TEXT, markup[1:])) 232 | continue 233 | 234 | if markup[1] == "/": 235 | if self._tags and (tag == "" or tag == self._tags[-1]): 236 | self._tags.pop() 237 | self._color_tokens.pop() 238 | self._tokens.append((TokenType.CLOSING, "\033[0m")) 239 | self._tokens.extend(self._color_tokens) 240 | continue 241 | if tag in self._tags: 242 | raise ValueError('Closing tag "%s" violates nesting rules' % markup) 243 | raise ValueError('Closing tag "%s" has no corresponding opening tag' % markup) 244 | 245 | if tag in {"lvl", "level"}: 246 | token = (TokenType.LEVEL, None) 247 | else: 248 | ansi = self._get_ansicode(tag) 249 | 250 | if ansi is None: 251 | raise ValueError( 252 | 'Tag "%s" does not correspond to any known color directive, ' 253 | "make sure you did not misspelled it (or prepend '\\' to escape it)" 254 | % markup 255 | ) 256 | 257 | token = (TokenType.ANSI, ansi) 258 | 259 | self._tags.append(tag) 260 | self._color_tokens.append(token) 261 | self._tokens.append(token) 262 | 263 | self._tokens.append((TokenType.TEXT, text[position:])) 264 | 265 | def done(self, *, strict=True): 266 | if strict and self._tags: 267 | faulty_tag = self._tags.pop(0) 268 | raise ValueError('Opening tag "<%s>" has no corresponding closing tag' % faulty_tag) 269 | return self._tokens 270 | 271 | def current_color_tokens(self): 272 | return list(self._color_tokens) 273 | 274 | def _get_ansicode(self, tag): 275 | style = self._style 276 | foreground = self._foreground 277 | background = self._background 278 | 279 | # Substitute on a direct match. 280 | if tag in style: 281 | return style[tag] 282 | if tag in foreground: 283 | return foreground[tag] 284 | if tag in background: 285 | return background[tag] 286 | 287 | # An alternative syntax for setting the color (e.g. , ). 288 | if tag.startswith("fg ") or tag.startswith("bg "): 289 | st, color = tag[:2], tag[3:] 290 | code = "38" if st == "fg" else "48" 291 | 292 | if st == "fg" and color.lower() in foreground: 293 | return foreground[color.lower()] 294 | if st == "bg" and color.upper() in background: 295 | return background[color.upper()] 296 | if color.isdigit() and int(color) <= 255: 297 | return "\033[%s;5;%sm" % (code, color) 298 | if re.match(r"#(?:[a-fA-F0-9]{3}){1,2}$", color): 299 | hex_color = color[1:] 300 | if len(hex_color) == 3: 301 | hex_color *= 2 302 | rgb = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) 303 | return "\033[%s;2;%s;%s;%sm" % ((code,) + rgb) 304 | if color.count(",") == 2: 305 | colors = tuple(color.split(",")) 306 | if all(x.isdigit() and int(x) <= 255 for x in colors): 307 | return "\033[%s;2;%s;%s;%sm" % ((code,) + colors) 308 | 309 | return None 310 | 311 | 312 | class ColoringMessage(str): 313 | __fields__ = ("_messages",) 314 | 315 | def __format__(self, spec): 316 | return next(self._messages).__format__(spec) 317 | 318 | 319 | class ColoredMessage: 320 | def __init__(self, tokens): 321 | self.tokens = tokens 322 | self.stripped = AnsiParser.strip(tokens) 323 | 324 | def colorize(self, ansi_level): 325 | return AnsiParser.colorize(self.tokens, ansi_level) 326 | 327 | 328 | class ColoredFormat: 329 | def __init__(self, tokens, messages_color_tokens): 330 | self._tokens = tokens 331 | self._messages_color_tokens = messages_color_tokens 332 | 333 | def strip(self): 334 | return AnsiParser.strip(self._tokens) 335 | 336 | def colorize(self, ansi_level): 337 | return AnsiParser.colorize(self._tokens, ansi_level) 338 | 339 | def make_coloring_message(self, message, *, ansi_level, colored_message): 340 | messages = [ 341 | ( 342 | message 343 | if color_tokens is None 344 | else AnsiParser.wrap( 345 | colored_message.tokens, ansi_level=ansi_level, color_tokens=color_tokens 346 | ) 347 | ) 348 | for color_tokens in self._messages_color_tokens 349 | ] 350 | coloring = ColoringMessage(message) 351 | coloring._messages = iter(messages) 352 | return coloring 353 | 354 | 355 | class Colorizer: 356 | @staticmethod 357 | def prepare_format(string): 358 | tokens, messages_color_tokens = Colorizer._parse_without_formatting(string) 359 | return ColoredFormat(tokens, messages_color_tokens) 360 | 361 | @staticmethod 362 | def prepare_message(string, args=(), kwargs={}): # noqa: B006 363 | tokens = Colorizer._parse_with_formatting(string, args, kwargs) 364 | return ColoredMessage(tokens) 365 | 366 | @staticmethod 367 | def prepare_simple_message(string): 368 | parser = AnsiParser() 369 | parser.feed(string) 370 | tokens = parser.done() 371 | return ColoredMessage(tokens) 372 | 373 | @staticmethod 374 | def ansify(text): 375 | parser = AnsiParser() 376 | parser.feed(text.strip()) 377 | tokens = parser.done(strict=False) 378 | return AnsiParser.colorize(tokens, None) 379 | 380 | @staticmethod 381 | def _parse_with_formatting( 382 | string, args, kwargs, *, recursion_depth=2, auto_arg_index=0, recursive=False 383 | ): 384 | # This function re-implements Formatter._vformat() 385 | 386 | if recursion_depth < 0: 387 | raise ValueError("Max string recursion exceeded") 388 | 389 | formatter = Formatter() 390 | parser = AnsiParser() 391 | 392 | for literal_text, field_name, format_spec, conversion in formatter.parse(string): 393 | parser.feed(literal_text, raw=recursive) 394 | 395 | if field_name is not None: 396 | if field_name == "": 397 | if auto_arg_index is False: 398 | raise ValueError( 399 | "cannot switch from manual field " 400 | "specification to automatic field " 401 | "numbering" 402 | ) 403 | field_name = str(auto_arg_index) 404 | auto_arg_index += 1 405 | elif field_name.isdigit(): 406 | if auto_arg_index: 407 | raise ValueError( 408 | "cannot switch from manual field " 409 | "specification to automatic field " 410 | "numbering" 411 | ) 412 | auto_arg_index = False 413 | 414 | obj, _ = formatter.get_field(field_name, args, kwargs) 415 | obj = formatter.convert_field(obj, conversion) 416 | 417 | format_spec, auto_arg_index = Colorizer._parse_with_formatting( 418 | format_spec, 419 | args, 420 | kwargs, 421 | recursion_depth=recursion_depth - 1, 422 | auto_arg_index=auto_arg_index, 423 | recursive=True, 424 | ) 425 | 426 | formatted = formatter.format_field(obj, format_spec) 427 | parser.feed(formatted, raw=True) 428 | 429 | tokens = parser.done() 430 | 431 | if recursive: 432 | return AnsiParser.strip(tokens), auto_arg_index 433 | 434 | return tokens 435 | 436 | @staticmethod 437 | def _parse_without_formatting(string, *, recursion_depth=2, recursive=False): 438 | if recursion_depth < 0: 439 | raise ValueError("Max string recursion exceeded") 440 | 441 | formatter = Formatter() 442 | parser = AnsiParser() 443 | 444 | messages_color_tokens = [] 445 | 446 | for literal_text, field_name, format_spec, conversion in formatter.parse(string): 447 | if literal_text and literal_text[-1] in "{}": 448 | literal_text += literal_text[-1] 449 | 450 | parser.feed(literal_text, raw=recursive) 451 | 452 | if field_name is not None: 453 | if field_name == "message": 454 | if recursive: 455 | messages_color_tokens.append(None) 456 | else: 457 | color_tokens = parser.current_color_tokens() 458 | messages_color_tokens.append(color_tokens) 459 | field = "{%s" % field_name 460 | if conversion: 461 | field += "!%s" % conversion 462 | if format_spec: 463 | field += ":%s" % format_spec 464 | field += "}" 465 | parser.feed(field, raw=True) 466 | 467 | _, color_tokens = Colorizer._parse_without_formatting( 468 | format_spec, recursion_depth=recursion_depth - 1, recursive=True 469 | ) 470 | messages_color_tokens.extend(color_tokens) 471 | 472 | return parser.done(), messages_color_tokens 473 | -------------------------------------------------------------------------------- /loguru/_contextvars.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def load_contextvar_class(): 5 | if sys.version_info >= (3, 7): 6 | from contextvars import ContextVar 7 | elif sys.version_info >= (3, 5, 3): 8 | from aiocontextvars import ContextVar 9 | else: 10 | from contextvars import ContextVar 11 | 12 | return ContextVar 13 | 14 | 15 | ContextVar = load_contextvar_class() 16 | -------------------------------------------------------------------------------- /loguru/_ctime_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def load_ctime_functions(): 5 | if os.name == "nt": 6 | import win32_setctime 7 | 8 | def get_ctime_windows(filepath): 9 | return os.stat(filepath).st_ctime 10 | 11 | def set_ctime_windows(filepath, timestamp): 12 | if not win32_setctime.SUPPORTED: 13 | return 14 | 15 | try: 16 | win32_setctime.setctime(filepath, timestamp) 17 | except (OSError, ValueError): 18 | pass 19 | 20 | return get_ctime_windows, set_ctime_windows 21 | 22 | if hasattr(os.stat_result, "st_birthtime"): 23 | 24 | def get_ctime_macos(filepath): 25 | return os.stat(filepath).st_birthtime 26 | 27 | def set_ctime_macos(filepath, timestamp): 28 | pass 29 | 30 | return get_ctime_macos, set_ctime_macos 31 | 32 | if hasattr(os, "getxattr") and hasattr(os, "setxattr"): 33 | 34 | def get_ctime_linux(filepath): 35 | try: 36 | return float(os.getxattr(filepath, b"user.loguru_crtime")) 37 | except OSError: 38 | return os.stat(filepath).st_mtime 39 | 40 | def set_ctime_linux(filepath, timestamp): 41 | try: 42 | os.setxattr(filepath, b"user.loguru_crtime", str(timestamp).encode("ascii")) 43 | except OSError: 44 | pass 45 | 46 | return get_ctime_linux, set_ctime_linux 47 | 48 | def get_ctime_fallback(filepath): 49 | return os.stat(filepath).st_mtime 50 | 51 | def set_ctime_fallback(filepath, timestamp): 52 | pass 53 | 54 | return get_ctime_fallback, set_ctime_fallback 55 | 56 | 57 | get_ctime, set_ctime = load_ctime_functions() 58 | -------------------------------------------------------------------------------- /loguru/_datetime.py: -------------------------------------------------------------------------------- 1 | import re 2 | from calendar import day_abbr, day_name, month_abbr, month_name 3 | from datetime import datetime as datetime_ 4 | from datetime import timedelta, timezone 5 | from time import localtime, strftime 6 | 7 | tokens = r"H{1,2}|h{1,2}|m{1,2}|s{1,2}|S+|YYYY|YY|M{1,4}|D{1,4}|Z{1,2}|zz|A|X|x|E|Q|dddd|ddd|d" 8 | 9 | pattern = re.compile(r"(?:{0})|\[(?:{0}|!UTC|)\]".format(tokens)) 10 | 11 | 12 | class datetime(datetime_): # noqa: N801 13 | def __format__(self, spec): 14 | if spec.endswith("!UTC"): 15 | dt = self.astimezone(timezone.utc) 16 | spec = spec[:-4] 17 | else: 18 | dt = self 19 | 20 | if not spec: 21 | spec = "%Y-%m-%dT%H:%M:%S.%f%z" 22 | 23 | if "%" in spec: 24 | return datetime_.__format__(dt, spec) 25 | 26 | if "SSSSSSS" in spec: 27 | raise ValueError( 28 | "Invalid time format: the provided format string contains more than six successive " 29 | "'S' characters. This may be due to an attempt to use nanosecond precision, which " 30 | "is not supported." 31 | ) 32 | 33 | year, month, day, hour, minute, second, weekday, yearday, _ = dt.timetuple() 34 | microsecond = dt.microsecond 35 | timestamp = dt.timestamp() 36 | tzinfo = dt.tzinfo or timezone(timedelta(seconds=0)) 37 | offset = tzinfo.utcoffset(dt).total_seconds() 38 | sign = ("-", "+")[offset >= 0] 39 | (h, m), s = divmod(abs(offset // 60), 60), abs(offset) % 60 40 | 41 | rep = { 42 | "YYYY": "%04d" % year, 43 | "YY": "%02d" % (year % 100), 44 | "Q": "%d" % ((month - 1) // 3 + 1), 45 | "MMMM": month_name[month], 46 | "MMM": month_abbr[month], 47 | "MM": "%02d" % month, 48 | "M": "%d" % month, 49 | "DDDD": "%03d" % yearday, 50 | "DDD": "%d" % yearday, 51 | "DD": "%02d" % day, 52 | "D": "%d" % day, 53 | "dddd": day_name[weekday], 54 | "ddd": day_abbr[weekday], 55 | "d": "%d" % weekday, 56 | "E": "%d" % (weekday + 1), 57 | "HH": "%02d" % hour, 58 | "H": "%d" % hour, 59 | "hh": "%02d" % ((hour - 1) % 12 + 1), 60 | "h": "%d" % ((hour - 1) % 12 + 1), 61 | "mm": "%02d" % minute, 62 | "m": "%d" % minute, 63 | "ss": "%02d" % second, 64 | "s": "%d" % second, 65 | "S": "%d" % (microsecond // 100000), 66 | "SS": "%02d" % (microsecond // 10000), 67 | "SSS": "%03d" % (microsecond // 1000), 68 | "SSSS": "%04d" % (microsecond // 100), 69 | "SSSSS": "%05d" % (microsecond // 10), 70 | "SSSSSS": "%06d" % microsecond, 71 | "A": ("AM", "PM")[hour // 12], 72 | "Z": "%s%02d:%02d%s" % (sign, h, m, (":%09.06f" % s)[: 11 if s % 1 else 3] * (s > 0)), 73 | "ZZ": "%s%02d%02d%s" % (sign, h, m, ("%09.06f" % s)[: 10 if s % 1 else 2] * (s > 0)), 74 | "zz": tzinfo.tzname(dt) or "", 75 | "X": "%d" % timestamp, 76 | "x": "%d" % (int(timestamp) * 1000000 + microsecond), 77 | } 78 | 79 | def get(m): 80 | try: 81 | return rep[m.group(0)] 82 | except KeyError: 83 | return m.group(0)[1:-1] 84 | 85 | return pattern.sub(get, spec) 86 | 87 | 88 | def aware_now(): 89 | now = datetime_.now() 90 | timestamp = now.timestamp() 91 | local = localtime(timestamp) 92 | 93 | try: 94 | seconds = local.tm_gmtoff 95 | zone = local.tm_zone 96 | except AttributeError: 97 | # Workaround for Python 3.5. 98 | utc_naive = datetime_.fromtimestamp(timestamp, tz=timezone.utc).replace(tzinfo=None) 99 | offset = datetime_.fromtimestamp(timestamp) - utc_naive 100 | seconds = offset.total_seconds() 101 | zone = strftime("%Z") 102 | 103 | tzinfo = timezone(timedelta(seconds=seconds), zone) 104 | 105 | return datetime.combine(now.date(), now.time().replace(tzinfo=tzinfo)) 106 | -------------------------------------------------------------------------------- /loguru/_defaults.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | 3 | 4 | def env(key, type_, default=None): 5 | if key not in environ: 6 | return default 7 | 8 | val = environ[key] 9 | 10 | if isinstance(type_, str): 11 | return val 12 | if isinstance(type_, bool): 13 | if val.lower() in ["1", "true", "yes", "y", "ok", "on"]: 14 | return True 15 | if val.lower() in ["0", "false", "no", "n", "nok", "off"]: 16 | return False 17 | raise ValueError( 18 | "Invalid environment variable '%s' (expected a boolean): '%s'" % (key, val) 19 | ) 20 | if isinstance(type_, int): 21 | try: 22 | return int(val) 23 | except ValueError: 24 | raise ValueError( 25 | "Invalid environment variable '%s' (expected an integer): '%s'" % (key, val) 26 | ) from None 27 | raise ValueError("The requested type '%r' is not supported" % type_) 28 | 29 | 30 | LOGURU_AUTOINIT = env("LOGURU_AUTOINIT", bool, True) 31 | 32 | LOGURU_FORMAT = env( 33 | "LOGURU_FORMAT", 34 | str, 35 | "{time:YYYY-MM-DD HH:mm:ss.SSS} | " 36 | "{level: <8} | " 37 | "{name}:{function}:{line} - {message}", 38 | ) 39 | LOGURU_FILTER = env("LOGURU_FILTER", str, None) 40 | LOGURU_LEVEL = env("LOGURU_LEVEL", str, "DEBUG") 41 | LOGURU_COLORIZE = env("LOGURU_COLORIZE", bool, None) 42 | LOGURU_SERIALIZE = env("LOGURU_SERIALIZE", bool, False) 43 | LOGURU_BACKTRACE = env("LOGURU_BACKTRACE", bool, True) 44 | LOGURU_DIAGNOSE = env("LOGURU_DIAGNOSE", bool, True) 45 | LOGURU_ENQUEUE = env("LOGURU_ENQUEUE", bool, False) 46 | LOGURU_CONTEXT = env("LOGURU_CONTEXT", str, None) 47 | LOGURU_CATCH = env("LOGURU_CATCH", bool, True) 48 | 49 | LOGURU_TRACE_NO = env("LOGURU_TRACE_NO", int, 5) 50 | LOGURU_TRACE_COLOR = env("LOGURU_TRACE_COLOR", str, "") 51 | LOGURU_TRACE_ICON = env("LOGURU_TRACE_ICON", str, "\u270F\uFE0F") # Pencil 52 | 53 | LOGURU_DEBUG_NO = env("LOGURU_DEBUG_NO", int, 10) 54 | LOGURU_DEBUG_COLOR = env("LOGURU_DEBUG_COLOR", str, "") 55 | LOGURU_DEBUG_ICON = env("LOGURU_DEBUG_ICON", str, "\U0001F41E") # Lady Beetle 56 | 57 | LOGURU_INFO_NO = env("LOGURU_INFO_NO", int, 20) 58 | LOGURU_INFO_COLOR = env("LOGURU_INFO_COLOR", str, "") 59 | LOGURU_INFO_ICON = env("LOGURU_INFO_ICON", str, "\u2139\uFE0F") # Information 60 | 61 | LOGURU_SUCCESS_NO = env("LOGURU_SUCCESS_NO", int, 25) 62 | LOGURU_SUCCESS_COLOR = env("LOGURU_SUCCESS_COLOR", str, "") 63 | LOGURU_SUCCESS_ICON = env("LOGURU_SUCCESS_ICON", str, "\u2705") # White Heavy Check Mark 64 | 65 | LOGURU_WARNING_NO = env("LOGURU_WARNING_NO", int, 30) 66 | LOGURU_WARNING_COLOR = env("LOGURU_WARNING_COLOR", str, "") 67 | LOGURU_WARNING_ICON = env("LOGURU_WARNING_ICON", str, "\u26A0\uFE0F") # Warning 68 | 69 | LOGURU_ERROR_NO = env("LOGURU_ERROR_NO", int, 40) 70 | LOGURU_ERROR_COLOR = env("LOGURU_ERROR_COLOR", str, "") 71 | LOGURU_ERROR_ICON = env("LOGURU_ERROR_ICON", str, "\u274C") # Cross Mark 72 | 73 | LOGURU_CRITICAL_NO = env("LOGURU_CRITICAL_NO", int, 50) 74 | LOGURU_CRITICAL_COLOR = env("LOGURU_CRITICAL_COLOR", str, "") 75 | LOGURU_CRITICAL_ICON = env("LOGURU_CRITICAL_ICON", str, "\u2620\uFE0F") # Skull and Crossbones 76 | -------------------------------------------------------------------------------- /loguru/_error_interceptor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | 4 | 5 | class ErrorInterceptor: 6 | def __init__(self, should_catch, handler_id): 7 | self._should_catch = should_catch 8 | self._handler_id = handler_id 9 | 10 | def should_catch(self): 11 | return self._should_catch 12 | 13 | def print(self, record=None, *, exception=None): 14 | if not sys.stderr: 15 | return 16 | 17 | if exception is None: 18 | type_, value, traceback_ = sys.exc_info() 19 | else: 20 | type_, value, traceback_ = (type(exception), exception, exception.__traceback__) 21 | 22 | try: 23 | sys.stderr.write("--- Logging error in Loguru Handler #%d ---\n" % self._handler_id) 24 | try: 25 | record_repr = str(record) 26 | except Exception: 27 | record_repr = "/!\\ Unprintable record /!\\" 28 | sys.stderr.write("Record was: %s\n" % record_repr) 29 | traceback.print_exception(type_, value, traceback_, None, sys.stderr) 30 | sys.stderr.write("--- End of logging error ---\n") 31 | except OSError: 32 | pass 33 | finally: 34 | del type_, value, traceback_ 35 | -------------------------------------------------------------------------------- /loguru/_file_sink.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import decimal 3 | import glob 4 | import numbers 5 | import os 6 | import shutil 7 | import string 8 | from functools import partial 9 | from stat import ST_DEV, ST_INO 10 | 11 | from . import _string_parsers as string_parsers 12 | from ._ctime_functions import get_ctime, set_ctime 13 | from ._datetime import aware_now 14 | 15 | 16 | def generate_rename_path(root, ext, creation_time): 17 | creation_datetime = datetime.datetime.fromtimestamp(creation_time) 18 | date = FileDateFormatter(creation_datetime) 19 | 20 | renamed_path = "{}.{}{}".format(root, date, ext) 21 | counter = 1 22 | 23 | while os.path.exists(renamed_path): 24 | counter += 1 25 | renamed_path = "{}.{}.{}{}".format(root, date, counter, ext) 26 | 27 | return renamed_path 28 | 29 | 30 | class FileDateFormatter: 31 | def __init__(self, datetime=None): 32 | self.datetime = datetime or aware_now() 33 | 34 | def __format__(self, spec): 35 | if not spec: 36 | spec = "%Y-%m-%d_%H-%M-%S_%f" 37 | return self.datetime.__format__(spec) 38 | 39 | 40 | class Compression: 41 | @staticmethod 42 | def add_compress(path_in, path_out, opener, **kwargs): 43 | with opener(path_out, **kwargs) as f_comp: 44 | f_comp.add(path_in, os.path.basename(path_in)) 45 | 46 | @staticmethod 47 | def write_compress(path_in, path_out, opener, **kwargs): 48 | with opener(path_out, **kwargs) as f_comp: 49 | f_comp.write(path_in, os.path.basename(path_in)) 50 | 51 | @staticmethod 52 | def copy_compress(path_in, path_out, opener, **kwargs): 53 | with open(path_in, "rb") as f_in: 54 | with opener(path_out, **kwargs) as f_out: 55 | shutil.copyfileobj(f_in, f_out) 56 | 57 | @staticmethod 58 | def compression(path_in, ext, compress_function): 59 | path_out = "{}{}".format(path_in, ext) 60 | 61 | if os.path.exists(path_out): 62 | creation_time = get_ctime(path_out) 63 | root, ext_before = os.path.splitext(path_in) 64 | renamed_path = generate_rename_path(root, ext_before + ext, creation_time) 65 | os.rename(path_out, renamed_path) 66 | compress_function(path_in, path_out) 67 | os.remove(path_in) 68 | 69 | 70 | class Retention: 71 | @staticmethod 72 | def retention_count(logs, number): 73 | def key_log(log): 74 | return (-os.stat(log).st_mtime, log) 75 | 76 | for log in sorted(logs, key=key_log)[number:]: 77 | os.remove(log) 78 | 79 | @staticmethod 80 | def retention_age(logs, seconds): 81 | t = datetime.datetime.now().timestamp() 82 | for log in logs: 83 | if os.stat(log).st_mtime <= t - seconds: 84 | os.remove(log) 85 | 86 | 87 | class Rotation: 88 | @staticmethod 89 | def forward_day(t): 90 | return t + datetime.timedelta(days=1) 91 | 92 | @staticmethod 93 | def forward_weekday(t, weekday): 94 | while True: 95 | t += datetime.timedelta(days=1) 96 | if t.weekday() == weekday: 97 | return t 98 | 99 | @staticmethod 100 | def forward_interval(t, interval): 101 | return t + interval 102 | 103 | @staticmethod 104 | def rotation_size(message, file, size_limit): 105 | file.seek(0, 2) 106 | return file.tell() + len(message) > size_limit 107 | 108 | class RotationTime: 109 | def __init__(self, step_forward, time_init=None): 110 | self._step_forward = step_forward 111 | self._time_init = time_init 112 | self._limit = None 113 | 114 | def __call__(self, message, file): 115 | record_time = message.record["time"] 116 | 117 | if self._limit is None: 118 | filepath = os.path.realpath(file.name) 119 | creation_time = get_ctime(filepath) 120 | set_ctime(filepath, creation_time) 121 | start_time = datetime.datetime.fromtimestamp( 122 | creation_time, tz=datetime.timezone.utc 123 | ) 124 | 125 | time_init = self._time_init 126 | 127 | if time_init is None: 128 | limit = start_time.astimezone(record_time.tzinfo).replace(tzinfo=None) 129 | limit = self._step_forward(limit) 130 | else: 131 | tzinfo = record_time.tzinfo if time_init.tzinfo is None else time_init.tzinfo 132 | limit = start_time.astimezone(tzinfo).replace( 133 | hour=time_init.hour, 134 | minute=time_init.minute, 135 | second=time_init.second, 136 | microsecond=time_init.microsecond, 137 | ) 138 | 139 | if limit <= start_time: 140 | limit = self._step_forward(limit) 141 | 142 | if time_init.tzinfo is None: 143 | limit = limit.replace(tzinfo=None) 144 | 145 | self._limit = limit 146 | 147 | if self._limit.tzinfo is None: 148 | record_time = record_time.replace(tzinfo=None) 149 | 150 | if record_time >= self._limit: 151 | while self._limit <= record_time: 152 | self._limit = self._step_forward(self._limit) 153 | return True 154 | return False 155 | 156 | 157 | class FileSink: 158 | def __init__( 159 | self, 160 | path, 161 | *, 162 | rotation=None, 163 | retention=None, 164 | compression=None, 165 | delay=False, 166 | watch=False, 167 | mode="a", 168 | buffering=1, 169 | encoding="utf8", 170 | **kwargs 171 | ): 172 | self.encoding = encoding 173 | 174 | self._kwargs = {**kwargs, "mode": mode, "buffering": buffering, "encoding": self.encoding} 175 | self._path = str(path) 176 | 177 | self._glob_patterns = self._make_glob_patterns(self._path) 178 | self._rotation_function = self._make_rotation_function(rotation) 179 | self._retention_function = self._make_retention_function(retention) 180 | self._compression_function = self._make_compression_function(compression) 181 | 182 | self._file = None 183 | self._file_path = None 184 | 185 | self._watch = watch 186 | self._file_dev = -1 187 | self._file_ino = -1 188 | 189 | if not delay: 190 | path = self._create_path() 191 | self._create_dirs(path) 192 | self._create_file(path) 193 | 194 | def write(self, message): 195 | if self._file is None: 196 | path = self._create_path() 197 | self._create_dirs(path) 198 | self._create_file(path) 199 | 200 | if self._watch: 201 | self._reopen_if_needed() 202 | 203 | if self._rotation_function is not None and self._rotation_function(message, self._file): 204 | self._terminate_file(is_rotating=True) 205 | 206 | self._file.write(message) 207 | 208 | def stop(self): 209 | if self._watch: 210 | self._reopen_if_needed() 211 | 212 | self._terminate_file(is_rotating=False) 213 | 214 | def tasks_to_complete(self): 215 | return [] 216 | 217 | def _create_path(self): 218 | path = self._path.format_map({"time": FileDateFormatter()}) 219 | return os.path.abspath(path) 220 | 221 | def _create_dirs(self, path): 222 | dirname = os.path.dirname(path) 223 | os.makedirs(dirname, exist_ok=True) 224 | 225 | def _create_file(self, path): 226 | self._file = open(path, **self._kwargs) 227 | self._file_path = path 228 | 229 | if self._watch: 230 | fileno = self._file.fileno() 231 | result = os.fstat(fileno) 232 | self._file_dev = result[ST_DEV] 233 | self._file_ino = result[ST_INO] 234 | 235 | def _close_file(self): 236 | self._file.flush() 237 | self._file.close() 238 | 239 | self._file = None 240 | self._file_path = None 241 | self._file_dev = -1 242 | self._file_ino = -1 243 | 244 | def _reopen_if_needed(self): 245 | # Implemented based on standard library: 246 | # https://github.com/python/cpython/blob/cb589d1b/Lib/logging/handlers.py#L486 247 | if not self._file: 248 | return 249 | 250 | filepath = self._file_path 251 | 252 | try: 253 | result = os.stat(filepath) 254 | except FileNotFoundError: 255 | result = None 256 | 257 | if not result or result[ST_DEV] != self._file_dev or result[ST_INO] != self._file_ino: 258 | self._close_file() 259 | self._create_dirs(filepath) 260 | self._create_file(filepath) 261 | 262 | def _terminate_file(self, *, is_rotating=False): 263 | old_path = self._file_path 264 | 265 | if self._file is not None: 266 | self._close_file() 267 | 268 | if is_rotating: 269 | new_path = self._create_path() 270 | self._create_dirs(new_path) 271 | 272 | if new_path == old_path: 273 | creation_time = get_ctime(old_path) 274 | root, ext = os.path.splitext(old_path) 275 | renamed_path = generate_rename_path(root, ext, creation_time) 276 | os.rename(old_path, renamed_path) 277 | old_path = renamed_path 278 | 279 | if is_rotating or self._rotation_function is None: 280 | if self._compression_function is not None and old_path is not None: 281 | self._compression_function(old_path) 282 | 283 | if self._retention_function is not None: 284 | logs = { 285 | file 286 | for pattern in self._glob_patterns 287 | for file in glob.glob(pattern) 288 | if os.path.isfile(file) 289 | } 290 | self._retention_function(list(logs)) 291 | 292 | if is_rotating: 293 | self._create_file(new_path) 294 | set_ctime(new_path, datetime.datetime.now().timestamp()) 295 | 296 | @staticmethod 297 | def _make_glob_patterns(path): 298 | formatter = string.Formatter() 299 | tokens = formatter.parse(path) 300 | escaped = "".join(glob.escape(text) + "*" * (name is not None) for text, name, *_ in tokens) 301 | 302 | root, ext = os.path.splitext(escaped) 303 | 304 | if not ext: 305 | return [escaped, escaped + ".*"] 306 | 307 | return [escaped, escaped + ".*", root + ".*" + ext, root + ".*" + ext + ".*"] 308 | 309 | @staticmethod 310 | def _make_rotation_function(rotation): 311 | if rotation is None: 312 | return None 313 | if isinstance(rotation, str): 314 | size = string_parsers.parse_size(rotation) 315 | if size is not None: 316 | return FileSink._make_rotation_function(size) 317 | interval = string_parsers.parse_duration(rotation) 318 | if interval is not None: 319 | return FileSink._make_rotation_function(interval) 320 | frequency = string_parsers.parse_frequency(rotation) 321 | if frequency is not None: 322 | return Rotation.RotationTime(frequency) 323 | daytime = string_parsers.parse_daytime(rotation) 324 | if daytime is not None: 325 | day, time = daytime 326 | if day is None: 327 | return FileSink._make_rotation_function(time) 328 | if time is None: 329 | time = datetime.time(0, 0, 0) 330 | step_forward = partial(Rotation.forward_weekday, weekday=day) 331 | return Rotation.RotationTime(step_forward, time) 332 | raise ValueError("Cannot parse rotation from: '%s'" % rotation) 333 | if isinstance(rotation, (numbers.Real, decimal.Decimal)): 334 | return partial(Rotation.rotation_size, size_limit=rotation) 335 | if isinstance(rotation, datetime.time): 336 | return Rotation.RotationTime(Rotation.forward_day, rotation) 337 | if isinstance(rotation, datetime.timedelta): 338 | step_forward = partial(Rotation.forward_interval, interval=rotation) 339 | return Rotation.RotationTime(step_forward) 340 | if callable(rotation): 341 | return rotation 342 | raise TypeError("Cannot infer rotation for objects of type: '%s'" % type(rotation).__name__) 343 | 344 | @staticmethod 345 | def _make_retention_function(retention): 346 | if retention is None: 347 | return None 348 | if isinstance(retention, str): 349 | interval = string_parsers.parse_duration(retention) 350 | if interval is None: 351 | raise ValueError("Cannot parse retention from: '%s'" % retention) 352 | return FileSink._make_retention_function(interval) 353 | if isinstance(retention, int): 354 | return partial(Retention.retention_count, number=retention) 355 | if isinstance(retention, datetime.timedelta): 356 | return partial(Retention.retention_age, seconds=retention.total_seconds()) 357 | if callable(retention): 358 | return retention 359 | raise TypeError( 360 | "Cannot infer retention for objects of type: '%s'" % type(retention).__name__ 361 | ) 362 | 363 | @staticmethod 364 | def _make_compression_function(compression): 365 | if compression is None: 366 | return None 367 | if isinstance(compression, str): 368 | ext = compression.strip().lstrip(".") 369 | 370 | if ext == "gz": 371 | import gzip 372 | 373 | compress = partial(Compression.copy_compress, opener=gzip.open, mode="wb") 374 | elif ext == "bz2": 375 | import bz2 376 | 377 | compress = partial(Compression.copy_compress, opener=bz2.open, mode="wb") 378 | 379 | elif ext == "xz": 380 | import lzma 381 | 382 | compress = partial( 383 | Compression.copy_compress, opener=lzma.open, mode="wb", format=lzma.FORMAT_XZ 384 | ) 385 | 386 | elif ext == "lzma": 387 | import lzma 388 | 389 | compress = partial( 390 | Compression.copy_compress, opener=lzma.open, mode="wb", format=lzma.FORMAT_ALONE 391 | ) 392 | elif ext == "tar": 393 | import tarfile 394 | 395 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:") 396 | elif ext == "tar.gz": 397 | import gzip 398 | import tarfile 399 | 400 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:gz") 401 | elif ext == "tar.bz2": 402 | import bz2 403 | import tarfile 404 | 405 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:bz2") 406 | 407 | elif ext == "tar.xz": 408 | import lzma 409 | import tarfile 410 | 411 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:xz") 412 | elif ext == "zip": 413 | import zipfile 414 | 415 | compress = partial( 416 | Compression.write_compress, 417 | opener=zipfile.ZipFile, 418 | mode="w", 419 | compression=zipfile.ZIP_DEFLATED, 420 | ) 421 | else: 422 | raise ValueError("Invalid compression format: '%s'" % ext) 423 | 424 | return partial(Compression.compression, ext="." + ext, compress_function=compress) 425 | if callable(compression): 426 | return compression 427 | raise TypeError( 428 | "Cannot infer compression for objects of type: '%s'" % type(compression).__name__ 429 | ) 430 | -------------------------------------------------------------------------------- /loguru/_filters.py: -------------------------------------------------------------------------------- 1 | def filter_none(record): 2 | return record["name"] is not None 3 | 4 | 5 | def filter_by_name(record, parent, length): 6 | name = record["name"] 7 | if name is None: 8 | return False 9 | return (name + ".")[:length] == parent 10 | 11 | 12 | def filter_by_level(record, level_per_module): 13 | name = record["name"] 14 | 15 | while True: 16 | level = level_per_module.get(name, None) 17 | if level is False: 18 | return False 19 | if level is not None: 20 | return record["level"].no >= level 21 | if not name: 22 | return True 23 | index = name.rfind(".") 24 | name = name[:index] if index != -1 else "" 25 | -------------------------------------------------------------------------------- /loguru/_get_frame.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from sys import exc_info 3 | 4 | 5 | def get_frame_fallback(n): 6 | try: 7 | raise Exception 8 | except Exception: 9 | frame = exc_info()[2].tb_frame.f_back 10 | for _ in range(n): 11 | frame = frame.f_back 12 | return frame 13 | 14 | 15 | def load_get_frame_function(): 16 | if hasattr(sys, "_getframe"): 17 | get_frame = sys._getframe 18 | else: 19 | get_frame = get_frame_fallback 20 | return get_frame 21 | 22 | 23 | get_frame = load_get_frame_function() 24 | -------------------------------------------------------------------------------- /loguru/_handler.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import multiprocessing 4 | import os 5 | import threading 6 | from contextlib import contextmanager 7 | from threading import Thread 8 | 9 | from ._colorizer import Colorizer 10 | from ._locks_machinery import create_handler_lock 11 | 12 | 13 | def prepare_colored_format(format_, ansi_level): 14 | colored = Colorizer.prepare_format(format_) 15 | return colored, colored.colorize(ansi_level) 16 | 17 | 18 | def prepare_stripped_format(format_): 19 | colored = Colorizer.prepare_format(format_) 20 | return colored.strip() 21 | 22 | 23 | def memoize(function): 24 | return functools.lru_cache(maxsize=64)(function) 25 | 26 | 27 | class Message(str): 28 | __slots__ = ("record",) 29 | 30 | 31 | class Handler: 32 | def __init__( 33 | self, 34 | *, 35 | sink, 36 | name, 37 | levelno, 38 | formatter, 39 | is_formatter_dynamic, 40 | filter_, 41 | colorize, 42 | serialize, 43 | enqueue, 44 | multiprocessing_context, 45 | error_interceptor, 46 | exception_formatter, 47 | id_, 48 | levels_ansi_codes 49 | ): 50 | self._name = name 51 | self._sink = sink 52 | self._levelno = levelno 53 | self._formatter = formatter 54 | self._is_formatter_dynamic = is_formatter_dynamic 55 | self._filter = filter_ 56 | self._colorize = colorize 57 | self._serialize = serialize 58 | self._enqueue = enqueue 59 | self._multiprocessing_context = multiprocessing_context 60 | self._error_interceptor = error_interceptor 61 | self._exception_formatter = exception_formatter 62 | self._id = id_ 63 | self._levels_ansi_codes = levels_ansi_codes # Warning, reference shared among handlers 64 | 65 | self._decolorized_format = None 66 | self._precolorized_formats = {} 67 | self._memoize_dynamic_format = None 68 | 69 | self._stopped = False 70 | self._lock = create_handler_lock() 71 | self._lock_acquired = threading.local() 72 | self._queue = None 73 | self._queue_lock = None 74 | self._confirmation_event = None 75 | self._confirmation_lock = None 76 | self._owner_process_pid = None 77 | self._thread = None 78 | 79 | if self._is_formatter_dynamic: 80 | if self._colorize: 81 | self._memoize_dynamic_format = memoize(prepare_colored_format) 82 | else: 83 | self._memoize_dynamic_format = memoize(prepare_stripped_format) 84 | else: 85 | if self._colorize: 86 | for level_name in self._levels_ansi_codes: 87 | self.update_format(level_name) 88 | else: 89 | self._decolorized_format = self._formatter.strip() 90 | 91 | if self._enqueue: 92 | if self._multiprocessing_context is None: 93 | self._queue = multiprocessing.SimpleQueue() 94 | self._confirmation_event = multiprocessing.Event() 95 | self._confirmation_lock = multiprocessing.Lock() 96 | else: 97 | self._queue = self._multiprocessing_context.SimpleQueue() 98 | self._confirmation_event = self._multiprocessing_context.Event() 99 | self._confirmation_lock = self._multiprocessing_context.Lock() 100 | self._queue_lock = create_handler_lock() 101 | self._owner_process_pid = os.getpid() 102 | self._thread = Thread( 103 | target=self._queued_writer, daemon=True, name="loguru-writer-%d" % self._id 104 | ) 105 | self._thread.start() 106 | 107 | def __repr__(self): 108 | return "(id=%d, level=%d, sink=%s)" % (self._id, self._levelno, self._name) 109 | 110 | @contextmanager 111 | def _protected_lock(self): 112 | """Acquire the lock, but fail fast if its already acquired by the current thread.""" 113 | if getattr(self._lock_acquired, "acquired", False): 114 | raise RuntimeError( 115 | "Could not acquire internal lock because it was already in use (deadlock avoided). " 116 | "This likely happened because the logger was re-used inside a sink, a signal " 117 | "handler or a '__del__' method. This is not permitted because the logger and its " 118 | "handlers are not re-entrant." 119 | ) 120 | self._lock_acquired.acquired = True 121 | try: 122 | with self._lock: 123 | yield 124 | finally: 125 | self._lock_acquired.acquired = False 126 | 127 | def emit(self, record, level_id, from_decorator, is_raw, colored_message): 128 | try: 129 | if self._levelno > record["level"].no: 130 | return 131 | 132 | if self._filter is not None: 133 | if not self._filter(record): 134 | return 135 | 136 | if self._is_formatter_dynamic: 137 | dynamic_format = self._formatter(record) 138 | 139 | formatter_record = record.copy() 140 | 141 | if not record["exception"]: 142 | formatter_record["exception"] = "" 143 | else: 144 | type_, value, tb = record["exception"] 145 | formatter = self._exception_formatter 146 | lines = formatter.format_exception(type_, value, tb, from_decorator=from_decorator) 147 | formatter_record["exception"] = "".join(lines) 148 | 149 | if colored_message is not None and colored_message.stripped != record["message"]: 150 | colored_message = None 151 | 152 | if is_raw: 153 | if colored_message is None or not self._colorize: 154 | formatted = record["message"] 155 | else: 156 | ansi_level = self._levels_ansi_codes[level_id] 157 | formatted = colored_message.colorize(ansi_level) 158 | elif self._is_formatter_dynamic: 159 | if not self._colorize: 160 | precomputed_format = self._memoize_dynamic_format(dynamic_format) 161 | formatted = precomputed_format.format_map(formatter_record) 162 | elif colored_message is None: 163 | ansi_level = self._levels_ansi_codes[level_id] 164 | _, precomputed_format = self._memoize_dynamic_format(dynamic_format, ansi_level) 165 | formatted = precomputed_format.format_map(formatter_record) 166 | else: 167 | ansi_level = self._levels_ansi_codes[level_id] 168 | formatter, precomputed_format = self._memoize_dynamic_format( 169 | dynamic_format, ansi_level 170 | ) 171 | coloring_message = formatter.make_coloring_message( 172 | record["message"], ansi_level=ansi_level, colored_message=colored_message 173 | ) 174 | formatter_record["message"] = coloring_message 175 | formatted = precomputed_format.format_map(formatter_record) 176 | 177 | else: 178 | if not self._colorize: 179 | precomputed_format = self._decolorized_format 180 | formatted = precomputed_format.format_map(formatter_record) 181 | elif colored_message is None: 182 | ansi_level = self._levels_ansi_codes[level_id] 183 | precomputed_format = self._precolorized_formats[level_id] 184 | formatted = precomputed_format.format_map(formatter_record) 185 | else: 186 | ansi_level = self._levels_ansi_codes[level_id] 187 | precomputed_format = self._precolorized_formats[level_id] 188 | coloring_message = self._formatter.make_coloring_message( 189 | record["message"], ansi_level=ansi_level, colored_message=colored_message 190 | ) 191 | formatter_record["message"] = coloring_message 192 | formatted = precomputed_format.format_map(formatter_record) 193 | 194 | if self._serialize: 195 | formatted = self._serialize_record(formatted, record) 196 | 197 | str_record = Message(formatted) 198 | str_record.record = record 199 | 200 | with self._protected_lock(): 201 | if self._stopped: 202 | return 203 | if self._enqueue: 204 | self._queue.put(str_record) 205 | else: 206 | self._sink.write(str_record) 207 | except Exception: 208 | if not self._error_interceptor.should_catch(): 209 | raise 210 | self._error_interceptor.print(record) 211 | 212 | def stop(self): 213 | with self._protected_lock(): 214 | self._stopped = True 215 | if self._enqueue: 216 | if self._owner_process_pid != os.getpid(): 217 | return 218 | self._queue.put(None) 219 | self._thread.join() 220 | if hasattr(self._queue, "close"): 221 | self._queue.close() 222 | 223 | self._sink.stop() 224 | 225 | def complete_queue(self): 226 | if not self._enqueue: 227 | return 228 | 229 | with self._confirmation_lock: 230 | self._queue.put(True) 231 | self._confirmation_event.wait() 232 | self._confirmation_event.clear() 233 | 234 | def tasks_to_complete(self): 235 | if self._enqueue and self._owner_process_pid != os.getpid(): 236 | return [] 237 | lock = self._queue_lock if self._enqueue else self._protected_lock() 238 | with lock: 239 | return self._sink.tasks_to_complete() 240 | 241 | def update_format(self, level_id): 242 | if not self._colorize or self._is_formatter_dynamic: 243 | return 244 | ansi_code = self._levels_ansi_codes[level_id] 245 | self._precolorized_formats[level_id] = self._formatter.colorize(ansi_code) 246 | 247 | @property 248 | def levelno(self): 249 | return self._levelno 250 | 251 | @staticmethod 252 | def _serialize_record(text, record): 253 | exception = record["exception"] 254 | 255 | if exception is not None: 256 | exception = { 257 | "type": None if exception.type is None else exception.type.__name__, 258 | "value": exception.value, 259 | "traceback": bool(exception.traceback), 260 | } 261 | 262 | serializable = { 263 | "text": text, 264 | "record": { 265 | "elapsed": { 266 | "repr": record["elapsed"], 267 | "seconds": record["elapsed"].total_seconds(), 268 | }, 269 | "exception": exception, 270 | "extra": record["extra"], 271 | "file": {"name": record["file"].name, "path": record["file"].path}, 272 | "function": record["function"], 273 | "level": { 274 | "icon": record["level"].icon, 275 | "name": record["level"].name, 276 | "no": record["level"].no, 277 | }, 278 | "line": record["line"], 279 | "message": record["message"], 280 | "module": record["module"], 281 | "name": record["name"], 282 | "process": {"id": record["process"].id, "name": record["process"].name}, 283 | "thread": {"id": record["thread"].id, "name": record["thread"].name}, 284 | "time": {"repr": record["time"], "timestamp": record["time"].timestamp()}, 285 | }, 286 | } 287 | 288 | return json.dumps(serializable, default=str, ensure_ascii=False) + "\n" 289 | 290 | def _queued_writer(self): 291 | message = None 292 | queue = self._queue 293 | 294 | # We need to use a lock to protect sink during fork. 295 | # Particularly, writing to stderr may lead to deadlock in child process. 296 | lock = self._queue_lock 297 | 298 | while True: 299 | try: 300 | message = queue.get() 301 | except Exception: 302 | with lock: 303 | self._error_interceptor.print(None) 304 | continue 305 | 306 | if message is None: 307 | break 308 | 309 | if message is True: 310 | self._confirmation_event.set() 311 | continue 312 | 313 | with lock: 314 | try: 315 | self._sink.write(message) 316 | except Exception: 317 | self._error_interceptor.print(message.record) 318 | 319 | def __getstate__(self): 320 | state = self.__dict__.copy() 321 | state["_lock"] = None 322 | state["_lock_acquired"] = None 323 | state["_memoize_dynamic_format"] = None 324 | if self._enqueue: 325 | state["_sink"] = None 326 | state["_thread"] = None 327 | state["_owner_process"] = None 328 | state["_queue_lock"] = None 329 | return state 330 | 331 | def __setstate__(self, state): 332 | self.__dict__.update(state) 333 | self._lock = create_handler_lock() 334 | self._lock_acquired = threading.local() 335 | if self._enqueue: 336 | self._queue_lock = create_handler_lock() 337 | if self._is_formatter_dynamic: 338 | if self._colorize: 339 | self._memoize_dynamic_format = memoize(prepare_colored_format) 340 | else: 341 | self._memoize_dynamic_format = memoize(prepare_stripped_format) 342 | -------------------------------------------------------------------------------- /loguru/_locks_machinery.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import weakref 4 | 5 | if not hasattr(os, "register_at_fork"): 6 | 7 | def create_logger_lock(): 8 | return threading.Lock() 9 | 10 | def create_handler_lock(): 11 | return threading.Lock() 12 | 13 | else: 14 | # While forking, we need to sanitize all locks to make sure the child process doesn't run into 15 | # a deadlock (if a lock already acquired is inherited) and to protect sink from corrupted state. 16 | # It's very important to acquire logger locks before handlers one to prevent possible deadlock 17 | # while 'remove()' is called for example. 18 | 19 | logger_locks = weakref.WeakSet() 20 | handler_locks = weakref.WeakSet() 21 | 22 | def acquire_locks(): 23 | for lock in logger_locks: 24 | lock.acquire() 25 | 26 | for lock in handler_locks: 27 | lock.acquire() 28 | 29 | def release_locks(): 30 | for lock in logger_locks: 31 | lock.release() 32 | 33 | for lock in handler_locks: 34 | lock.release() 35 | 36 | os.register_at_fork( 37 | before=acquire_locks, 38 | after_in_parent=release_locks, 39 | after_in_child=release_locks, 40 | ) 41 | 42 | def create_logger_lock(): 43 | lock = threading.Lock() 44 | logger_locks.add(lock) 45 | return lock 46 | 47 | def create_handler_lock(): 48 | lock = threading.Lock() 49 | handler_locks.add(lock) 50 | return lock 51 | -------------------------------------------------------------------------------- /loguru/_recattrs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import namedtuple 3 | 4 | 5 | class RecordLevel: 6 | __slots__ = ("name", "no", "icon") 7 | 8 | def __init__(self, name, no, icon): 9 | self.name = name 10 | self.no = no 11 | self.icon = icon 12 | 13 | def __repr__(self): 14 | return "(name=%r, no=%r, icon=%r)" % (self.name, self.no, self.icon) 15 | 16 | def __format__(self, spec): 17 | return self.name.__format__(spec) 18 | 19 | 20 | class RecordFile: 21 | __slots__ = ("name", "path") 22 | 23 | def __init__(self, name, path): 24 | self.name = name 25 | self.path = path 26 | 27 | def __repr__(self): 28 | return "(name=%r, path=%r)" % (self.name, self.path) 29 | 30 | def __format__(self, spec): 31 | return self.name.__format__(spec) 32 | 33 | 34 | class RecordThread: 35 | __slots__ = ("id", "name") 36 | 37 | def __init__(self, id_, name): 38 | self.id = id_ 39 | self.name = name 40 | 41 | def __repr__(self): 42 | return "(id=%r, name=%r)" % (self.id, self.name) 43 | 44 | def __format__(self, spec): 45 | return self.id.__format__(spec) 46 | 47 | 48 | class RecordProcess: 49 | __slots__ = ("id", "name") 50 | 51 | def __init__(self, id_, name): 52 | self.id = id_ 53 | self.name = name 54 | 55 | def __repr__(self): 56 | return "(id=%r, name=%r)" % (self.id, self.name) 57 | 58 | def __format__(self, spec): 59 | return self.id.__format__(spec) 60 | 61 | 62 | class RecordException(namedtuple("RecordException", ("type", "value", "traceback"))): 63 | def __repr__(self): 64 | return "(type=%r, value=%r, traceback=%r)" % (self.type, self.value, self.traceback) 65 | 66 | def __reduce__(self): 67 | # The traceback is not picklable, therefore it needs to be removed. Additionally, there's a 68 | # possibility that the exception value is not picklable either. In such cases, we also need 69 | # to remove it. This is done for user convenience, aiming to prevent error logging caused by 70 | # custom exceptions from third-party libraries. If the serialization succeeds, we can reuse 71 | # the pickled value later for optimization (so that it's not pickled twice). It's important 72 | # to note that custom exceptions might not necessarily raise a PickleError, hence the 73 | # generic Exception catch. 74 | try: 75 | pickled_value = pickle.dumps(self.value) 76 | except Exception: 77 | return (RecordException, (self.type, None, None)) 78 | else: 79 | return (RecordException._from_pickled_value, (self.type, pickled_value, None)) 80 | 81 | @classmethod 82 | def _from_pickled_value(cls, type_, pickled_value, traceback_): 83 | try: 84 | # It's safe to use "pickle.loads()" in this case because the pickled value is generated 85 | # by the same code and is not coming from an untrusted source. 86 | value = pickle.loads(pickled_value) 87 | except Exception: 88 | return cls(type_, None, traceback_) 89 | else: 90 | return cls(type_, value, traceback_) 91 | -------------------------------------------------------------------------------- /loguru/_simple_sinks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import picologging as logging 3 | import weakref 4 | 5 | from ._asyncio_loop import get_running_loop, get_task_loop 6 | 7 | 8 | class StreamSink: 9 | def __init__(self, stream): 10 | self._stream = stream 11 | self._flushable = callable(getattr(stream, "flush", None)) 12 | self._stoppable = callable(getattr(stream, "stop", None)) 13 | self._completable = asyncio.iscoroutinefunction(getattr(stream, "complete", None)) 14 | 15 | def write(self, message): 16 | self._stream.write(message) 17 | if self._flushable: 18 | self._stream.flush() 19 | 20 | def stop(self): 21 | if self._stoppable: 22 | self._stream.stop() 23 | 24 | def tasks_to_complete(self): 25 | if not self._completable: 26 | return [] 27 | return [self._stream.complete()] 28 | 29 | 30 | class StandardSink: 31 | def __init__(self, handler): 32 | self._handler = handler 33 | 34 | def write(self, message): 35 | record = message.record 36 | message = str(message) 37 | exc = record["exception"] 38 | record = logging.getLogger().makeRecord( 39 | record["name"], 40 | record["level"].no, 41 | record["file"].path, 42 | record["line"], 43 | message, 44 | (), 45 | (exc.type, exc.value, exc.traceback) if exc else None, 46 | record["function"], 47 | {"extra": record["extra"]}, 48 | ) 49 | if exc: 50 | record.exc_text = "\n" 51 | self._handler.handle(record) 52 | 53 | def stop(self): 54 | self._handler.close() 55 | 56 | def tasks_to_complete(self): 57 | return [] 58 | 59 | 60 | class AsyncSink: 61 | def __init__(self, function, loop, error_interceptor): 62 | self._function = function 63 | self._loop = loop 64 | self._error_interceptor = error_interceptor 65 | self._tasks = weakref.WeakSet() 66 | 67 | def write(self, message): 68 | try: 69 | loop = self._loop or get_running_loop() 70 | except RuntimeError: 71 | return 72 | 73 | coroutine = self._function(message) 74 | task = loop.create_task(coroutine) 75 | 76 | def check_exception(future): 77 | if future.cancelled() or future.exception() is None: 78 | return 79 | if not self._error_interceptor.should_catch(): 80 | raise future.exception() 81 | self._error_interceptor.print(message.record, exception=future.exception()) 82 | 83 | task.add_done_callback(check_exception) 84 | self._tasks.add(task) 85 | 86 | def stop(self): 87 | for task in self._tasks: 88 | task.cancel() 89 | 90 | def tasks_to_complete(self): 91 | # To avoid errors due to "self._tasks" being mutated while iterated, the 92 | # "tasks_to_complete()" method must be protected by the same lock as "write()" (which 93 | # happens to be the handler lock). However, the tasks must not be awaited while the lock is 94 | # acquired as this could lead to a deadlock. Therefore, we first need to collect the tasks 95 | # to complete, then return them so that they can be awaited outside of the lock. 96 | return [self._complete_task(task) for task in self._tasks] 97 | 98 | async def _complete_task(self, task): 99 | loop = get_running_loop() 100 | if get_task_loop(task) is not loop: 101 | return 102 | try: 103 | await task 104 | except Exception: 105 | pass # Handled in "check_exception()" 106 | 107 | def __getstate__(self): 108 | state = self.__dict__.copy() 109 | state["_tasks"] = None 110 | return state 111 | 112 | def __setstate__(self, state): 113 | self.__dict__.update(state) 114 | self._tasks = weakref.WeakSet() 115 | 116 | 117 | class CallableSink: 118 | def __init__(self, function): 119 | self._function = function 120 | 121 | def write(self, message): 122 | self._function(message) 123 | 124 | def stop(self): 125 | pass 126 | 127 | def tasks_to_complete(self): 128 | return [] 129 | -------------------------------------------------------------------------------- /loguru/_string_parsers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import re 3 | 4 | 5 | class Frequencies: 6 | @staticmethod 7 | def hourly(t): 8 | dt = t + datetime.timedelta(hours=1) 9 | return dt.replace(minute=0, second=0, microsecond=0) 10 | 11 | @staticmethod 12 | def daily(t): 13 | dt = t + datetime.timedelta(days=1) 14 | return dt.replace(hour=0, minute=0, second=0, microsecond=0) 15 | 16 | @staticmethod 17 | def weekly(t): 18 | dt = t + datetime.timedelta(days=7 - t.weekday()) 19 | return dt.replace(hour=0, minute=0, second=0, microsecond=0) 20 | 21 | @staticmethod 22 | def monthly(t): 23 | if t.month == 12: 24 | y, m = t.year + 1, 1 25 | else: 26 | y, m = t.year, t.month + 1 27 | return t.replace(year=y, month=m, day=1, hour=0, minute=0, second=0, microsecond=0) 28 | 29 | @staticmethod 30 | def yearly(t): 31 | y = t.year + 1 32 | return t.replace(year=y, month=1, day=1, hour=0, minute=0, second=0, microsecond=0) 33 | 34 | 35 | def parse_size(size): 36 | size = size.strip() 37 | reg = re.compile(r"([e\+\-\.\d]+)\s*([kmgtpezy])?(i)?(b)", flags=re.I) 38 | 39 | match = reg.fullmatch(size) 40 | 41 | if not match: 42 | return None 43 | 44 | s, u, i, b = match.groups() 45 | 46 | try: 47 | s = float(s) 48 | except ValueError as e: 49 | raise ValueError("Invalid float value while parsing size: '%s'" % s) from e 50 | 51 | u = "kmgtpezy".index(u.lower()) + 1 if u else 0 52 | i = 1024 if i else 1000 53 | b = {"b": 8, "B": 1}[b] if b else 1 54 | return s * i**u / b 55 | 56 | 57 | def parse_duration(duration): 58 | duration = duration.strip() 59 | reg = r"(?:([e\+\-\.\d]+)\s*([a-z]+)[\s\,]*)" 60 | 61 | units = [ 62 | ("y|years?", 31536000), 63 | ("months?", 2628000), 64 | ("w|weeks?", 604800), 65 | ("d|days?", 86400), 66 | ("h|hours?", 3600), 67 | ("min(?:ute)?s?", 60), 68 | ("s|sec(?:ond)?s?", 1), # spellchecker: disable-line 69 | ("ms|milliseconds?", 0.001), 70 | ("us|microseconds?", 0.000001), 71 | ] 72 | 73 | if not re.fullmatch(reg + "+", duration, flags=re.I): 74 | return None 75 | 76 | seconds = 0 77 | 78 | for value, unit in re.findall(reg, duration, flags=re.I): 79 | try: 80 | value = float(value) 81 | except ValueError as e: 82 | raise ValueError("Invalid float value while parsing duration: '%s'" % value) from e 83 | 84 | try: 85 | unit = next(u for r, u in units if re.fullmatch(r, unit, flags=re.I)) 86 | except StopIteration: 87 | raise ValueError("Invalid unit value while parsing duration: '%s'" % unit) from None 88 | 89 | seconds += value * unit 90 | 91 | return datetime.timedelta(seconds=seconds) 92 | 93 | 94 | def parse_frequency(frequency): 95 | frequencies = { 96 | "hourly": Frequencies.hourly, 97 | "daily": Frequencies.daily, 98 | "weekly": Frequencies.weekly, 99 | "monthly": Frequencies.monthly, 100 | "yearly": Frequencies.yearly, 101 | } 102 | frequency = frequency.strip().lower() 103 | return frequencies.get(frequency, None) 104 | 105 | 106 | def parse_day(day): 107 | days = { 108 | "monday": 0, 109 | "tuesday": 1, 110 | "wednesday": 2, 111 | "thursday": 3, 112 | "friday": 4, 113 | "saturday": 5, 114 | "sunday": 6, 115 | } 116 | day = day.strip().lower() 117 | if day in days: 118 | return days[day] 119 | if day.startswith("w") and day[1:].isdigit(): 120 | day = int(day[1:]) 121 | if not 0 <= day < 7: 122 | raise ValueError("Invalid weekday value while parsing day (expected [0-6]): '%d'" % day) 123 | else: 124 | day = None 125 | 126 | return day 127 | 128 | 129 | def parse_time(time): 130 | time = time.strip() 131 | reg = re.compile(r"^[\d\.\:]+\s*(?:[ap]m)?$", flags=re.I) 132 | 133 | if not reg.match(time): 134 | return None 135 | 136 | formats = [ 137 | "%H", 138 | "%H:%M", 139 | "%H:%M:%S", 140 | "%H:%M:%S.%f", 141 | "%I %p", 142 | "%I:%M %S", 143 | "%I:%M:%S %p", 144 | "%I:%M:%S.%f %p", 145 | ] 146 | 147 | for format_ in formats: 148 | try: 149 | dt = datetime.datetime.strptime(time, format_) 150 | except ValueError: 151 | pass 152 | else: 153 | return dt.time() 154 | 155 | raise ValueError("Unrecognized format while parsing time: '%s'" % time) 156 | 157 | 158 | def parse_daytime(daytime): 159 | daytime = daytime.strip() 160 | reg = re.compile(r"^(.*?)\s+at\s+(.*)$", flags=re.I) 161 | 162 | match = reg.match(daytime) 163 | if match: 164 | day, time = match.groups() 165 | else: 166 | day = time = daytime 167 | 168 | try: 169 | day = parse_day(day) 170 | if match and day is None: 171 | raise ValueError 172 | except ValueError as e: 173 | raise ValueError("Invalid day while parsing daytime: '%s'" % day) from e 174 | 175 | try: 176 | time = parse_time(time) 177 | if match and time is None: 178 | raise ValueError 179 | except ValueError as e: 180 | raise ValueError("Invalid time while parsing daytime: '%s'" % time) from e 181 | 182 | if day is None and time is None: 183 | return None 184 | 185 | return day, time 186 | -------------------------------------------------------------------------------- /loguru/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FastestMolasses/Fastest-FastAPI/bed27b943813341f83402aedb9ec6cc4090076c5/loguru/py.typed -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # uvicorn main:server --reload 2 | import app.models.mysql 3 | import app.db.connection 4 | import app.api.middleware as middleware 5 | 6 | from fastapi import FastAPI 7 | from app.api.router import apiRouter 8 | from app.core.config import settings 9 | from app.log.setup import setup_logging 10 | from app.types.server import ServerResponse 11 | from fastapi.responses import ORJSONResponse 12 | from starlette.middleware.cors import CORSMiddleware 13 | from prometheus_fastapi_instrumentator import Instrumentator 14 | 15 | server = FastAPI( 16 | title='fastapi-server', 17 | debug=settings.DEBUG, 18 | openapi_url=f'{settings.API_V1_STR}/openapi.json', 19 | default_response_class=ORJSONResponse, 20 | ) 21 | server.include_router(apiRouter, prefix=settings.API_V1_STR) 22 | 23 | setup_logging() 24 | 25 | # Prometheus metrics 26 | Instrumentator( 27 | env_var_name='ENABLE_METRICS', 28 | ).instrument(server).expose(server) 29 | 30 | # Creates the tables if they dont exist 31 | app.db.connection.createMySQLTables() 32 | 33 | # Set all CORS enabled origins 34 | if settings.BACKEND_CORS_ORIGINS: 35 | server.add_middleware( 36 | CORSMiddleware, 37 | allow_origins=[str(origin) 38 | for origin in settings.BACKEND_CORS_ORIGINS], 39 | allow_credentials=True, 40 | allow_methods=['*'], 41 | allow_headers=['*'], 42 | ) 43 | 44 | server.add_middleware(middleware.DBExceptionsMiddleware) 45 | server.add_middleware(middleware.CatchAllMiddleware) 46 | server.add_middleware(middleware.ProfilingMiddleware) 47 | 48 | @server.get('/') 49 | async def root() -> ServerResponse: 50 | return ServerResponse() 51 | -------------------------------------------------------------------------------- /make-env-example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ### Used with Makefiles to set environment variables 4 | ### Copy this file into make-env.sh and fill in the blanks 5 | 6 | export AWS_REGION= 7 | export AWS_ACCOUNT_ID= 8 | export CONTAINER_REPO_NAME= 9 | export DOCKER_IMAGE_NAME= 10 | export CONTAINER_NAME= 11 | export ECS_CLUSTER_NAME= 12 | export ECS_SERVICE_NAME= 13 | 14 | make "$@" 15 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | plugins = pydantic.mypy 3 | ignore_missing_imports = True 4 | disallow_untyped_defs = False 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fastapi-base" 3 | version = "0.1.0" 4 | description = "" 5 | authors = [{ name = "Abe M", email = "abe.malla8@gmail.com" }] 6 | readme = "README.md" 7 | requires-python = ">=3.11" 8 | 9 | [tool.ruff] 10 | line-length = 100 11 | target-version = "py311" 12 | extend-exclude = [".pytest_cache"] 13 | ignore-init-module-imports = true 14 | 15 | [tool.ruff.format] 16 | quote-style = "single" 17 | indent-style = "space" 18 | skip-magic-trailing-comma = false 19 | line-ending = "auto" 20 | 21 | [tool.ruff.mccabe] 22 | max-complexity = 10 23 | 24 | [tool.ruff.per-file-ignores] 25 | "main.py" = ["E402"] 26 | "shell.py" = ["F401"] 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile pyproject.toml -o requirements.txt 3 | aiohappyeyeballs==2.3.4 4 | # via aiohttp 5 | aiohttp==3.10.1 6 | # via fastapi-base (pyproject.toml) 7 | aiosignal==1.3.1 8 | # via aiohttp 9 | alembic==1.13.2 10 | # via fastapi-base (pyproject.toml) 11 | annotated-types==0.5.0 12 | # via pydantic 13 | anyio==3.7.1 14 | # via starlette 15 | attrs==23.1.0 16 | # via aiohttp 17 | boto3==1.34.154 18 | # via fastapi-base (pyproject.toml) 19 | botocore==1.34.154 20 | # via 21 | # boto3 22 | # s3transfer 23 | certifi==2023.7.22 24 | # via requests 25 | cffi==1.15.1 26 | # via cryptography 27 | charset-normalizer==3.2.0 28 | # via requests 29 | click==8.1.7 30 | # via uvicorn 31 | cryptography==41.0.4 32 | # via python-jose 33 | ecdsa==0.18.0 34 | # via python-jose 35 | fastapi==0.112.0 36 | # via fastapi-base (pyproject.toml) 37 | frozenlist==1.4.0 38 | # via 39 | # aiohttp 40 | # aiosignal 41 | gunicorn==22.0.0 42 | # via fastapi-base (pyproject.toml) 43 | h11==0.14.0 44 | # via uvicorn 45 | idna==3.4 46 | # via 47 | # anyio 48 | # requests 49 | # yarl 50 | jmespath==1.0.1 51 | # via 52 | # boto3 53 | # botocore 54 | mako==1.2.4 55 | # via alembic 56 | markupsafe==2.1.3 57 | # via mako 58 | multidict==6.0.4 59 | # via 60 | # aiohttp 61 | # yarl 62 | orjson==3.10.6 63 | # via fastapi-base (pyproject.toml) 64 | packaging==23.1 65 | # via gunicorn 66 | picologging==0.9.3 67 | # via fastapi-base (pyproject.toml) 68 | prometheus-client==0.17.1 69 | # via prometheus-fastapi-instrumentator 70 | prometheus-fastapi-instrumentator==7.0.0 71 | # via fastapi-base (pyproject.toml) 72 | psycopg==3.2.1 73 | # via fastapi-base (pyproject.toml) 74 | pyasn1==0.5.0 75 | # via 76 | # python-jose 77 | # rsa 78 | pycparser==2.21 79 | # via cffi 80 | pydantic==2.8.2 81 | # via 82 | # fastapi-base (pyproject.toml) 83 | # fastapi 84 | # pydantic-settings 85 | pydantic-core==2.20.1 86 | # via pydantic 87 | pydantic-settings==2.4.0 88 | # via fastapi-base (pyproject.toml) 89 | pyinstrument==4.7.2 90 | # via fastapi-base (pyproject.toml) 91 | pymysql==1.1.1 92 | # via fastapi-base (pyproject.toml) 93 | python-dateutil==2.8.2 94 | # via botocore 95 | python-dotenv==1.0.1 96 | # via 97 | # fastapi-base (pyproject.toml) 98 | # pydantic-settings 99 | python-jose==3.3.0 100 | # via fastapi-base (pyproject.toml) 101 | redis==5.0.8 102 | # via fastapi-base (pyproject.toml) 103 | requests==2.32.3 104 | # via fastapi-base (pyproject.toml) 105 | rsa==4.9 106 | # via python-jose 107 | s3transfer==0.10.2 108 | # via boto3 109 | six==1.16.0 110 | # via 111 | # ecdsa 112 | # python-dateutil 113 | sniffio==1.3.0 114 | # via anyio 115 | sqlalchemy==2.0.32 116 | # via 117 | # fastapi-base (pyproject.toml) 118 | # alembic 119 | starlette==0.37.2 120 | # via 121 | # fastapi 122 | # prometheus-fastapi-instrumentator 123 | typing-extensions==4.8.0 124 | # via 125 | # alembic 126 | # fastapi 127 | # psycopg 128 | # pydantic 129 | # pydantic-core 130 | # sqlalchemy 131 | urllib3==1.26.16 132 | # via 133 | # botocore 134 | # requests 135 | uvicorn==0.30.5 136 | # via fastapi-base (pyproject.toml) 137 | uvloop==0.19.0 138 | # via fastapi-base (pyproject.toml) 139 | yarl==1.9.2 140 | # via aiohttp 141 | -------------------------------------------------------------------------------- /shell.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to run a shell with all the models and the session 3 | loaded. This is useful for debugging and testing. 4 | """ 5 | # python -i shell.py 6 | # noqa: F401 7 | from app.models.mysql import User, UserNotification, Notification 8 | from app.db.connection import MySqlSession 9 | 10 | 11 | session = MySqlSession() 12 | --------------------------------------------------------------------------------