├── .dockerignore
├── .env.example
├── .github
└── workflows
│ └── main.yaml
├── .gitignore
├── .vscode
└── settings.json
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── alembic.ini
├── app
├── api
│ ├── endpoints
│ │ └── auth.py
│ ├── middleware.py
│ └── router.py
├── auth
│ ├── jwt.py
│ └── key.py
├── cache
│ ├── redis.py
│ └── time_cache.py
├── core
│ └── config.py
├── db
│ ├── connection.py
│ ├── dep.py
│ └── util.py
├── discord
│ ├── client.py
│ ├── config.py
│ ├── exceptions.py
│ └── models
│ │ ├── guild.py
│ │ ├── role.py
│ │ └── user.py
├── lmbd
│ ├── invoke.py
│ ├── sample_data
│ │ └── snsCloudWatchEvent.json
│ └── sample_lambda
│ │ └── main.py
├── log
│ └── setup.py
├── migrations
│ ├── env.py
│ └── script.py.mako
├── models
│ ├── mysql.py
│ └── mysql_no_fk.py
├── types
│ ├── cache.py
│ ├── jwt.py
│ ├── lmbd.py
│ ├── server.py
│ └── time.py
└── util
│ ├── common.py
│ ├── json.py
│ └── time.py
├── dev-requirements.txt
├── docker-compose.yaml
├── docker-entrypoint.sh
├── gunicorn_conf.py
├── loguru
├── __init__.py
├── __init__.pyi
├── _asyncio_loop.py
├── _better_exceptions.py
├── _colorama.py
├── _colorizer.py
├── _contextvars.py
├── _ctime_functions.py
├── _datetime.py
├── _defaults.py
├── _error_interceptor.py
├── _file_sink.py
├── _filters.py
├── _get_frame.py
├── _handler.py
├── _locks_machinery.py
├── _logger.py
├── _recattrs.py
├── _simple_sinks.py
├── _string_parsers.py
└── py.typed
├── main.py
├── make-env-example.sh
├── mypy.ini
├── pyproject.toml
├── requirements.txt
└── shell.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Version control system files
2 | .git
3 | .gitignore
4 |
5 | # Logs and temporary files
6 | *.log
7 | logs/
8 | __pycache__/
9 | .mypy_cache/
10 | .ruff_cache/
11 | .cache/
12 |
13 | # Virtual environment and dependency files
14 | env/
15 | .venv/
16 |
17 | # Environment configuration files
18 | .env.prod
19 | .env.production
20 | .env.dev
21 | .env.development
22 | .env.staging
23 | .env.stage
24 | .env.test
25 | .env.testing
26 | .env.local
27 | .env.example
28 |
29 | # Build and setup scripts
30 | make-env-example.sh
31 | make-env.sh
32 |
33 | # Test files and directories
34 | test.py
35 | test.json
36 | tests/
37 |
38 | # IDE and editor specific files
39 | .vscode/
40 |
41 | # GitHub workflow and contribution files
42 | .github/
43 | CHANGELOG.md
44 | CONTRIBUTING.md
45 |
46 | # Documentation and license files
47 | README.md
48 |
49 | # Docker Compose file (if not needed in build context)
50 | docker-compose.yml
51 |
52 | # Package and archive files
53 | *.zip
54 | *.7zip
55 | *.dmg
56 | *.gz
57 | *.iso
58 | *.jar
59 | *.rar
60 |
61 | # Other
62 | .DS_STORE
63 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | DEBUG=0
2 | ENABLE_METRICS=1
3 | DOMAIN=localhost
4 | SECRET_KEY=
5 | REFRESH_KEY=
6 | PROFILING=0
7 | JWT_USE_NONCE=0
8 |
9 | # Backend
10 | BACKEND_CORS_ORIGINS=["http://localhost:8000","http://localhost:5000"]
11 |
12 | # MySQL
13 | MYSQL_HOST=localhost:3306
14 | MYSQL_USER=root
15 | MYSQL_PASSWORD=root
16 | MYSQL_DATABASE=fastapi-db
17 | MYSQL_SSL=/etc/ssl/cert.pem
18 |
19 | # Redis
20 | REDIS_HOST=localhost
21 | REDIS_PORT=6379
22 | REDIS_PASSWORD=test
23 |
24 | # AWS
25 | AWS_ACCESS_KEY_ID=test
26 | AWS_SECRET_ACCESS_KEY=test
27 | AWS_DEFAULT_REGION=us-east-1
28 |
--------------------------------------------------------------------------------
/.github/workflows/main.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - master
8 |
9 | permissions:
10 | contents: read
11 |
12 | jobs:
13 | lint:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Checkout
17 | uses: actions/checkout@v4
18 |
19 | - name: Set up Python 3.11
20 | uses: actions/setup-python@v3
21 | with:
22 | python-version: "3.11"
23 |
24 | - name: Install linter
25 | run: |
26 | python -m pip install --upgrade pip
27 | pip install ruff
28 |
29 | - name: Linting code
30 | run: |
31 | ruff check .
32 |
33 | type-check:
34 | runs-on: ubuntu-latest
35 | steps:
36 | - name: Checkout
37 | uses: actions/checkout@v4
38 |
39 | - name: Set up Python 3.11
40 | uses: actions/setup-python@v3
41 | with:
42 | python-version: "3.11"
43 |
44 | - name: Install dependencies
45 | run: |
46 | python -m pip install --upgrade pip
47 | pip install -r requirements.txt
48 | pip install -r dev-requirements.txt
49 |
50 | - name: Type checking code
51 | run: |
52 | mypy . --explicit-package-bases --exclude loguru
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled source code files
2 | __pycache__/
3 | *.c
4 | *.so
5 | *.dll
6 | *.pyd
7 | *.exe
8 | *.pyc
9 | compiled/
10 |
11 | # Build and environment files
12 | build/
13 | env/
14 | .venv/
15 | .docker/
16 | docker/
17 | make-env.sh
18 |
19 | # Configuration and environment variable files
20 | .env
21 | .env.prod
22 | .env.production
23 | .env.dev
24 | .env.development
25 | .env.staging
26 | .env.stage
27 | .env.test
28 | .env.testing
29 | .env.local
30 | *prod.env
31 | *production.env
32 | *dev.env
33 | *development.env
34 | *staging.env
35 | *stage.env
36 | *test.env
37 | *testing.env
38 | *local.env
39 |
40 | # Logs and temporary files
41 | *.log
42 | logs/
43 | .ruff_cache/
44 | .cache/
45 | .mypy_cache/
46 |
47 | # Package and archive files
48 | *.zip
49 | *.7zip
50 | *.dmg
51 | *.gz
52 | *.iso
53 | *.jar
54 | *.rar
55 |
56 | # IDE and editor specific files
57 | !.vscode/settings.json
58 | .DS_STORE
59 |
60 | # Test files
61 | test.py
62 | test.json
63 |
64 | # Specific includes and excludes
65 | *.txt
66 | !requirements.txt
67 | !dev-requirements.txt
68 | chromedriver*
69 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.analysis.typeCheckingMode": "basic",
3 | }
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Creating a python base with shared environment variables
2 | FROM python:3.11-slim-bullseye AS python-base
3 | ENV PYTHONUNBUFFERED=1 \
4 | PYTHONDONTWRITEBYTECODE=1 \
5 | PIP_NO_CACHE_DIR=off \
6 | PIP_DISABLE_PIP_VERSION_CHECK=on \
7 | PIP_DEFAULT_TIMEOUT=100 \
8 | UV_HOME="/opt/uv" \
9 | PYSETUP_PATH="/opt/pysetup" \
10 | VENV_PATH="/opt/pysetup/.venv"
11 |
12 | ENV PATH="$UV_HOME/bin:$VENV_PATH/bin:$PATH"
13 |
14 | # builder-base is used to build dependencies
15 | FROM python-base AS builder-base
16 | RUN buildDeps="build-essential" \
17 | && apt-get update \
18 | && apt-get install --no-install-recommends -y \
19 | curl \
20 | vim \
21 | netcat \
22 | && apt-get install -y --no-install-recommends $buildDeps \
23 | && rm -rf /var/lib/apt/lists/*
24 |
25 | # Install UV
26 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh
27 |
28 | # We copy our Python requirements here to cache them
29 | # and install only runtime deps using uv
30 | WORKDIR $PYSETUP_PATH
31 | COPY ./requirements.txt ./
32 | RUN uv venv $VENV_PATH && \
33 | . $VENV_PATH/bin/activate && \
34 | uv pip install -r requirements.txt
35 |
36 | # 'development' stage installs all dev deps and can be used to develop code.
37 | # For example using docker-compose to mount local volume under /app
38 | FROM python-base as development
39 | ENV FASTAPI_ENV=development
40 |
41 | # Copying uv and venv into image
42 | COPY --from=builder-base $UV_HOME $UV_HOME
43 | COPY --from=builder-base $PYSETUP_PATH $PYSETUP_PATH
44 |
45 | # Copying in our entrypoint
46 | COPY docker-entrypoint.sh /docker-entrypoint.sh
47 | RUN chmod +x /docker-entrypoint.sh
48 |
49 | # venv already has runtime deps installed we get a quicker install
50 | WORKDIR $PYSETUP_PATH
51 | COPY ./dev-requirements.txt ./
52 | RUN . $VENV_PATH/bin/activate && \
53 | uv pip install -r dev-requirements.txt
54 |
55 | WORKDIR /app
56 | COPY . .
57 |
58 | # Needs to be consistent with gunicorn_conf.py
59 | EXPOSE 8000
60 | ENTRYPOINT ["/docker-entrypoint.sh"]
61 |
62 | # 'production' stage uses the clean 'python-base' stage and copies
63 | # in only our runtime deps that were installed in the 'builder-base'
64 | FROM python-base AS production
65 | ENV FASTAPI_ENV=production
66 | ENV PROD=true
67 |
68 | COPY --from=builder-base $VENV_PATH $VENV_PATH
69 | COPY gunicorn_conf.py /gunicorn_conf.py
70 |
71 | COPY docker-entrypoint.sh /docker-entrypoint.sh
72 | RUN chmod +x /docker-entrypoint.sh
73 |
74 | COPY main.py /main.py
75 | COPY .env /.env
76 |
77 | # Create user with the name appuser
78 | RUN groupadd -g 1500 appuser && \
79 | useradd -m -u 1500 -g appuser appuser
80 |
81 | COPY --chown=appuser:appuser ./app /app
82 | USER appuser
83 |
84 | ENTRYPOINT ["/docker-entrypoint.sh"]
85 | CMD ["gunicorn", "--worker-class", "uvicorn.workers.UvicornWorker", "--config", "/gunicorn_conf.py", "main:server"]
86 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Abe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | build-container: ## Build container
2 | docker buildx build --platform=linux/amd64 -t $(CONTAINER_NAME) .
3 |
4 | run-container: ## Run container
5 | docker run --name=$(CONTAINER_NAME) -d -p 5500:8000 $(CONTAINER_NAME)
6 |
7 | upload-container: ## Upload container to ECR
8 | aws ecr get-login-password --region $(AWS_REGION) | docker login --username AWS --password-stdin $(AWS_ACCOUNT_ID).dkr.ecr.$(AWS_REGION).amazonaws.com/$(CONTAINER_REPO_NAME)
9 | docker tag $(CONTAINER_NAME):latest $(AWS_ACCOUNT_ID).dkr.ecr.$(AWS_REGION).amazonaws.com/$(CONTAINER_REPO_NAME):latest
10 | docker push $(AWS_ACCOUNT_ID).dkr.ecr.$(AWS_REGION).amazonaws.com/$(CONTAINER_REPO_NAME):latest
11 |
12 | cutover-ecs: ## Cutover to new ECS task definition
13 | aws ecs update-service --cluster $(ECS_CLUSTER_NAME) --service $(ECS_SERVICE_NAME) --force-new-deployment
14 |
15 | deploy-cloudwatch-lambda: ## Deploy the cloudwatch lambda ## TODO: UPDATE FOR SAMPLE LAMBDA
16 | zip -j sns_cloudwatch_webhook.zip app/lmbd/sns_cloudwatch_webhook/main.py
17 | aws lambda update-function-code --function-name sns-cloudwatch-webhook --zip-file fileb://sns_cloudwatch_webhook.zip
18 | rm sns_cloudwatch_webhook.zip
19 |
20 | help: ## Display help
21 | @awk -F ':|##' '/^[^\t].+?:.*?##/ {printf "\033[36m%-30s\033[0m %s\n", $$1, $$NF}' $(MAKEFILE_LIST) | sort
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
⚡️ Fastest FastAPI
3 |
4 |
5 |
6 |  [](https://github.com/FastestMolasses/Fast-Python-Server-Template/blob/main/LICENSE)
7 |
8 |
9 |
10 |
11 | A production-ready FastAPI server template, emphasizing performance and type safety. It includes a configurable set of features and options, allowing customization to retain or exclude components.
12 |
13 |
14 | Built with FastAPI, Pydantic, Ruff, and MyPy.
15 |
16 | Report Bug
17 | ·
18 | Request Feature
19 |
20 |
21 |
22 | ## Features
23 |
24 | - ⚡ Async and type safety by default
25 | - 🛠️ CI/CD and tooling setup
26 | - 🚀 High performance libraries integrated ([orjson](https://github.com/ijl/orjson), [uvloop](https://github.com/MagicStack/uvloop), [pydantic2](https://github.com/pydantic/pydantic))
27 | - 📝 [Loguru](https://github.com/Delgan/loguru) + [picologging](https://github.com/microsoft/picologging) for simplified and performant logging
28 | - 🐳 Dockerized and includes AWS deployment flow
29 | - 🗃️ Several database implementations with sample ORM models (MySQL, Postgres, Timescale) & migrations
30 | - 🔐 Optional JWT authentication and authorization
31 | - 🌐 AWS Lambda functions support
32 | - 🧩 Modularized features
33 | - 📊 Prometheus metrics
34 | - 📜 Makefile commands
35 | - 🗺️ Route profiling
36 |
37 | ## Table of Contents
38 |
39 | - [Requirements](#requirements)
40 | - [Installation](#installation)
41 | - [Environment Specific Configuration](#environment-specific-configuration)
42 | - [Upgrading Dependencies](#upgrading-dependencies)
43 | - [Databases](#databases)
44 | - [Shell](#shell)
45 | - [Migrations](#migrations)
46 | - [Downgrade Migration](#downgrade-migration)
47 | - [JWT Auth](#jwt-auth)
48 | - [JWT Overview](#jwt-overview)
49 | - [Modifying JWT Payload Fields](#modifying-jwt-payload-fields)
50 | - [Project Structure](#project-structure)
51 | - [Makefile Commands](#makefile-commands)
52 | - [Contributing](#contributing)
53 |
54 | ## Requirements
55 |
56 | - [Python 3.11+](https://www.python.org/downloads/)
57 | - [Docker](https://www.docker.com/get-started/)
58 |
59 | ## Installation
60 |
61 | 1. Fork this repo ([How to create a private fork](https://gist.github.com/0xjac/85097472043b697ab57ba1b1c7530274))
62 |
63 | 2. Install UV Package Manager:
64 | Follow the installation instructions at https://github.com/astral-sh/uv
65 |
66 | 3. Set up the virtual environment and install dependencies:
67 |
68 | ```bash
69 | uv venv
70 |
71 | # On macOS and Linux
72 | source .venv/bin/activate
73 |
74 | # On Windows
75 | .venv\Scripts\activate
76 |
77 | # Install main dependencies
78 | uv pip install -r requirements.txt
79 |
80 | # Install development dependencies (optional)
81 | uv pip install -r dev-requirements.txt
82 | ```
83 |
84 | 4. Install [Docker](https://www.docker.com/get-started/)
85 |
86 | 5. Start your Docker services:
87 |
88 | ```bash
89 | docker compose up
90 | ```
91 |
92 | 6. Clone `.env.example` to `.env` and update the values:
93 |
94 | ```bash
95 | # macOS and Linux
96 | cp .env.example .env
97 |
98 | # Windows (PowerShell)
99 | Copy-Item .env.example .env
100 | ```
101 |
102 | You can use this command to generate secret keys:
103 |
104 | ```bash
105 | # macOS and Linux
106 | openssl rand -hex 128
107 |
108 | # Windows (PowerShell)
109 | $bytes = New-Object byte[] 128; (New-Object Security.Cryptography.RNGCryptoServiceProvider).GetBytes($bytes); [System.BitConverter]::ToString($bytes) -Replace '-'
110 | ```
111 |
112 | 7. Run the server:
113 |
114 | ```bash
115 | uvicorn main:server --reload
116 | ```
117 |
118 | Note: If you need to update dependencies, you can modify the `requirements.txt` or `dev-requirements.txt` files directly and then run `uv pip install -r requirements.txt` or `uv pip install -r dev-requirements.txt` respectively.
119 |
120 | ## Environment Specific Configuration
121 |
122 | This project uses environment-specific configuration files and symbolic links to manage different environments such as development, production, and staging. Follow the steps below for your operating system to set up the desired environment.
123 |
124 | ```bash
125 | # macOS, linux
126 | ln -s .env .env
127 | # example: ln -s prod.env .env
128 |
129 | # windows
130 | mklink .env .env
131 | # example: mklink .env prod.env
132 | ```
133 |
134 | ## Databases
135 |
136 | ### Shell
137 |
138 | To access the database shell, run this command
139 |
140 | ```bash
141 | python -i shell.py
142 | ```
143 |
144 | The `shell.py` script will be loaded including the database session and models.
145 |
146 | ### Migrations
147 |
148 | To do a database migration, follow the steps below.
149 |
150 | 1. Update `database/models.py` with the changes you want
151 | 2. Run this command to generate the migration file in `migrations/versions`
152 |
153 | ```bash
154 | alembic revision --autogenerate -m "Describe your migration"
155 | ```
156 |
157 | 3. Check the newly generated migration file and verify that it generated correctly.
158 | 4. Run this command to apply the migration
159 | ```bash
160 | alembic upgrade head
161 | ```
162 |
163 | ⛔️ Autogenerated migrations cannot detect these changes:
164 |
165 | - Changes of table name
166 | - Changes of column name
167 | - Anonymously named constraints
168 | - Special SQLAlchemy types such as Enum when generated on a backend which doesn’t support ENUM directly
169 |
170 | [Reference](https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect)
171 |
172 | These changes will need to be migrated manually by creating an empty migration file and then writing the code to create the changes.
173 |
174 | ```bash
175 | # Manual creation of empty migration file
176 | alembic revision -m "Describe your migration"
177 | ```
178 |
179 | ### Downgrade Migration
180 |
181 | Run this command to revert every migration back to the beginning.
182 |
183 | ```bash
184 | alembic downgrade base
185 | ```
186 |
187 | ## JWT Implementation
188 |
189 | In this FastAPI template, JSON Web Tokens (JWT) can be optionally utilized for authentication. This documentation section elucidates the JWT implementation and related functionalities.
190 |
191 | ### JWT Overview
192 |
193 | The JWT implementation can be found in the module: app/auth/jwt.py. The primary functions include:
194 |
195 | - Creating access and refresh JWT tokens.
196 | - Verifying and decoding a given JWT token.
197 | - Handling JWT-based authentication for FastAPI routes.
198 |
199 | #### User Management
200 |
201 | If a user associated with a JWT token is not found in the database, a new user will be created. This is managed by the get_or_create_user function. When a token is decoded and the corresponding user ID (sub field in the token) is not found, the system will attempt to create a new user with that ID.
202 |
203 | #### Nonce Usage
204 |
205 | A nonce is an arbitrary number that can be used just once. It's an optional field in the JWT token to ensure additional security. If a nonce is used:
206 |
207 | - It is stored in Redis for the duration of the refresh token's validity.
208 | - It must match between access and refresh tokens to ensure their pairing.
209 | - Its presence in Redis is verified before the token is considered valid.
210 |
211 | Enabling nonce usage provides an additional layer of security against token reuse, but requires Redis to function.
212 |
213 | ### Modifying JWT Payload Fields
214 |
215 | The JWT token payload structure is defined in `app/types/jwt.py`` under the JWTPayload class. If you wish to add more fields to the JWT token payload:
216 |
217 | 1. Update the TokenData and JWTPayload class in `app/types/jwt.py`` by adding the desired fields.
218 |
219 | ```python
220 | class JWTPayload(BaseModel):
221 | # ... existing fields ...
222 | new_field: Type
223 |
224 | class TokenData(BaseModel):
225 | # ... existing fields ...
226 | new_field: Type
227 | ```
228 |
229 | TokenData is separated from JWTPayload to make it clear what is automatically filled in and what is manually added. Both classes must be updated to include the new fields.
230 |
231 | 2. Wherever the token is created, update the payload to include the new fields.
232 |
233 | ```python
234 | from app.auth.jwt import create_jwt
235 | from app.types.jwt import TokenData
236 |
237 | payload = TokenData(
238 | sub='user_id_1',
239 | field1='value1',
240 | # ... all fields ...
241 | )
242 | access_token, refresh_token = create_jwt(payload)
243 | ```
244 |
245 | Remember, the JWT token has a size limit. The more data you include, the bigger your token becomes, so ensure that you only include essential data in the token payload.
246 |
247 | ## Project Structure
248 |
249 | ```
250 | 📄 main.py - Server entry point
251 | 📁 .github/ - Github specific files
252 | 📁 app/ - Application code
253 | ├── 📁 api - API endpoints and middleware
254 | ├── 📁 auth - Authentication / authorization
255 | ├── 📁 cache - Redis code and caching functions
256 | ├── 📁 core - Core configuration
257 | ├── 📁 db - Database connections
258 | ├── 📁 discord - Discord library for auth (optional)
259 | ├── 📁 lmbd - Holds AWS lambda functions
260 | ├── 📁 migrations - Database migrations
261 | ├── 📁 models - Database ORM models
262 | ├── 📁 types - Type definitions
263 | └── 📁 util - Helper functions
264 | ```
265 |
266 | ## Makefile Commands
267 |
268 | Make files are used to run common commands. You can find the list of commands in the `Makefile` file.
269 | To use these commands, first copy `make-env-example.sh` to `make-env.sh` and update the values.
270 |
271 | ```bash
272 | # macOS
273 | cp make-env-example.sh make-env.sh
274 |
275 | # windows (powershell)
276 | copy make-env-example.sh make-env.sh
277 | ```
278 |
279 | Remember to make the file executable
280 |
281 | ```bash
282 | chmod +x make-env.sh
283 | ```
284 |
285 | Then you can run the commands like this
286 |
287 | ```bash
288 | ./make-env.sh
289 | ```
290 |
291 | Try it with the help command, which will list all the available commands.
292 |
293 | ```bash
294 | ./make-env.sh help
295 | ```
296 |
297 | ## Contributing
298 |
299 | 1. **Fork the Repository**: Start by forking the repository to your own GitHub account.
300 | 2. **Clone the Forked Repository**: Clone the fork to your local machine.
301 | 3. **Create a New Branch**: Always create a new branch for your changes.
302 | 4. **Make Your Changes**: Implement your changes.
303 | 5. **Run Tests**: Make sure to test your changes locally.
304 | 6. **Submit a Pull Request**: Commit and push your changes, then create a pull request against the main branch.
305 |
--------------------------------------------------------------------------------
/alembic.ini:
--------------------------------------------------------------------------------
1 | [alembic]
2 | script_location = app/migrations
3 |
4 | # Template used to generate migration file names
5 | file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(slug)s
6 |
7 | # sys.path path, will be prepended to sys.path if present
8 | prepend_sys_path = .
9 |
10 | # Max length of characters to apply to the "slug" field
11 | truncate_slug_length = 8
12 |
13 | # Set to 'true' to run the environment during
14 | # the 'revision' command, regardless of autogenerate
15 | revision_environment = false
16 |
17 | # The output encoding used when revision files
18 | # are written from script.py.mako
19 | output_encoding = utf-8
20 |
21 | # Logging configuration
22 | [loggers]
23 | keys = root,sqlalchemy,alembic
24 |
25 | [handlers]
26 | keys = console
27 |
28 | [formatters]
29 | keys = generic
30 |
31 | [logger_root]
32 | level = WARN
33 | handlers = console
34 | qualname =
35 |
36 | [logger_sqlalchemy]
37 | level = WARN
38 | handlers =
39 | qualname = sqlalchemy.engine
40 |
41 | [logger_alembic]
42 | level = INFO
43 | handlers =
44 | qualname = alembic
45 |
46 | [handler_console]
47 | class = StreamHandler
48 | args = (sys.stderr,)
49 | level = NOTSET
50 | formatter = generic
51 |
52 | [formatter_generic]
53 | format = %(levelname)-5.5s [%(name)s] %(message)s
54 | datefmt = %H:%M:%S
55 |
--------------------------------------------------------------------------------
/app/api/endpoints/auth.py:
--------------------------------------------------------------------------------
1 | from jose import JWTError
2 | from loguru import logger
3 | from fastapi import APIRouter, Depends
4 | from fastapi.responses import ORJSONResponse
5 |
6 | from app.models.mysql import User
7 | from app.core.config import settings
8 | from app.db.connection import MySqlSession
9 | from app.auth.jwt import RequireRefreshToken, RequireJWT, create_jwt
10 |
11 | from app.types.jwt import TokenData, JWTPayload
12 | from app.types.server import ServerResponse, Cookie
13 |
14 | router = APIRouter(prefix='/auth')
15 |
16 |
17 | @router.get('/signup')
18 | async def signup() -> ServerResponse[str]:
19 | """
20 | Example route to create a user.
21 | """
22 | with MySqlSession() as session:
23 | user = User()
24 | session.add(user)
25 | session.commit()
26 | session.refresh(user)
27 |
28 | return ServerResponse()
29 |
30 |
31 | @router.get('/login')
32 | async def login(response: ORJSONResponse) -> ServerResponse[str]:
33 | """
34 | Example route to login a user. Will just grab the first user from the database
35 | and create a JWT for them.
36 | """
37 | # Grab the first user from the database
38 | with MySqlSession() as session:
39 | user = session.query(User).first()
40 |
41 | if not user:
42 | return ServerResponse(status='error', message='No user found')
43 |
44 | token = TokenData(sub=str(user.id))
45 | try:
46 | accessToken, refreshToken = create_jwt(token)
47 | except JWTError as e:
48 | logger.error(f'JWT Error during login: {e}')
49 | return ServerResponse(status='error', message='JWT Error, try again')
50 |
51 | # Save the refresh token in an HTTPOnly cookie
52 | response.set_cookie(
53 | Cookie.REFRESH_TOKEN,
54 | value=refreshToken,
55 | httponly=True,
56 | max_age=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
57 | expires=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
58 | )
59 | return ServerResponse[str](data=accessToken)
60 |
61 |
62 | @router.get('/refresh')
63 | async def refresh(
64 | response: ORJSONResponse, payload: JWTPayload = Depends(RequireRefreshToken)
65 | ) -> ServerResponse[str]:
66 | """
67 | Example route to refresh a JWT.
68 | """
69 | token = TokenData(sub=payload.sub)
70 |
71 | try:
72 | accessToken, refreshToken = create_jwt(token)
73 | except JWTError as e:
74 | logger.error(f'JWT Error during login: {e}')
75 | return ServerResponse(status='error', message='JWT Error, try again.')
76 |
77 | # Save the refresh token in an HTTPOnly cookie
78 | response.set_cookie(
79 | Cookie.REFRESH_TOKEN,
80 | value=refreshToken,
81 | httponly=True,
82 | max_age=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
83 | expires=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
84 | )
85 | return ServerResponse[str](data=accessToken)
86 |
87 |
88 | @router.get('/decodeToken')
89 | async def decodeToken(payload: JWTPayload = Depends(RequireJWT())):
90 | """
91 | Example route to decode a JWT.
92 | """
93 | return ServerResponse[dict](data=payload.model_dump())
94 |
--------------------------------------------------------------------------------
/app/api/middleware.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | from typing import Callable
3 | from fastapi import Request
4 | # from pyinstrument import Profiler
5 | from app.core.config import settings
6 | from sqlalchemy.exc import IntegrityError
7 | from app.types.server import ServerResponse
8 | from fastapi.responses import ORJSONResponse
9 | from starlette.middleware.base import BaseHTTPMiddleware
10 | from sqlalchemy.exc import NoResultFound, MultipleResultsFound
11 |
12 |
13 | class DBExceptionsMiddleware(BaseHTTPMiddleware):
14 | """
15 | Middleware to catch and handle database exceptions.
16 | """
17 | async def dispatch(self, request: Request, call_next: Callable):
18 | try:
19 | response = await call_next(request)
20 | return response
21 |
22 | except NoResultFound as e:
23 | logger.exception(f'NoResultFound: {e}')
24 | response = ServerResponse(status='error', message='Row not found')
25 |
26 | except MultipleResultsFound as e:
27 | logger.exception(f'MultipleResultsFound: {e}')
28 | response = ServerResponse(
29 | status='error', message='Multiple rows found')
30 |
31 | except IntegrityError as e:
32 | e.hide_parameters = True
33 | logger.exception(f'IntegrityError: {e}')
34 | response = ServerResponse(
35 | status='error', message=str(e))
36 |
37 | return ORJSONResponse(response.dict(), status_code=400)
38 |
39 |
40 | class CatchAllMiddleware(BaseHTTPMiddleware):
41 | """
42 | Middleware to catch errors.
43 | """
44 | async def dispatch(self, request: Request, call_next: Callable):
45 | try:
46 | response = await call_next(request)
47 | return response
48 |
49 | except Exception as e:
50 | # TODO: SEND NOTIFICATION HERE
51 | logger.exception(e)
52 | response = ServerResponse(
53 | status='error', message=str(e))
54 | return ORJSONResponse(response.dict(), status_code=400)
55 |
56 |
57 | class ProfilingMiddleware(BaseHTTPMiddleware):
58 | """
59 | Middleware to catch errors.
60 | """
61 | async def dispatch(self, request: Request, call_next: Callable):
62 | if settings.PROFILING:
63 | # profiler = Profiler(interval=settings.profiling_interval, async_mode='enabled')
64 |
65 | # profiler.start()
66 | # response = await call_next(request)
67 | # profiler.stop()
68 |
69 | # return response
70 | return await call_next(request)
71 | else:
72 | return await call_next(request)
73 |
--------------------------------------------------------------------------------
/app/api/router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter
2 | from app.api.endpoints import auth
3 |
4 | apiRouter = APIRouter()
5 | apiRouter.include_router(auth.router, tags=['auth'])
6 |
--------------------------------------------------------------------------------
/app/auth/jwt.py:
--------------------------------------------------------------------------------
1 | from jose import jwt
2 | from loguru import logger
3 | from jose.constants import Algorithms
4 | from fastapi.security import HTTPBearer
5 | from datetime import datetime, timedelta
6 | from fastapi import Request, HTTPException
7 | from jose.exceptions import JWTClaimsError, JWTError, ExpiredSignatureError
8 |
9 | from app.core.config import settings
10 | from app.cache.redis import SessionStore
11 | from app.util.common import generateNonce
12 |
13 | from app.types.server import Cookie
14 | from app.types.jwt import TokenData, JWTPayload
15 | from app.types.cache import RedisTokenPrefix, UserKey
16 |
17 | ALGORITHM = Algorithms.HS256
18 |
19 |
20 | def create_jwt(data: TokenData) -> tuple[str, str]:
21 | """
22 | Create access and refresh JWT tokens.
23 | If the user ID is provided, the database won't be queried.
24 | If the user does not exist in the database and no user ID is provided, a new user will be created.
25 | The nonce is stored in the cache for refresh token invalidation.
26 | """
27 | nonce = None
28 | if settings.JWT_USE_NONCE:
29 | nonce = generateNonce()
30 |
31 | access_token = create_token(data, nonce, settings.ACCESS_TOKEN_EXPIRE_MINUTES)
32 | refresh_token = create_token(data, nonce, settings.REFRESH_TOKEN_EXPIRE_MINUTES)
33 |
34 | # Save the nonce in the cache for refresh token invalidation, only if using nonce
35 | if nonce:
36 | set_nonce_in_cache(data.sub, nonce, settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60)
37 |
38 | return access_token, refresh_token
39 |
40 |
41 | def verify_token(token: str) -> JWTPayload | None:
42 | """
43 | Decode a JWT token.
44 | """
45 | try:
46 | payload = JWTPayload(
47 | **jwt.decode(
48 | token,
49 | settings.SECRET_KEY,
50 | algorithms=[ALGORITHM],
51 | options={
52 | 'require_iat': True,
53 | 'require_exp': True,
54 | 'require_sub': True,
55 | },
56 | )
57 | )
58 | if settings.JWT_USE_NONCE and not payload.nonce:
59 | logger.error('Nonce not found in JWT payload.')
60 | return None
61 | return payload
62 | except (JWTError, ExpiredSignatureError, JWTClaimsError) as e:
63 | logger.error(f'Error while verifying JWT: {e}')
64 | return None
65 |
66 |
67 | class RequireJWT(HTTPBearer):
68 | """
69 | Custom FastAPI dependency for JWT authentication.
70 | Returns the decoded JWT payload if the token is valid.
71 | """
72 |
73 | async def __call__(self, request: Request):
74 | credentials = await super(RequireJWT, self).__call__(request)
75 |
76 | if credentials and credentials.credentials:
77 | payload = verify_token(credentials.credentials)
78 | if not payload:
79 | raise HTTPException(status_code=403, detail='Invalid token or expired token.')
80 |
81 | validate_nonce(payload)
82 | return payload
83 | else:
84 | raise HTTPException(status_code=403, detail='Invalid authorization code.')
85 |
86 |
87 | def RequireRefreshToken(request: Request) -> JWTPayload:
88 | refreshToken = request.cookies.get(Cookie.REFRESH_TOKEN, '')
89 | payload = verify_token(refreshToken)
90 | if not payload:
91 | raise HTTPException(status_code=403, detail='Invalid token or expired token.')
92 |
93 | validate_nonce(payload)
94 | return payload
95 |
96 |
97 | def create_token(data: TokenData, nonce: str | None, expire_minutes: int) -> str:
98 | payload = {
99 | **data.model_dump(),
100 | 'exp': datetime.utcnow() + timedelta(minutes=expire_minutes),
101 | 'iat': datetime.utcnow(),
102 | }
103 | if settings.JWT_USE_NONCE and nonce:
104 | payload['nonce'] = nonce
105 | return jwt.encode(payload, settings.SECRET_KEY, algorithm=ALGORITHM)
106 |
107 |
108 | def set_nonce_in_cache(user_id: str, nonce: str, expiration_time: int):
109 | """
110 | Store nonce in cache with a specified expiration time.
111 | """
112 | if settings.JWT_USE_NONCE:
113 | cache = SessionStore(RedisTokenPrefix.USER, user_id, ttl=expiration_time)
114 | cache.set(UserKey.NONCE, nonce)
115 |
116 |
117 | def validate_nonce(payload: JWTPayload):
118 | if not settings.JWT_USE_NONCE:
119 | return
120 | if not is_nonce_in_cache(payload.sub, payload.nonce):
121 | raise HTTPException(status_code=403, detail='Invalid authorization code.')
122 |
123 |
124 | def is_nonce_in_cache(user_id: str, nonce: str | None) -> bool:
125 | """
126 | Check if the nonce is in the cache.
127 | """
128 | if not nonce:
129 | return False
130 | cache = SessionStore(RedisTokenPrefix.USER, user_id)
131 | return cache.get(UserKey.NONCE) == nonce
132 |
--------------------------------------------------------------------------------
/app/auth/key.py:
--------------------------------------------------------------------------------
1 | import os
2 | import binascii
3 |
4 |
5 | def generate_key():
6 | """
7 | Generate a random hex key for the application
8 | """
9 | return '0x' + binascii.hexlify(os.urandom(6)).decode()
10 |
--------------------------------------------------------------------------------
/app/cache/redis.py:
--------------------------------------------------------------------------------
1 | import redis
2 |
3 | from app.core.config import settings
4 |
5 |
6 | class SessionStore:
7 | _pool = None
8 |
9 | @classmethod
10 | def get_pool(cls):
11 | if cls._pool is None:
12 | cls._pool = redis.ConnectionPool(
13 | host=settings.REDIS_HOST,
14 | port=settings.REDIS_PORT,
15 | password=settings.REDIS_PASSWORD,
16 | connection_class=redis.SSLConnection,
17 | )
18 | return cls._pool
19 |
20 | def __init__(self, *tokens: str, ttl: int = (60 * 60 * 4)):
21 | """
22 | Params:\n
23 | tokens - Used to create a session in redis. Key/value pairs are unique to this token.
24 | Pass in multiple tokens and they will be joined into one.\n
25 | ttl - Time to live in seconds. Defaults to 4 hours.
26 | """
27 | self.token = ':'.join(tokens)
28 | self.redis = redis.StrictRedis(connection_pool=self.get_pool())
29 | self.ttl = ttl
30 |
31 | def set(self, key: str, value: str) -> int:
32 | self._refresh()
33 | return self.redis.hset(self.token, key, value)
34 |
35 | def get(self, key: str) -> str:
36 | self._refresh()
37 | val = self.redis.hget(self.token, key)
38 | if not val:
39 | return ''
40 | return val.decode('utf-8')
41 |
42 | def delete(self, key: str) -> int:
43 | self._refresh()
44 | return self.redis.hdel(self.token, key)
45 |
46 | def deleteSelf(self):
47 | self.redis.delete(self.token)
48 |
49 | def incr(self, key: str, amount: int = 1) -> int:
50 | self._refresh()
51 | return self.redis.hincrby(self.token, key, amount)
52 |
53 | def _refresh(self):
54 | self.redis.expire(self.token, self.ttl)
55 |
--------------------------------------------------------------------------------
/app/cache/time_cache.py:
--------------------------------------------------------------------------------
1 | import time
2 | import asyncio
3 | import functools
4 |
5 | from collections import OrderedDict
6 |
7 |
8 | def time_cache(max_age_seconds, maxsize=128, typed=False):
9 | """
10 | Least-recently-used cache decorator with time-based cache invalidation.
11 |
12 | Args:
13 | max_age_seconds: Time to live for cached results (in seconds).
14 | maxsize: Maximum cache size (see `functools.lru_cache`).
15 | typed: Cache on distinct input types (see `functools.lru_cache`).
16 | """
17 | def _decorator(fn):
18 | @functools.lru_cache(maxsize=maxsize, typed=typed)
19 | def _new(*args, __time_salt, **kwargs):
20 | return fn(*args, **kwargs)
21 |
22 | @functools.wraps(fn)
23 | def _wrapped(*args, **kwargs):
24 | return _new(*args, **kwargs, __time_salt=int(time.time() / max_age_seconds))
25 |
26 | return _wrapped
27 |
28 | return _decorator
29 |
30 |
31 | def aio_time_cache(max_age_seconds, maxsize=128, typed=False):
32 | """Least-recently-used cache decorator with time-based cache invalidation for async functions.
33 |
34 | Args:
35 | max_age_seconds: Time to live for cached results (in seconds).
36 | maxsize: Maximum cache size.
37 | typed: Cache on distinct input types.
38 | """
39 | cache = OrderedDict()
40 | lock = asyncio.Lock()
41 |
42 | def _key(args, kwargs):
43 | return args, tuple(sorted(kwargs.items()))
44 |
45 | def _decorator(fn):
46 | @functools.wraps(fn)
47 | async def _wrapped(*args, **kwargs):
48 | async with lock:
49 | key = _key(args, kwargs)
50 | if typed:
51 | key += (tuple(type(arg) for arg in args),
52 | tuple(type(value) for value in kwargs.values()))
53 | now = time.time()
54 |
55 | # Cache hit and check if value is still fresh
56 | if key in cache:
57 | result, timestamp = cache.pop(key)
58 | if now - timestamp <= max_age_seconds:
59 | # Move to end to show that it was recently used
60 | cache[key] = (result, timestamp)
61 | return result
62 |
63 | # Cache miss or value has expired
64 | result = await fn(*args, **kwargs)
65 | cache[key] = (result, now)
66 |
67 | # Remove oldest items if cache is full
68 | while len(cache) > maxsize:
69 | cache.popitem(last=False)
70 |
71 | return result
72 |
73 | return _wrapped
74 |
75 | return _decorator
76 |
--------------------------------------------------------------------------------
/app/core/config.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from pydantic_settings import BaseSettings
3 | from pydantic import AnyHttpUrl, validator
4 |
5 |
6 | class EnvConfigSettings(BaseSettings):
7 | """
8 | Settings for the app. Reads from environment variables.
9 | """
10 | DEBUG: bool
11 | API_V1_STR: str = '/api/v1'
12 | SECRET_KEY: str
13 | REFRESH_KEY: str
14 | PROFILING: bool = False
15 | JWT_USE_NONCE: bool
16 | ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # 30 minutes
17 | REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 3 # 3 days
18 | BACKEND_CORS_ORIGINS: list[AnyHttpUrl] = []
19 |
20 | @validator('BACKEND_CORS_ORIGINS', pre=True)
21 | def assemble_cors_origins(cls, v: str | list[str]) -> list[str] | str:
22 | if isinstance(v, str) and not v.startswith('['):
23 | return [i.strip() for i in v.split(',')]
24 | elif isinstance(v, (list, str)):
25 | return v
26 | raise ValueError(v)
27 |
28 | # MySQL
29 | MYSQL_HOST: str
30 | MYSQL_USER: str
31 | MYSQL_PASSWORD: str
32 | MYSQL_DATABASE: str
33 | MYSQL_DATABASE_URI: str | None = None
34 | MYSQL_SSL: str
35 |
36 | @validator('MYSQL_DATABASE_URI', pre=True)
37 | def assemble_mysql_connection(cls, v: str | None, values: dict[str, Any]) -> Any:
38 | if isinstance(v, str):
39 | return v
40 | return (f"mysql+pymysql://{values.get('MYSQL_USER')}:{values.get('MYSQL_PASSWORD')}"
41 | f"@{values.get('MYSQL_HOST')}/{values.get('MYSQL_DATABASE')}")
42 |
43 | # Redis
44 | REDIS_HOST: str
45 | REDIS_PORT: int
46 | REDIS_PASSWORD: str
47 |
48 | class Config:
49 | case_sensitive = True
50 | env_file = '.env'
51 | extra = 'ignore'
52 |
53 |
54 | settings = EnvConfigSettings(**{})
55 |
--------------------------------------------------------------------------------
/app/db/connection.py:
--------------------------------------------------------------------------------
1 | from app.core.config import settings
2 | from sqlalchemy.engine import Engine
3 | from sqlalchemy import create_engine
4 | from sqlalchemy.orm import sessionmaker, Session, DeclarativeBase
5 |
6 | if not settings.MYSQL_DATABASE_URI:
7 | raise ValueError('Missing database URI')
8 |
9 | mysqlEngine: Engine = create_engine(
10 | settings.MYSQL_DATABASE_URI,
11 | echo=settings.DEBUG,
12 | pool_pre_ping=True,
13 | connect_args={
14 | 'ssl_ca': settings.MYSQL_SSL,
15 | },
16 | )
17 |
18 |
19 | class MySQLTableBase(DeclarativeBase):
20 | pass
21 |
22 |
23 | def MySqlSession(expireOnCommit: bool = False) -> Session:
24 | return sessionmaker(bind=mysqlEngine, expire_on_commit=expireOnCommit)()
25 |
26 |
27 | def createMySQLTables():
28 | MySQLTableBase.metadata.create_all(mysqlEngine)
29 |
30 |
31 | def dropMySQLTables():
32 | MySQLTableBase.metadata.drop_all(mysqlEngine)
33 |
--------------------------------------------------------------------------------
/app/db/dep.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 | from sqlalchemy.orm import Session
3 | from app.db.connection import MySqlSession
4 | from app.cache.time_cache import time_cache
5 |
6 |
7 | def get_db() -> Iterator[Session]:
8 | """
9 | Dependency that gets a cached database session.
10 | """
11 | session = _get_session()
12 | try:
13 | yield session
14 | session.commit()
15 | except Exception as exc:
16 | session.rollback()
17 | raise exc
18 | finally:
19 | session.close()
20 |
21 |
22 | @time_cache(max_age_seconds=60 * 10)
23 | def _get_session():
24 | return MySqlSession()
25 |
--------------------------------------------------------------------------------
/app/db/util.py:
--------------------------------------------------------------------------------
1 | from app.db.connection import mysqlEngine
2 |
3 |
4 | def printTables():
5 | tables = mysqlEngine.table_names()
6 | print('Tables:')
7 | for table in tables:
8 | print('-', table)
9 |
--------------------------------------------------------------------------------
/app/discord/client.py:
--------------------------------------------------------------------------------
1 | import aiohttp
2 |
3 | from typing import Optional
4 | from fastapi import Depends, Request
5 | from app.cache.time_cache import aio_time_cache
6 | from typing_extensions import TypedDict, Literal
7 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
8 |
9 | from app.discord.models.user import User
10 | from app.discord.models.guild import Guild, GuildPreview
11 | from app.discord.config import DISCORD_API_URL, DISCORD_OAUTH_AUTHENTICATION_URL, DISCORD_TOKEN_URL
12 | from app.discord.exceptions import RateLimited, ScopeMissing, Unauthorized, InvalidToken, \
13 | ClientSessionNotInitialized
14 |
15 |
16 | class RefreshTokenPayload(TypedDict):
17 | client_id: str
18 | client_secret: str
19 | grant_type: Literal['refresh_token']
20 | refresh_token: str
21 |
22 |
23 | class TokenGrantPayload(TypedDict):
24 | client_id: str
25 | client_secret: str
26 | grant_type: Literal['authorization_code']
27 | code: str
28 | redirect_uri: str
29 |
30 |
31 | class TokenResponse(TypedDict):
32 | access_token: str
33 | token_type: str
34 | expires_in: int
35 | refresh_token: str
36 | scope: str
37 |
38 |
39 | PAYLOAD = TokenGrantPayload | RefreshTokenPayload
40 |
41 |
42 | def _tokens(resp: TokenResponse) -> tuple[str, str]:
43 | """
44 | Extracts tokens from TokenResponse
45 |
46 | Parameters
47 | ----------
48 | resp: TokenResponse
49 | Response
50 |
51 | Returns
52 | -------
53 | Tuple[str, str]
54 | An union of access_token and refresh_token
55 |
56 | Raises
57 | ------
58 | InvalidToken
59 | If tokens are `None`
60 | """
61 | access_token, refresh_token = resp.get(
62 | 'access_token'), resp.get('refresh_token')
63 | if access_token is None or refresh_token is None:
64 | raise InvalidToken('Tokens can\'t be None')
65 | return access_token, refresh_token
66 |
67 |
68 | class DiscordOAuthClient:
69 | """
70 | Client for Discord Oauth2.
71 |
72 | Parameters
73 | ----------
74 | client_id:
75 | Discord application client ID.
76 | client_secret:
77 | Discord application client secret.
78 | redirect_uri:
79 | Discord application redirect URI.
80 | scopes:
81 | Optional Discord Oauth scopes
82 | proxy:
83 | Optional proxy url
84 | proxy_auth:
85 | Optional aiohttp.BasicAuth proxy authentification
86 | """
87 | client_id: str
88 | client_secret: str
89 | redirect_uri: str
90 | scopes: str
91 | proxy: Optional[str] = None
92 | proxy_auth: Optional[aiohttp.BasicAuth] = None
93 | client_session: Optional[aiohttp.ClientSession] = None
94 |
95 | def __init__(
96 | self,
97 | client_id,
98 | client_secret,
99 | redirect_uri,
100 | scopes=("identify",),
101 | proxy=None,
102 | proxy_auth: Optional[aiohttp.BasicAuth] = None,
103 | ):
104 | self.client_id = client_id
105 | self.client_secret = client_secret
106 | self.redirect_uri = redirect_uri
107 | self.scopes = '%20'.join(scope for scope in scopes)
108 | self.proxy = proxy
109 | self.proxy_auth = proxy_auth
110 |
111 | async def init(self):
112 | """
113 | Initialized the connection to the discord api
114 | """
115 | if self.client_session is not None:
116 | return
117 | self.client_session = aiohttp.ClientSession()
118 |
119 | def get_oauth_login_url(self, state: Optional[str] = None):
120 | """
121 |
122 | Returns a Discord Login URL
123 |
124 | """
125 | client_id = f'client_id={self.client_id}'
126 | redirect_uri = f'redirect_uri={self.redirect_uri}'
127 | scopes = f'scope={self.scopes}'
128 | response_type = 'response_type=code'
129 | state = f'&state={state}' if state else ''
130 | return f'{DISCORD_OAUTH_AUTHENTICATION_URL}?{client_id}&{redirect_uri}&{scopes}' + \
131 | f'&{response_type}{state}'
132 |
133 | oauth_login_url = property(get_oauth_login_url)
134 |
135 | @aio_time_cache(max_age_seconds=550)
136 | async def request(self, route: str, token: Optional[str] = None,
137 | method: Literal['GET', 'POST'] = 'GET'):
138 | if self.client_session is None:
139 | raise ClientSessionNotInitialized
140 | headers: dict = {}
141 | if token:
142 | headers = {'Authorization': f'Bearer {token}'}
143 | if method == 'GET':
144 | async with self.client_session.get(
145 | f'{DISCORD_API_URL}{route}',
146 | headers=headers,
147 | proxy=self.proxy,
148 | proxy_auth=self.proxy_auth,
149 | ) as resp:
150 | data = await resp.json()
151 | elif method == 'POST':
152 | async with self.client_session.post(
153 | f'{DISCORD_API_URL}{route}',
154 | headers=headers,
155 | proxy=self.proxy,
156 | proxy_auth=self.proxy_auth,
157 | ) as resp:
158 | data = await resp.json()
159 | else:
160 | raise Exception(
161 | 'Other HTTP than GET and POST are currently not Supported')
162 | if resp.status == 401:
163 | raise Unauthorized
164 | if resp.status == 429:
165 | raise RateLimited(data, resp.headers)
166 | return data
167 |
168 | async def get_token_response(self, payload: PAYLOAD) -> TokenResponse:
169 | if self.client_session is None:
170 | raise ClientSessionNotInitialized
171 | async with self.client_session.post(
172 | DISCORD_TOKEN_URL,
173 | data=payload,
174 | proxy=self.proxy,
175 | proxy_auth=self.proxy_auth,
176 | ) as resp:
177 | return await resp.json()
178 |
179 | async def get_access_token(self, code: str) -> tuple[str, str]:
180 | payload: TokenGrantPayload = {
181 | 'client_id': self.client_id,
182 | 'client_secret': self.client_secret,
183 | 'grant_type': 'authorization_code',
184 | 'code': code,
185 | 'redirect_uri': self.redirect_uri,
186 | }
187 | resp = await self.get_token_response(payload)
188 | return _tokens(resp)
189 |
190 | async def refresh_access_token(self, refresh_token: str) -> tuple[str, str]:
191 | payload: RefreshTokenPayload = {
192 | 'client_id': self.client_id,
193 | 'client_secret': self.client_secret,
194 | 'grant_type': 'refresh_token',
195 | 'refresh_token': refresh_token,
196 | }
197 | resp = await self.get_token_response(payload)
198 | return _tokens(resp)
199 |
200 | async def user(self, request: Request):
201 | if 'identify' not in self.scopes:
202 | raise ScopeMissing('identify')
203 | route = '/users/@me'
204 | token = self.get_token(request)
205 | return User(**(await self.request(route, token)))
206 |
207 | async def guilds(self, request: Request) -> list[GuildPreview]:
208 | if 'guilds' not in self.scopes:
209 | raise ScopeMissing('guilds')
210 | route = '/users/@me/guilds'
211 | token = self.get_token(request)
212 | return [Guild(**guild) for guild in await self.request(route, token)]
213 |
214 | def get_token(self, request: Request):
215 | authorization_header = request.headers.get('Authorization')
216 | if not authorization_header:
217 | raise Unauthorized
218 | all_headers = authorization_header.split(' ')
219 | if not all_headers[0] == 'Bearer' or len(all_headers) > 2:
220 | raise Unauthorized
221 |
222 | token = all_headers[1]
223 | return token
224 |
225 | async def isAuthenticated(self, token: str):
226 | route = '/oauth2/@me'
227 | try:
228 | await self.request(route, token)
229 | return True
230 | except Unauthorized:
231 | return False
232 |
233 | async def requires_authorization(self, bearer: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer())): # noqa: E501
234 | if bearer is None:
235 | raise Unauthorized
236 | if not await self.isAuthenticated(bearer.credentials):
237 | raise Unauthorized
238 |
--------------------------------------------------------------------------------
/app/discord/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Discord module code taken and updated from:
3 | https://github.com/Tert0/fastapi-discord
4 | """
5 |
6 | DISCORD_URL = 'https://discord.com'
7 | DISCORD_API_URL = f'{DISCORD_URL}/api/v10'
8 | DISCORD_OAUTH_URL = f'{DISCORD_URL}/api/oauth2'
9 | DISCORD_TOKEN_URL = f'{DISCORD_OAUTH_URL}/token'
10 | DISCORD_OAUTH_AUTHENTICATION_URL = f'{DISCORD_OAUTH_URL}/authorize'
11 |
--------------------------------------------------------------------------------
/app/discord/exceptions.py:
--------------------------------------------------------------------------------
1 | class Unauthorized(Exception):
2 | """Raised when user is not authorized."""
3 |
4 |
5 | class InvalidRequest(Exception):
6 | """Raised when a Request is not Valid"""
7 |
8 |
9 | class RateLimited(Exception):
10 | """Raised when a Request is not Valid"""
11 |
12 | def __init__(self, json, headers):
13 | self.json = json
14 | self.headers = headers
15 | self.message = json['message']
16 | self.retry_after = json['retry_after']
17 | super().__init__(self.message)
18 |
19 |
20 | class InvalidToken(Exception):
21 | """Raised when a Response has invalid tokens"""
22 |
23 |
24 | class ScopeMissing(Exception):
25 | scope: str
26 |
27 | def __init__(self, scope: str):
28 | self.scope = scope
29 | super().__init__(self.scope)
30 |
31 |
32 | class ClientSessionNotInitialized(Exception):
33 | """Raised when no Client Session is initialized but one would be needed"""
34 | pass
35 |
--------------------------------------------------------------------------------
/app/discord/models/guild.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from pydantic import BaseModel
3 | from app.discord.models.role import Role
4 |
5 |
6 | class GuildPreview(BaseModel):
7 | id: str
8 | name: str
9 | icon: Optional[str] = None
10 | owner: bool
11 | permissions: int
12 | features: list[str]
13 |
14 |
15 | class Guild(GuildPreview):
16 | owner_id: Optional[int] = None
17 | verification_level: Optional[int] = None
18 | default_message_notifications: Optional[int] = None
19 | roles: Optional[list[Role]] = None
20 |
--------------------------------------------------------------------------------
/app/discord/models/role.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class Role(BaseModel):
5 | id: int
6 | name: str
7 | color: int
8 | position: int
9 | permissions: int
10 | managed: bool
11 | mentionable: bool
12 |
--------------------------------------------------------------------------------
/app/discord/models/user.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from pydantic import BaseModel
3 |
4 |
5 | class User(BaseModel):
6 | id: str
7 | username: str
8 | discriminator: Optional[int] = None
9 | global_name: Optional[str] = None
10 | avatar: Optional[str]
11 | avatar_url: Optional[str] = 'https://cdn.discordapp.com/embed/avatars/1.png'
12 | locale: str
13 | email: Optional[str] = None
14 | mfa_enabled: Optional[bool] = None
15 | flags: Optional[int] = None
16 | premium_type: Optional[int] = None
17 | public_flags: Optional[int] = None
18 | banner: Optional[str] = None
19 | accent_color: Optional[int] = None
20 | verified: Optional[bool] = None
21 | avatar_decoration: Optional[str] = None
22 |
23 | def __init__(self, **data):
24 | super().__init__(**data)
25 | if self.avatar:
26 | self.avatar_url = f'https://cdn.discordapp.com/avatars/{self.id}/{self.avatar}.png'
27 | else:
28 | self.avatar_url = 'https://cdn.discordapp.com/embed/avatars/1.png'
29 | if self.discriminator == 0:
30 | self.discriminator = None
31 |
--------------------------------------------------------------------------------
/app/lmbd/invoke.py:
--------------------------------------------------------------------------------
1 | import boto3
2 |
3 | from typing import Callable
4 | from app.types.lmbd import InvokeResponse
5 |
6 | LAMBDA = boto3.client('lambda')
7 |
8 |
9 | def catchLambdaError(func: Callable):
10 | def wrapper(*args, **kwargs):
11 | try:
12 | return func(*args, **kwargs)
13 | # TODO: Catch specific Lambda errors
14 | except Exception:
15 | # TODO: SEND ERROR NOTIFICATION
16 | return {}
17 | return wrapper
18 |
19 |
20 | @catchLambdaError
21 | def invokeEvent(payload: str, functionName: str) -> InvokeResponse:
22 | return LAMBDA.invoke(
23 | FunctionName=functionName,
24 | InvocationType='Event',
25 | LogType='None',
26 | Payload=payload,
27 | )
28 |
--------------------------------------------------------------------------------
/app/lmbd/sample_data/snsCloudWatchEvent.json:
--------------------------------------------------------------------------------
1 | {
2 | "Records": [
3 | {
4 | "EventVersion": "1.0",
5 | "EventSubscriptionArn": "arn:aws:sns:EXAMPLE",
6 | "EventSource": "aws:sns",
7 | "Sns": {
8 | "SignatureVersion": "1",
9 | "Timestamp": "1970-01-01T00:00:00.000Z",
10 | "Signature": "EXAMPLE",
11 | "SigningCertUrl": "EXAMPLE",
12 | "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
13 | "Message": {
14 | "AlarmName": "Saffron-Octopus-RDS",
15 | "AlarmDescription": null,
16 | "AWSAccountId": "498849832712",
17 | "NewStateValue": "ALARM",
18 | "NewStateReason": "Threshold Crossed: 1 datapoint [2.1533759377604764 (20/07/20 21: 07: 00)] was greater than or equal to the threshold (0.0175).",
19 | "StateChangeTime": "2020-07-20T21: 12: 01.544+0000",
20 | "Region": "US East (N. Virginia)",
21 | "AlarmArn": "arn:aws:cloudwatch:us-east-1: 498849832712:alarm:Saffron-Octopus-RDS",
22 | "OldStateValue": "INSUFFICIENT_DATA",
23 | "Trigger": {
24 | "MetricName": "CPUUtilization",
25 | "Namespace": "AWS/RDS",
26 | "StatisticType": "Statistic",
27 | "Statistic": "AVERAGE",
28 | "Unit": null,
29 | "Dimensions": [
30 | {
31 | "value": "sm16lm1jrrjf0rk",
32 | "name": "DBInstanceIdentifier"
33 | }
34 | ],
35 | "Period": 300,
36 | "EvaluationPeriods": 1,
37 | "ComparisonOperator": "GreaterThanOrEqualToThreshold",
38 | "Threshold": 0.0175,
39 | "TreatMissingData": "",
40 | "EvaluateLowSampleCountPercentile": ""
41 | }
42 | },
43 | "MessageAttributes": {
44 | "Test": {
45 | "Type": "String",
46 | "Value": "TestString"
47 | },
48 | "TestBinary": {
49 | "Type": "Binary",
50 | "Value": "TestBinary"
51 | }
52 | },
53 | "Type": "Notification",
54 | "UnsubscribeUrl": "EXAMPLE",
55 | "TopicArn": "arn:aws:sns:EXAMPLE",
56 | "Subject": "TestInvoke"
57 | }
58 | }
59 | ]
60 | }
--------------------------------------------------------------------------------
/app/lmbd/sample_lambda/main.py:
--------------------------------------------------------------------------------
1 | # TODO: SAMPLE LAMBDA
2 |
--------------------------------------------------------------------------------
/app/log/setup.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 |
3 |
4 | def setup_logging():
5 | # Format of the logs
6 | log_format = (
7 | '{time:YYYY-MM-DD HH:mm:ss!UTC} | '
8 | '{level: <8} | '
9 | '{name}:{function}:{line} - '
10 | '{message}'
11 | )
12 |
13 | # General logs
14 | logger.add(
15 | 'logs/server.log',
16 | rotation='10 MB',
17 | retention='30 days',
18 | level='INFO', # INFO logs will log everything from INFO and above (WARNING, ERROR, etc.)
19 | format=log_format,
20 | enqueue=True,
21 | filter=lambda record: not bool(record['exception']),
22 | )
23 |
24 | # Exception logs
25 | logger.add(
26 | 'logs/exceptions.log',
27 | rotation='10 MB',
28 | retention='30 days',
29 | level='ERROR', # ERROR will only log errors
30 | format=log_format,
31 | enqueue=True,
32 | backtrace=True,
33 | diagnose=True,
34 | filter=lambda record: bool(record['exception']),
35 | )
36 |
--------------------------------------------------------------------------------
/app/migrations/env.py:
--------------------------------------------------------------------------------
1 | from alembic import context
2 | from sqlalchemy import pool
3 | from app.core.config import settings
4 | from logging.config import fileConfig
5 | from sqlalchemy import engine_from_config
6 | from app.db.connection import MySQLTableBase
7 |
8 | # This is the Alembic Config object, which provides
9 | # access to the values within the .ini file in use.
10 | config = context.config
11 |
12 | # Interpret the config file for Python logging.
13 | # This line sets up loggers basically.
14 | if config.config_file_name is not None:
15 | fileConfig(config.config_file_name)
16 |
17 | TARGET_METADATA = MySQLTableBase.metadata
18 |
19 |
20 | def run_migrations_offline():
21 | """
22 | Run migrations in 'offline' mode.
23 |
24 | This configures the context with just a URL
25 | and not an Engine, though an Engine is acceptable
26 | here as well. By skipping the Engine creation
27 | we don't even need a DBAPI to be available.
28 |
29 | Calls to context.execute() here emit the given string to the
30 | script output.
31 | """
32 | context.configure(
33 | url=settings.MYSQL_DATABASE_URI,
34 | target_metadata=TARGET_METADATA,
35 | literal_binds=True,
36 | dialect_opts={'paramstyle': 'named'},
37 | )
38 |
39 | with context.begin_transaction():
40 | context.run_migrations()
41 |
42 |
43 | def run_migrations_online():
44 | """
45 | Run migrations in 'online' mode.
46 |
47 | In this scenario we need to create an Engine
48 | and associate a connection with the context.
49 | """
50 | connectable = engine_from_config(
51 | config.get_section(config.config_ini_section),
52 | prefix='sqlalchemy.',
53 | poolclass=pool.NullPool,
54 | )
55 |
56 | with connectable.connect() as connection:
57 | context.configure(
58 | connection=connection,
59 | target_metadata=TARGET_METADATA
60 | )
61 |
62 | with context.begin_transaction():
63 | context.run_migrations()
64 |
65 |
66 | if context.is_offline_mode():
67 | run_migrations_offline()
68 | else:
69 | run_migrations_online()
70 |
--------------------------------------------------------------------------------
/app/migrations/script.py.mako:
--------------------------------------------------------------------------------
1 | """${message}
2 |
3 | Revision ID: ${up_revision}
4 | Revises: ${down_revision | comma,n}
5 | Create Date: ${create_date}
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 | ${imports if imports else ""}
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = ${repr(up_revision)}
14 | down_revision = ${repr(down_revision)}
15 | branch_labels = ${repr(branch_labels)}
16 | depends_on = ${repr(depends_on)}
17 |
18 |
19 | def upgrade() -> None:
20 | ${upgrades if upgrades else "pass"}
21 |
22 |
23 | def downgrade() -> None:
24 | ${downgrades if downgrades else "pass"}
25 |
--------------------------------------------------------------------------------
/app/models/mysql.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from app.db.connection import MySQLTableBase
3 |
4 | from sqlalchemy.sql import func
5 | from sqlalchemy.dialects.mysql import BIGINT
6 | from sqlalchemy import String, ForeignKey, DateTime, Boolean
7 | from sqlalchemy.orm import relationship, Mapped, mapped_column
8 |
9 |
10 | class User(MySQLTableBase):
11 | __tablename__ = 'User'
12 |
13 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True)
14 | lastActive: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
15 | joinDate: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
16 |
17 | # One-to-many relationship with Post
18 | posts: Mapped[list['Post']] = relationship(back_populates='user',
19 | lazy='select')
20 | # One-to-one relationship with Profile
21 | profile: Mapped['Profile'] = relationship(back_populates='user',
22 | lazy='select')
23 |
24 | # Many-to-many with association object UserNotification
25 | notifications: Mapped[list['UserNotification']] = relationship(
26 | 'UserNotification', back_populates='user', lazy='select', cascade='all, delete-orphan'
27 | )
28 |
29 | def __repr__(self) -> str:
30 | return f''
31 |
32 |
33 | class UserNotification(MySQLTableBase):
34 | """
35 | Association Object for User and Notification.
36 | Represents the notifications received by a user.
37 | """
38 | __tablename__ = 'UserNotification'
39 |
40 | userID: Mapped[int] = mapped_column(
41 | BIGINT(unsigned=True), ForeignKey('User.id', ondelete='CASCADE'), primary_key=True
42 | )
43 | notificationID: Mapped[int] = mapped_column(
44 | BIGINT(unsigned=True), ForeignKey('Notification.id', ondelete='CASCADE'), primary_key=True
45 | )
46 | read: Mapped[bool] = mapped_column(Boolean)
47 |
48 | # Many-to-one relationship with User and Notification
49 | user: Mapped['User'] = relationship('User', back_populates='notifications')
50 | notification: Mapped['Notification'] = relationship('Notification', back_populates='users')
51 |
52 | def __repr__(self) -> str:
53 | return f''
54 |
55 |
56 | class Notification(MySQLTableBase):
57 | """
58 | Notification class representing the notifications sent to the users.
59 | """
60 |
61 | __tablename__ = 'Notification'
62 |
63 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True)
64 | message: Mapped[str] = mapped_column(String(length=256))
65 | time: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
66 |
67 | # One-to-many relationship with UserNotification
68 | users: Mapped[list['UserNotification']] = relationship(
69 | 'UserNotification', back_populates='notification', lazy='select', passive_deletes=True
70 | )
71 |
72 | def __repr__(self) -> str:
73 | return f''
74 |
75 |
76 | class Post(MySQLTableBase):
77 | """
78 | Post class representing the posts created by the users.
79 | """
80 |
81 | __tablename__ = 'Post'
82 |
83 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True)
84 | content: Mapped[str] = mapped_column(String(length=1024))
85 | created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
86 | user_id: Mapped[int] = mapped_column(BIGINT(unsigned=True), ForeignKey('User.id'))
87 |
88 | # Relationship with User
89 | user: Mapped['User'] = relationship('User', back_populates='posts')
90 |
91 | def __repr__(self) -> str:
92 | return f''
93 |
94 |
95 | class Profile(MySQLTableBase):
96 | """
97 | Profile class representing the profile of each user.
98 | """
99 |
100 | __tablename__ = 'Profile'
101 |
102 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True), primary_key=True, autoincrement=True)
103 | bio: Mapped[str] = mapped_column(String(length=256))
104 | user_id: Mapped[int] = mapped_column(BIGINT(unsigned=True), ForeignKey('User.id'), unique=True)
105 |
106 | # Relationship with User
107 | user: Mapped['User'] = relationship('User', back_populates='profile')
108 |
109 | def __repr__(self) -> str:
110 | return f''
111 |
--------------------------------------------------------------------------------
/app/models/mysql_no_fk.py:
--------------------------------------------------------------------------------
1 | """
2 | In this example, we use SQLAlchemy's ORM to create a MySQL tables *without* foreign keys.
3 | Instead, we create relationships in place of foreign keys. There are many reasons why
4 | someone might want to not use foreign keys:
5 | 1. Schema flexibility
6 | 2. Cross-database relationships
7 | 3. Performance
8 | 4. database agnostic code
9 | """
10 |
11 | from datetime import datetime
12 | from sqlalchemy.sql import func
13 | from app.db.connection import MySQLTableBase
14 | from sqlalchemy.dialects.mysql import BIGINT
15 | from sqlalchemy import String, Boolean, DateTime
16 | from sqlalchemy.orm import relationship, Session, mapped_column, Mapped
17 |
18 | # TODO: CHECK IF THESE RELATIONSHIPS LOAD LAZILY OR EAGERLY
19 | # TODO: FIX WARNING ABOUT RELATIONSHIP COPYING COLUMN
20 |
21 |
22 | class User(MySQLTableBase):
23 | """
24 | User class representing the user.
25 | """
26 | __tablename__ = 'User'
27 |
28 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True),
29 | primary_key=True,
30 | autoincrement=True)
31 | address: Mapped[str] = mapped_column(String(length=42),
32 | primary_key=True,
33 | unique=True)
34 | lastActive: Mapped[datetime] = mapped_column(DateTime,
35 | server_default=func.now(),
36 | onupdate=func.now())
37 | joinDate: Mapped[datetime] = mapped_column(DateTime,
38 | server_default=func.now())
39 | isBetaUser: Mapped[bool] = mapped_column(Boolean, default=False)
40 |
41 | # Relationships
42 | notifications: Mapped[list['UserNotification']] = relationship(
43 | primaryjoin='User.id == foreign(UserNotification.userID)',
44 | )
45 |
46 | def load(self, session: Session) -> 'User':
47 | """
48 | Load the user's data from the database and return the user instance.
49 | """
50 | return session.query(User).filter_by(address=self.address).scalar()
51 |
52 | def __repr__(self) -> str:
53 | return f''
54 |
55 |
56 | class UserNotification(MySQLTableBase):
57 | """
58 | Association Table for User and Notification.
59 | Represents the notifications received by a user.
60 | """
61 | __tablename__ = 'UserNotification'
62 |
63 | userID: Mapped[int] = mapped_column(
64 | BIGINT(unsigned=True), primary_key=True)
65 | notificationID: Mapped[int] = mapped_column(
66 | BIGINT(unsigned=True), primary_key=True)
67 | read: Mapped[bool] = mapped_column(Boolean, default=False)
68 |
69 | # Relationships
70 | user: Mapped[User] = relationship(
71 | primaryjoin='UserNotification.userID == foreign(User.id)'
72 | )
73 | notification: Mapped['Notification'] = relationship(
74 | primaryjoin='UserNotification.notificationID == foreign(Notification.id)'
75 | )
76 |
77 | def __repr__(self) -> str:
78 | return f''
79 |
80 |
81 | class Notification(MySQLTableBase):
82 | """
83 | Notification class representing the notification sent to the users.
84 | """
85 | __tablename__ = 'Notification'
86 |
87 | id: Mapped[int] = mapped_column(BIGINT(unsigned=True),
88 | primary_key=True,
89 | autoincrement=True)
90 | message: Mapped[str] = mapped_column(String(length=256), primary_key=True)
91 | time: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
92 |
93 | # Relationships
94 | users: Mapped[list['UserNotification']] = relationship(
95 | primaryjoin='Notification.id == foreign(UserNotification.notificationID)',
96 | )
97 |
98 | def __repr__(self) -> str:
99 | return f''
100 |
--------------------------------------------------------------------------------
/app/types/cache.py:
--------------------------------------------------------------------------------
1 | # Don't use enums because usually places where we use these values require
2 | # string values, not Enum values.
3 |
4 |
5 | class RedisTokenPrefix:
6 | """
7 | When creating a SessionStore, it is useful to have a prefix to avoid
8 | collisions with other keys in Redis.
9 | """
10 | USER = 'user'
11 |
12 |
13 | class UserKey:
14 | """
15 | Keys used to store user data in Redis.
16 | """
17 | NONCE = 'nonce'
18 |
--------------------------------------------------------------------------------
/app/types/jwt.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from pydantic import BaseModel
3 |
4 |
5 | class JWTPayload(BaseModel):
6 | """
7 | The payload of a JWT token.
8 | """
9 | # Subject
10 | sub: str
11 | # Expiration
12 | exp: datetime
13 | # Issued at
14 | iat: datetime
15 | # Unique hex number
16 | nonce: str | None = None
17 |
18 |
19 | class TokenData(BaseModel):
20 | """
21 | Data passed to create a JWT token.
22 | """
23 | # Subject
24 | sub: str
25 |
--------------------------------------------------------------------------------
/app/types/lmbd.py:
--------------------------------------------------------------------------------
1 | # AWS Lambda
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class InvokeResponse(BaseModel):
7 | StatusCode: int
8 | FunctionError: str
9 | Payload: bytes
10 | LogResult: str
11 | ExecutedVersion: str
12 |
--------------------------------------------------------------------------------
/app/types/server.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 | from pydantic import BaseModel
3 | from typing import Generic, TypeVar, Optional, Any, Literal
4 |
5 | DataT = TypeVar('DataT')
6 |
7 | ServerStatus = Literal['ok', 'error', 'missing_parameters']
8 |
9 |
10 | class ServerCode(IntEnum):
11 | OK = 200
12 | CREATED = 201
13 | NO_CONTENT = 204
14 | BAD_REQUEST = 400
15 | UNAUTHORIZED = 401
16 | FORBIDDEN = 403
17 | NOT_FOUND = 404
18 | INTERNAL_SERVER_ERROR = 500
19 |
20 |
21 | # Not an enum because it's not a finite set of values
22 | class Cookie:
23 | REFRESH_TOKEN = 'refresh_token'
24 |
25 |
26 | class ServerResponse(BaseModel, Generic[DataT]):
27 | data: Optional[DataT] = None
28 | status: ServerStatus = 'ok'
29 | message: Optional[str] = None
30 |
31 | def dict(self, *args, **kwargs) -> dict[str, Any]:
32 | """
33 | Override the default dict method to exclude None values in the response
34 | """
35 | kwargs.pop('exclude_none', None)
36 | return super().model_dump(*args, exclude_none=True, **kwargs)
37 |
--------------------------------------------------------------------------------
/app/types/time.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class TimeBucketInterval(str, Enum):
5 | FIVE_MIN = '5 min'
6 | FIFTEEN_MIN = '15 min'
7 | ONE_HOUR = '1 hour'
8 | ONE_DAY = '1 day'
9 | ONE_WEEK = '1 week'
10 | ONE_MONTH = '1 month'
11 | ONE_YEAR = '1 year'
12 |
--------------------------------------------------------------------------------
/app/util/common.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from typing import Callable
4 |
5 |
6 | def getFunctionArguments(func: Callable[..., str]) -> tuple[str, ...]:
7 | """
8 | Gets the arguments of a function.
9 | """
10 | return func.__code__.co_varnames[:func.__code__.co_argcount]
11 |
12 |
13 | def generateNonce() -> str:
14 | """
15 | Generates a 32 length nonce hex string.
16 | """
17 | return os.urandom(16).hex()
18 |
--------------------------------------------------------------------------------
/app/util/json.py:
--------------------------------------------------------------------------------
1 | import orjson
2 |
3 |
4 | def orjson_dumps_str(obj):
5 | return orjson.dumps(obj).decode('utf-8')
6 |
--------------------------------------------------------------------------------
/app/util/time.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone
2 |
3 |
4 | def iso8601ToTimestamp(iso8601: str) -> int:
5 | """
6 | Converts an ISO8601 string to a timestamp.
7 | """
8 | if iso8601.count('.') == 1:
9 | dt = datetime.strptime(iso8601, '%Y-%m-%dT%H:%M:%S.%f')
10 | else:
11 | dt = datetime.strptime(iso8601, '%Y-%m-%dT%H:%M:%S')
12 |
13 | dt = dt.replace(tzinfo=timezone.utc)
14 | return int(dt.timestamp())
15 |
--------------------------------------------------------------------------------
/dev-requirements.txt:
--------------------------------------------------------------------------------
1 | autopep8==2.3.1
2 | mypy==1.11.1
3 | types-redis==4.6.0.20240806
4 | ruff==0.5.6
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | mysql-server:
3 | image: mysql/mysql-server:latest
4 | container_name: fastapi-mysql
5 | environment:
6 | MYSQL_ROOT_PASSWORD: "root"
7 | MYSQL_DATABASE: "fastapi-db"
8 | MYSQL_USER: "root"
9 | MYSQL_PASSWORD: "root"
10 | ports:
11 | - "3306:3306"
12 | volumes:
13 | - "./docker/mysql:/var/lib/mysql"
14 |
15 | redis-server:
16 | image: redis:latest
17 | container_name: fastapi-redis
18 | ports:
19 | - "6379:6379"
20 |
--------------------------------------------------------------------------------
/docker-entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 |
5 | # Activate virtual environment
6 | . /opt/pysetup/.venv/bin/activate
7 |
8 | # You can put other setup logic here
9 |
10 | # Execute passed in command
11 | exec "$@"
12 |
--------------------------------------------------------------------------------
/gunicorn_conf.py:
--------------------------------------------------------------------------------
1 | # From: https://github.com/tiangolo/uvicorn-gunicorn-docker/blob/315f04413114e938ff37a410b5979126facc90af/python3.7/gunicorn_conf.py
2 |
3 | import os
4 | import orjson
5 | import multiprocessing
6 |
7 | workers_per_core_str = os.getenv('WORKERS_PER_CORE', '1')
8 | web_concurrency_str = os.getenv('WEB_CONCURRENCY', None)
9 | host = '0.0.0.0'
10 | port = '8000'
11 | bind_env = os.getenv('BIND', None)
12 | use_loglevel = os.getenv('LOG_LEVEL', 'info')
13 | if bind_env:
14 | use_bind = bind_env
15 | else:
16 | use_bind = f'{host}:{port}'
17 |
18 | cores = multiprocessing.cpu_count()
19 | workers_per_core = float(workers_per_core_str)
20 | default_web_concurrency = workers_per_core * cores
21 | if web_concurrency_str:
22 | web_concurrency = int(web_concurrency_str)
23 | assert web_concurrency > 0
24 | else:
25 | web_concurrency = max(int(default_web_concurrency), 2)
26 |
27 | # Gunicorn config variables
28 | loglevel = use_loglevel
29 | workers = 2 # web_concurrency
30 | bind = use_bind
31 | keepalive = 120
32 | errorlog = '-'
33 | timeout = 600
34 |
35 | # For debugging and testing
36 | log_data = {
37 | 'loglevel': loglevel,
38 | 'workers': workers,
39 | 'bind': bind,
40 | # Additional, non-gunicorn variables
41 | 'workers_per_core': workers_per_core,
42 | 'host': host,
43 | 'port': port,
44 | }
45 | print(orjson.dumps(log_data))
46 |
--------------------------------------------------------------------------------
/loguru/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The Loguru library provides a pre-instanced logger to facilitate dealing with logging in Python.
3 |
4 | Just ``from loguru import logger``.
5 | """
6 |
7 | import atexit as _atexit
8 | import sys as _sys
9 |
10 | from . import _defaults
11 | from ._logger import Core as _Core
12 | from ._logger import Logger as _Logger
13 |
14 | __version__ = "0.7.2"
15 |
16 | __all__ = ["logger"]
17 |
18 | logger = _Logger(
19 | core=_Core(),
20 | exception=None,
21 | depth=0,
22 | record=False,
23 | lazy=False,
24 | colors=False,
25 | raw=False,
26 | capture=True,
27 | patchers=[],
28 | extra={},
29 | )
30 |
31 | if _defaults.LOGURU_AUTOINIT and _sys.stderr:
32 | logger.add(_sys.stderr)
33 |
34 | _atexit.register(logger.remove)
35 |
--------------------------------------------------------------------------------
/loguru/__init__.pyi:
--------------------------------------------------------------------------------
1 | """
2 | .. |str| replace:: :class:`str`
3 | .. |namedtuple| replace:: :func:`namedtuple`
4 | .. |dict| replace:: :class:`dict`
5 |
6 | .. |Logger| replace:: :class:`~loguru._logger.Logger`
7 | .. |catch| replace:: :meth:`~loguru._logger.Logger.catch()`
8 | .. |contextualize| replace:: :meth:`~loguru._logger.Logger.contextualize()`
9 | .. |complete| replace:: :meth:`~loguru._logger.Logger.complete()`
10 | .. |bind| replace:: :meth:`~loguru._logger.Logger.bind()`
11 | .. |patch| replace:: :meth:`~loguru._logger.Logger.patch()`
12 | .. |opt| replace:: :meth:`~loguru._logger.Logger.opt()`
13 | .. |level| replace:: :meth:`~loguru._logger.Logger.level()`
14 |
15 | .. _stub file: https://www.python.org/dev/peps/pep-0484/#stub-files
16 | .. _string literals: https://www.python.org/dev/peps/pep-0484/#forward-references
17 | .. _postponed evaluation of annotations: https://www.python.org/dev/peps/pep-0563/
18 | .. |future| replace:: ``__future__``
19 | .. _future: https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7
20 | .. |loguru-mypy| replace:: ``loguru-mypy``
21 | .. _loguru-mypy: https://github.com/kornicameister/loguru-mypy
22 | .. |documentation of loguru-mypy| replace:: documentation of ``loguru-mypy``
23 | .. _documentation of loguru-mypy:
24 | https://github.com/kornicameister/loguru-mypy/blob/master/README.md
25 | .. _@kornicameister: https://github.com/kornicameister
26 |
27 | Loguru relies on a `stub file`_ to document its types. This implies that these types are not
28 | accessible during execution of your program, however they can be used by type checkers and IDE.
29 | Also, this means that your Python interpreter has to support `postponed evaluation of annotations`_
30 | to prevent error at runtime. This is achieved with a |future|_ import in Python 3.7+ or by using
31 | `string literals`_ for earlier versions.
32 |
33 | A basic usage example could look like this:
34 |
35 | .. code-block:: python
36 |
37 | from __future__ import annotations
38 |
39 | import loguru
40 | from loguru import logger
41 |
42 | def good_sink(message: loguru.Message):
43 | print("My name is", message.record["name"])
44 |
45 | def bad_filter(record: loguru.Record):
46 | return record["invalid"]
47 |
48 | logger.add(good_sink, filter=bad_filter)
49 |
50 |
51 | .. code-block:: bash
52 |
53 | $ mypy test.py
54 | test.py:8: error: TypedDict "Record" has no key 'invalid'
55 | Found 1 error in 1 file (checked 1 source file)
56 |
57 | There are several internal types to which you can be exposed using Loguru's public API, they are
58 | listed here and might be useful to type hint your code:
59 |
60 | - ``Logger``: the usual |logger| object (also returned by |opt|, |bind| and |patch|).
61 | - ``Message``: the formatted logging message sent to the sinks (a |str| with ``record``
62 | attribute).
63 | - ``Record``: the |dict| containing all contextual information of the logged message.
64 | - ``Level``: the |namedtuple| returned by |level| (with ``name``, ``no``, ``color`` and ``icon``
65 | attributes).
66 | - ``Catcher``: the context decorator returned by |catch|.
67 | - ``Contextualizer``: the context decorator returned by |contextualize|.
68 | - ``AwaitableCompleter``: the awaitable object returned by |complete|.
69 | - ``RecordFile``: the ``record["file"]`` with ``name`` and ``path`` attributes.
70 | - ``RecordLevel``: the ``record["level"]`` with ``name``, ``no`` and ``icon`` attributes.
71 | - ``RecordThread``: the ``record["thread"]`` with ``id`` and ``name`` attributes.
72 | - ``RecordProcess``: the ``record["process"]`` with ``id`` and ``name`` attributes.
73 | - ``RecordException``: the ``record["exception"]`` with ``type``, ``value`` and ``traceback``
74 | attributes.
75 |
76 | If that is not enough, one can also use the |loguru-mypy|_ library developed by `@kornicameister`_.
77 | Plugin can be installed separately using::
78 |
79 | pip install loguru-mypy
80 |
81 | It helps to catch several possible runtime errors by performing additional checks like:
82 |
83 | - ``opt(lazy=True)`` loggers accepting only ``typing.Callable[[], typing.Any]`` arguments
84 | - ``opt(record=True)`` loggers wrongly calling log handler like so ``logger.info(..., record={})``
85 | - and even more...
86 |
87 | For more details, go to official |documentation of loguru-mypy|_.
88 | """
89 |
90 | import sys
91 | from asyncio import AbstractEventLoop
92 | from datetime import datetime, time, timedelta
93 | from picologging import Handler
94 | from multiprocessing.context import BaseContext
95 | from types import TracebackType
96 | from typing import (
97 | Any,
98 | BinaryIO,
99 | Callable,
100 | Dict,
101 | Generator,
102 | Generic,
103 | List,
104 | NamedTuple,
105 | NewType,
106 | Optional,
107 | Pattern,
108 | Sequence,
109 | TextIO,
110 | Tuple,
111 | Type,
112 | TypeVar,
113 | Union,
114 | overload,
115 | )
116 |
117 | if sys.version_info >= (3, 5, 3):
118 | from typing import Awaitable
119 | else:
120 | from typing_extensions import Awaitable
121 |
122 | if sys.version_info >= (3, 6):
123 | from os import PathLike
124 | from typing import ContextManager
125 |
126 | PathLikeStr = PathLike[str]
127 | else:
128 | from pathlib import PurePath as PathLikeStr
129 |
130 | from typing_extensions import ContextManager
131 |
132 | if sys.version_info >= (3, 8):
133 | from typing import Protocol, TypedDict
134 | else:
135 | from typing_extensions import Protocol, TypedDict
136 |
137 | _T = TypeVar("_T")
138 | _F = TypeVar("_F", bound=Callable[..., Any])
139 | ExcInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]
140 |
141 | class _GeneratorContextManager(ContextManager[_T], Generic[_T]):
142 | def __call__(self, func: _F) -> _F: ...
143 | def __exit__(
144 | self,
145 | typ: Optional[Type[BaseException]],
146 | value: Optional[BaseException],
147 | traceback: Optional[TracebackType],
148 | ) -> Optional[bool]: ...
149 |
150 | Catcher = NewType("Catcher", _GeneratorContextManager[None])
151 | Contextualizer = NewType("Contextualizer", _GeneratorContextManager[None])
152 | AwaitableCompleter = Awaitable[None]
153 |
154 | class Level(NamedTuple):
155 | name: str
156 | no: int
157 | color: str
158 | icon: str
159 |
160 | class _RecordAttribute:
161 | def __repr__(self) -> str: ...
162 | def __format__(self, spec: str) -> str: ...
163 |
164 | class RecordFile(_RecordAttribute):
165 | name: str
166 | path: str
167 |
168 | class RecordLevel(_RecordAttribute):
169 | name: str
170 | no: int
171 | icon: str
172 |
173 | class RecordThread(_RecordAttribute):
174 | id: int
175 | name: str
176 |
177 | class RecordProcess(_RecordAttribute):
178 | id: int
179 | name: str
180 |
181 | class RecordException(NamedTuple):
182 | type: Optional[Type[BaseException]]
183 | value: Optional[BaseException]
184 | traceback: Optional[TracebackType]
185 |
186 | class Record(TypedDict):
187 | elapsed: timedelta
188 | exception: Optional[RecordException]
189 | extra: Dict[Any, Any]
190 | file: RecordFile
191 | function: str
192 | level: RecordLevel
193 | line: int
194 | message: str
195 | module: str
196 | name: Union[str, None]
197 | process: RecordProcess
198 | thread: RecordThread
199 | time: datetime
200 |
201 | class Message(str):
202 | record: Record
203 |
204 | class Writable(Protocol):
205 | def write(self, message: Message) -> None: ...
206 |
207 | FilterDict = Dict[Union[str, None], Union[str, int, bool]]
208 | FilterFunction = Callable[[Record], bool]
209 | FormatFunction = Callable[[Record], str]
210 | PatcherFunction = Callable[[Record], None]
211 | RotationFunction = Callable[[Message, TextIO], bool]
212 | RetentionFunction = Callable[[List[str]], None]
213 | CompressionFunction = Callable[[str], None]
214 |
215 | # Actually unusable because TypedDict can't allow extra keys: python/mypy#4617
216 | class _HandlerConfig(TypedDict, total=False):
217 | sink: Union[str, PathLikeStr, TextIO, Writable, Callable[[Message], None], Handler]
218 | level: Union[str, int]
219 | format: Union[str, FormatFunction]
220 | filter: Optional[Union[str, FilterFunction, FilterDict]]
221 | colorize: Optional[bool]
222 | serialize: bool
223 | backtrace: bool
224 | diagnose: bool
225 | enqueue: bool
226 | catch: bool
227 |
228 | class LevelConfig(TypedDict, total=False):
229 | name: str
230 | no: int
231 | color: str
232 | icon: str
233 |
234 | ActivationConfig = Tuple[Union[str, None], bool]
235 |
236 | class Logger:
237 | @overload
238 | def add(
239 | self,
240 | sink: Union[TextIO, Writable, Callable[[Message], None], Handler],
241 | *,
242 | level: Union[str, int] = ...,
243 | format: Union[str, FormatFunction] = ...,
244 | filter: Optional[Union[str, FilterFunction, FilterDict]] = ...,
245 | colorize: Optional[bool] = ...,
246 | serialize: bool = ...,
247 | backtrace: bool = ...,
248 | diagnose: bool = ...,
249 | enqueue: bool = ...,
250 | context: Optional[Union[str, BaseContext]] = ...,
251 | catch: bool = ...
252 | ) -> int: ...
253 | @overload
254 | def add(
255 | self,
256 | sink: Callable[[Message], Awaitable[None]],
257 | *,
258 | level: Union[str, int] = ...,
259 | format: Union[str, FormatFunction] = ...,
260 | filter: Optional[Union[str, FilterFunction, FilterDict]] = ...,
261 | colorize: Optional[bool] = ...,
262 | serialize: bool = ...,
263 | backtrace: bool = ...,
264 | diagnose: bool = ...,
265 | enqueue: bool = ...,
266 | context: Optional[Union[str, BaseContext]] = ...,
267 | catch: bool = ...,
268 | loop: Optional[AbstractEventLoop] = ...
269 | ) -> int: ...
270 | @overload
271 | def add(
272 | self,
273 | sink: Union[str, PathLikeStr],
274 | *,
275 | level: Union[str, int] = ...,
276 | format: Union[str, FormatFunction] = ...,
277 | filter: Optional[Union[str, FilterFunction, FilterDict]] = ...,
278 | colorize: Optional[bool] = ...,
279 | serialize: bool = ...,
280 | backtrace: bool = ...,
281 | diagnose: bool = ...,
282 | enqueue: bool = ...,
283 | context: Optional[Union[str, BaseContext]] = ...,
284 | catch: bool = ...,
285 | rotation: Optional[Union[str, int, time, timedelta, RotationFunction]] = ...,
286 | retention: Optional[Union[str, int, timedelta, RetentionFunction]] = ...,
287 | compression: Optional[Union[str, CompressionFunction]] = ...,
288 | delay: bool = ...,
289 | watch: bool = ...,
290 | mode: str = ...,
291 | buffering: int = ...,
292 | encoding: str = ...,
293 | **kwargs: Any
294 | ) -> int: ...
295 | def remove(self, handler_id: Optional[int] = ...) -> None: ...
296 | def complete(self) -> AwaitableCompleter: ...
297 | @overload
298 | def catch(
299 | self,
300 | exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = ...,
301 | *,
302 | level: Union[str, int] = ...,
303 | reraise: bool = ...,
304 | onerror: Optional[Callable[[BaseException], None]] = ...,
305 | exclude: Optional[Union[Type[BaseException], Tuple[Type[BaseException], ...]]] = ...,
306 | default: Any = ...,
307 | message: str = ...
308 | ) -> Catcher: ...
309 | @overload
310 | def catch(self, function: _F) -> _F: ...
311 | def opt(
312 | self,
313 | *,
314 | exception: Optional[Union[bool, ExcInfo, BaseException]] = ...,
315 | record: bool = ...,
316 | lazy: bool = ...,
317 | colors: bool = ...,
318 | raw: bool = ...,
319 | capture: bool = ...,
320 | depth: int = ...,
321 | ansi: bool = ...
322 | ) -> Logger: ...
323 | def bind(__self, **kwargs: Any) -> Logger: ... # noqa: N805
324 | def contextualize(__self, **kwargs: Any) -> Contextualizer: ... # noqa: N805
325 | def patch(self, patcher: PatcherFunction) -> Logger: ...
326 | @overload
327 | def level(self, name: str) -> Level: ...
328 | @overload
329 | def level(
330 | self, name: str, no: int = ..., color: Optional[str] = ..., icon: Optional[str] = ...
331 | ) -> Level: ...
332 | @overload
333 | def level(
334 | self,
335 | name: str,
336 | no: Optional[int] = ...,
337 | color: Optional[str] = ...,
338 | icon: Optional[str] = ...,
339 | ) -> Level: ...
340 | def disable(self, name: Union[str, None]) -> None: ...
341 | def enable(self, name: Union[str, None]) -> None: ...
342 | def configure(
343 | self,
344 | *,
345 | handlers: Sequence[Dict[str, Any]] = ...,
346 | levels: Optional[Sequence[LevelConfig]] = ...,
347 | extra: Optional[Dict[Any, Any]] = ...,
348 | patcher: Optional[PatcherFunction] = ...,
349 | activation: Optional[Sequence[ActivationConfig]] = ...
350 | ) -> List[int]: ...
351 | # @staticmethod cannot be used with @overload in mypy (python/mypy#7781).
352 | # However Logger is not exposed and logger is an instance of Logger
353 | # so for type checkers it is all the same whether it is defined here
354 | # as a static method or an instance method.
355 | @overload
356 | def parse(
357 | self,
358 | file: Union[str, PathLikeStr, TextIO],
359 | pattern: Union[str, Pattern[str]],
360 | *,
361 | cast: Union[Dict[str, Callable[[str], Any]], Callable[[Dict[str, str]], None]] = ...,
362 | chunk: int = ...
363 | ) -> Generator[Dict[str, Any], None, None]: ...
364 | @overload
365 | def parse(
366 | self,
367 | file: BinaryIO,
368 | pattern: Union[bytes, Pattern[bytes]],
369 | *,
370 | cast: Union[Dict[str, Callable[[bytes], Any]], Callable[[Dict[str, bytes]], None]] = ...,
371 | chunk: int = ...
372 | ) -> Generator[Dict[str, Any], None, None]: ...
373 | @overload
374 | def trace(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
375 | @overload
376 | def trace(__self, __message: Any) -> None: ... # noqa: N805
377 | @overload
378 | def debug(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
379 | @overload
380 | def debug(__self, __message: Any) -> None: ... # noqa: N805
381 | @overload
382 | def info(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
383 | @overload
384 | def info(__self, __message: Any) -> None: ... # noqa: N805
385 | @overload
386 | def success(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
387 | @overload
388 | def success(__self, __message: Any) -> None: ... # noqa: N805
389 | @overload
390 | def warning(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
391 | @overload
392 | def warning(__self, __message: Any) -> None: ... # noqa: N805
393 | @overload
394 | def error(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
395 | @overload
396 | def error(__self, __message: Any) -> None: ... # noqa: N805
397 | @overload
398 | def critical(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
399 | @overload
400 | def critical(__self, __message: Any) -> None: ... # noqa: N805
401 | @overload
402 | def exception(__self, __message: str, *args: Any, **kwargs: Any) -> None: ... # noqa: N805
403 | @overload
404 | def exception(__self, __message: Any) -> None: ... # noqa: N805
405 | @overload
406 | def log(
407 | __self, __level: Union[int, str], __message: str, *args: Any, **kwargs: Any # noqa: N805
408 | ) -> None: ...
409 | @overload
410 | def log(__self, __level: Union[int, str], __message: Any) -> None: ... # noqa: N805
411 | def start(self, *args: Any, **kwargs: Any) -> int: ...
412 | def stop(self, *args: Any, **kwargs: Any) -> None: ...
413 |
414 | logger: Logger
415 |
--------------------------------------------------------------------------------
/loguru/_asyncio_loop.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import sys
3 |
4 |
5 | def load_loop_functions():
6 | if sys.version_info >= (3, 7):
7 |
8 | def get_task_loop(task):
9 | return task.get_loop()
10 |
11 | get_running_loop = asyncio.get_running_loop
12 |
13 | else:
14 |
15 | def get_task_loop(task):
16 | return task._loop
17 |
18 | def get_running_loop():
19 | loop = asyncio.get_event_loop()
20 | if not loop.is_running():
21 | raise RuntimeError("There is no running event loop")
22 | return loop
23 |
24 | return get_task_loop, get_running_loop
25 |
26 |
27 | get_task_loop, get_running_loop = load_loop_functions()
28 |
--------------------------------------------------------------------------------
/loguru/_better_exceptions.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import inspect
3 | import io
4 | import keyword
5 | import linecache
6 | import os
7 | import re
8 | import sys
9 | import sysconfig
10 | import tokenize
11 | import traceback
12 |
13 | if sys.version_info >= (3, 11):
14 |
15 | def is_exception_group(exc):
16 | return isinstance(exc, ExceptionGroup)
17 |
18 | else:
19 | try:
20 | from exceptiongroup import ExceptionGroup
21 | except ImportError:
22 |
23 | def is_exception_group(exc):
24 | return False
25 |
26 | else:
27 |
28 | def is_exception_group(exc):
29 | return isinstance(exc, ExceptionGroup)
30 |
31 |
32 | class SyntaxHighlighter:
33 | _default_style = {
34 | "comment": "\x1b[30m\x1b[1m{}\x1b[0m",
35 | "keyword": "\x1b[35m\x1b[1m{}\x1b[0m",
36 | "builtin": "\x1b[1m{}\x1b[0m",
37 | "string": "\x1b[36m{}\x1b[0m",
38 | "number": "\x1b[34m\x1b[1m{}\x1b[0m",
39 | "operator": "\x1b[35m\x1b[1m{}\x1b[0m",
40 | "punctuation": "\x1b[1m{}\x1b[0m",
41 | "constant": "\x1b[36m\x1b[1m{}\x1b[0m",
42 | "identifier": "\x1b[1m{}\x1b[0m",
43 | "other": "{}",
44 | }
45 |
46 | _builtins = set(dir(builtins))
47 | _constants = {"True", "False", "None"}
48 | _punctuation = {"(", ")", "[", "]", "{", "}", ":", ",", ";"}
49 | _strings = {tokenize.STRING}
50 | _fstring_middle = None
51 |
52 | if sys.version_info >= (3, 12):
53 | _strings.update({tokenize.FSTRING_START, tokenize.FSTRING_MIDDLE, tokenize.FSTRING_END})
54 | _fstring_middle = tokenize.FSTRING_MIDDLE
55 |
56 | def __init__(self, style=None):
57 | self._style = style or self._default_style
58 |
59 | def highlight(self, source):
60 | style = self._style
61 | row, column = 0, 0
62 | output = ""
63 |
64 | for token in self.tokenize(source):
65 | type_, string, (start_row, start_column), (_, end_column), line = token
66 |
67 | if type_ == self._fstring_middle:
68 | # When an f-string contains "{{" or "}}", they appear as "{" or "}" in the "string"
69 | # attribute of the token. However, they do not count in the column position.
70 | end_column += string.count("{") + string.count("}")
71 |
72 | if type_ == tokenize.NAME:
73 | if string in self._constants:
74 | color = style["constant"]
75 | elif keyword.iskeyword(string):
76 | color = style["keyword"]
77 | elif string in self._builtins:
78 | color = style["builtin"]
79 | else:
80 | color = style["identifier"]
81 | elif type_ == tokenize.OP:
82 | if string in self._punctuation:
83 | color = style["punctuation"]
84 | else:
85 | color = style["operator"]
86 | elif type_ == tokenize.NUMBER:
87 | color = style["number"]
88 | elif type_ in self._strings:
89 | color = style["string"]
90 | elif type_ == tokenize.COMMENT:
91 | color = style["comment"]
92 | else:
93 | color = style["other"]
94 |
95 | if start_row != row:
96 | source = source[column:]
97 | row, column = start_row, 0
98 |
99 | if type_ != tokenize.ENCODING:
100 | output += line[column:start_column]
101 | output += color.format(line[start_column:end_column])
102 |
103 | column = end_column
104 |
105 | output += source[column:]
106 |
107 | return output
108 |
109 | @staticmethod
110 | def tokenize(source):
111 | # Worth reading: https://www.asmeurer.com/brown-water-python/
112 | source = source.encode("utf-8")
113 | source = io.BytesIO(source)
114 |
115 | try:
116 | yield from tokenize.tokenize(source.readline)
117 | except tokenize.TokenError:
118 | return
119 |
120 |
121 | class ExceptionFormatter:
122 | _default_theme = {
123 | "introduction": "\x1b[33m\x1b[1m{}\x1b[0m",
124 | "cause": "\x1b[1m{}\x1b[0m",
125 | "context": "\x1b[1m{}\x1b[0m",
126 | "dirname": "\x1b[32m{}\x1b[0m",
127 | "basename": "\x1b[32m\x1b[1m{}\x1b[0m",
128 | "line": "\x1b[33m{}\x1b[0m",
129 | "function": "\x1b[35m{}\x1b[0m",
130 | "exception_type": "\x1b[31m\x1b[1m{}\x1b[0m",
131 | "exception_value": "\x1b[1m{}\x1b[0m",
132 | "arrows": "\x1b[36m{}\x1b[0m",
133 | "value": "\x1b[36m\x1b[1m{}\x1b[0m",
134 | }
135 |
136 | def __init__(
137 | self,
138 | colorize=False,
139 | backtrace=False,
140 | diagnose=True,
141 | theme=None,
142 | style=None,
143 | max_length=128,
144 | encoding="ascii",
145 | hidden_frames_filename=None,
146 | prefix="",
147 | ):
148 | self._colorize = colorize
149 | self._diagnose = diagnose
150 | self._theme = theme or self._default_theme
151 | self._backtrace = backtrace
152 | self._syntax_highlighter = SyntaxHighlighter(style)
153 | self._max_length = max_length
154 | self._encoding = encoding
155 | self._hidden_frames_filename = hidden_frames_filename
156 | self._prefix = prefix
157 | self._lib_dirs = self._get_lib_dirs()
158 | self._pipe_char = self._get_char("\u2502", "|")
159 | self._cap_char = self._get_char("\u2514", "->")
160 | self._catch_point_identifier = " "
161 |
162 | @staticmethod
163 | def _get_lib_dirs():
164 | schemes = sysconfig.get_scheme_names()
165 | names = ["stdlib", "platstdlib", "platlib", "purelib"]
166 | paths = {sysconfig.get_path(name, scheme) for scheme in schemes for name in names}
167 | return [os.path.abspath(path).lower() + os.sep for path in paths if path in sys.path]
168 |
169 | @staticmethod
170 | def _indent(text, count, *, prefix="| "):
171 | if count == 0:
172 | yield text
173 | return
174 | for line in text.splitlines(True):
175 | indented = " " * count + prefix + line
176 | yield indented.rstrip() + "\n"
177 |
178 | def _get_char(self, char, default):
179 | try:
180 | char.encode(self._encoding)
181 | except (UnicodeEncodeError, LookupError):
182 | return default
183 | else:
184 | return char
185 |
186 | def _is_file_mine(self, file):
187 | filepath = os.path.abspath(file).lower()
188 | if not filepath.endswith(".py"):
189 | return False
190 | return not any(filepath.startswith(d) for d in self._lib_dirs)
191 |
192 | def _extract_frames(self, tb, is_first, *, limit=None, from_decorator=False):
193 | frames, final_source = [], None
194 |
195 | if tb is None or (limit is not None and limit <= 0):
196 | return frames, final_source
197 |
198 | def is_valid(frame):
199 | return frame.f_code.co_filename != self._hidden_frames_filename
200 |
201 | def get_info(frame, lineno):
202 | filename = frame.f_code.co_filename
203 | function = frame.f_code.co_name
204 | source = linecache.getline(filename, lineno).strip()
205 | return filename, lineno, function, source
206 |
207 | infos = []
208 |
209 | if is_valid(tb.tb_frame):
210 | infos.append((get_info(tb.tb_frame, tb.tb_lineno), tb.tb_frame))
211 |
212 | get_parent_only = from_decorator and not self._backtrace
213 |
214 | if (self._backtrace and is_first) or get_parent_only:
215 | frame = tb.tb_frame.f_back
216 | while frame:
217 | if is_valid(frame):
218 | infos.insert(0, (get_info(frame, frame.f_lineno), frame))
219 | if get_parent_only:
220 | break
221 | frame = frame.f_back
222 |
223 | if infos and not get_parent_only:
224 | (filename, lineno, function, source), frame = infos[-1]
225 | function += self._catch_point_identifier
226 | infos[-1] = ((filename, lineno, function, source), frame)
227 |
228 | tb = tb.tb_next
229 |
230 | while tb:
231 | if is_valid(tb.tb_frame):
232 | infos.append((get_info(tb.tb_frame, tb.tb_lineno), tb.tb_frame))
233 | tb = tb.tb_next
234 |
235 | if limit is not None:
236 | infos = infos[-limit:]
237 |
238 | for (filename, lineno, function, source), frame in infos:
239 | final_source = source
240 | if source:
241 | colorize = self._colorize and self._is_file_mine(filename)
242 | lines = []
243 | if colorize:
244 | lines.append(self._syntax_highlighter.highlight(source))
245 | else:
246 | lines.append(source)
247 | if self._diagnose:
248 | relevant_values = self._get_relevant_values(source, frame)
249 | values = self._format_relevant_values(list(relevant_values), colorize)
250 | lines += list(values)
251 | source = "\n ".join(lines)
252 | frames.append((filename, lineno, function, source))
253 |
254 | return frames, final_source
255 |
256 | def _get_relevant_values(self, source, frame):
257 | value = None
258 | pending = None
259 | is_attribute = False
260 | is_valid_value = False
261 | is_assignment = True
262 |
263 | for token in self._syntax_highlighter.tokenize(source):
264 | type_, string, (_, col), *_ = token
265 |
266 | if pending is not None:
267 | # Keyword arguments are ignored
268 | if type_ != tokenize.OP or string != "=" or is_assignment:
269 | yield pending
270 | pending = None
271 |
272 | if type_ == tokenize.NAME and not keyword.iskeyword(string):
273 | if not is_attribute:
274 | for variables in (frame.f_locals, frame.f_globals):
275 | try:
276 | value = variables[string]
277 | except KeyError:
278 | continue
279 | else:
280 | is_valid_value = True
281 | pending = (col, self._format_value(value))
282 | break
283 | elif is_valid_value:
284 | try:
285 | value = inspect.getattr_static(value, string)
286 | except AttributeError:
287 | is_valid_value = False
288 | else:
289 | yield (col, self._format_value(value))
290 | elif type_ == tokenize.OP and string == ".":
291 | is_attribute = True
292 | is_assignment = False
293 | elif type_ == tokenize.OP and string == ";":
294 | is_assignment = True
295 | is_attribute = False
296 | is_valid_value = False
297 | else:
298 | is_attribute = False
299 | is_valid_value = False
300 | is_assignment = False
301 |
302 | if pending is not None:
303 | yield pending
304 |
305 | def _format_relevant_values(self, relevant_values, colorize):
306 | for i in reversed(range(len(relevant_values))):
307 | col, value = relevant_values[i]
308 | pipe_cols = [pcol for pcol, _ in relevant_values[:i]]
309 | pre_line = ""
310 | index = 0
311 |
312 | for pc in pipe_cols:
313 | pre_line += (" " * (pc - index)) + self._pipe_char
314 | index = pc + 1
315 |
316 | pre_line += " " * (col - index)
317 | value_lines = value.split("\n")
318 |
319 | for n, value_line in enumerate(value_lines):
320 | if n == 0:
321 | arrows = pre_line + self._cap_char + " "
322 | else:
323 | arrows = pre_line + " " * (len(self._cap_char) + 1)
324 |
325 | if colorize:
326 | arrows = self._theme["arrows"].format(arrows)
327 | value_line = self._theme["value"].format(value_line)
328 |
329 | yield arrows + value_line
330 |
331 | def _format_value(self, v):
332 | try:
333 | v = repr(v)
334 | except Exception:
335 | v = "" % type(v).__name__
336 |
337 | max_length = self._max_length
338 | if max_length is not None and len(v) > max_length:
339 | v = v[: max_length - 3] + "..."
340 | return v
341 |
342 | def _format_locations(self, frames_lines, *, has_introduction):
343 | prepend_with_new_line = has_introduction
344 | regex = r'^ File "(?P.*?)", line (?P[^,]+)(?:, in (?P.*))?\n'
345 |
346 | for frame in frames_lines:
347 | match = re.match(regex, frame)
348 |
349 | if match:
350 | file, line, function = match.group("file", "line", "function")
351 |
352 | is_mine = self._is_file_mine(file)
353 |
354 | if function is not None:
355 | pattern = ' File "{}", line {}, in {}\n'
356 | else:
357 | pattern = ' File "{}", line {}\n'
358 |
359 | if self._backtrace and function and function.endswith(self._catch_point_identifier):
360 | function = function[: -len(self._catch_point_identifier)]
361 | pattern = ">" + pattern[1:]
362 |
363 | if self._colorize and is_mine:
364 | dirname, basename = os.path.split(file)
365 | if dirname:
366 | dirname += os.sep
367 | dirname = self._theme["dirname"].format(dirname)
368 | basename = self._theme["basename"].format(basename)
369 | file = dirname + basename
370 | line = self._theme["line"].format(line)
371 | function = self._theme["function"].format(function)
372 |
373 | if self._diagnose and (is_mine or prepend_with_new_line):
374 | pattern = "\n" + pattern
375 |
376 | location = pattern.format(file, line, function)
377 | frame = location + frame[match.end() :]
378 | prepend_with_new_line = is_mine
379 |
380 | yield frame
381 |
382 | def _format_exception(
383 | self, value, tb, *, seen=None, is_first=False, from_decorator=False, group_nesting=0
384 | ):
385 | # Implemented from built-in traceback module:
386 | # https://github.com/python/cpython/blob/a5b76167/Lib/traceback.py#L468
387 | exc_type, exc_value, exc_traceback = type(value), value, tb
388 |
389 | if seen is None:
390 | seen = set()
391 |
392 | seen.add(id(exc_value))
393 |
394 | if exc_value:
395 | if exc_value.__cause__ is not None and id(exc_value.__cause__) not in seen:
396 | yield from self._format_exception(
397 | exc_value.__cause__,
398 | exc_value.__cause__.__traceback__,
399 | seen=seen,
400 | group_nesting=group_nesting,
401 | )
402 | cause = "The above exception was the direct cause of the following exception:"
403 | if self._colorize:
404 | cause = self._theme["cause"].format(cause)
405 | if self._diagnose:
406 | yield from self._indent("\n\n" + cause + "\n\n\n", group_nesting)
407 | else:
408 | yield from self._indent("\n" + cause + "\n\n", group_nesting)
409 |
410 | elif (
411 | exc_value.__context__ is not None
412 | and id(exc_value.__context__) not in seen
413 | and not exc_value.__suppress_context__
414 | ):
415 | yield from self._format_exception(
416 | exc_value.__context__,
417 | exc_value.__context__.__traceback__,
418 | seen=seen,
419 | group_nesting=group_nesting,
420 | )
421 | context = "During handling of the above exception, another exception occurred:"
422 | if self._colorize:
423 | context = self._theme["context"].format(context)
424 | if self._diagnose:
425 | yield from self._indent("\n\n" + context + "\n\n\n", group_nesting)
426 | else:
427 | yield from self._indent("\n" + context + "\n\n", group_nesting)
428 |
429 | is_grouped = is_exception_group(value)
430 |
431 | if is_grouped and group_nesting == 0:
432 | yield from self._format_exception(
433 | value,
434 | tb,
435 | seen=seen,
436 | group_nesting=1,
437 | is_first=is_first,
438 | from_decorator=from_decorator,
439 | )
440 | return
441 |
442 | try:
443 | traceback_limit = sys.tracebacklimit
444 | except AttributeError:
445 | traceback_limit = None
446 |
447 | frames, final_source = self._extract_frames(
448 | exc_traceback, is_first, limit=traceback_limit, from_decorator=from_decorator
449 | )
450 | exception_only = traceback.format_exception_only(exc_type, exc_value)
451 |
452 | # Determining the correct index for the "Exception: message" part in the formatted exception
453 | # is challenging. This is because it might be preceded by multiple lines specific to
454 | # "SyntaxError" or followed by various notes. However, we can make an educated guess based
455 | # on the indentation; the preliminary context for "SyntaxError" is always indented, while
456 | # the Exception itself is not. This allows us to identify the correct index for the
457 | # exception message.
458 | no_indented_indexes = (i for i, p in enumerate(exception_only) if not p.startswith(" "))
459 | error_message_index = next(no_indented_indexes, None)
460 |
461 | if error_message_index is not None:
462 | # Remove final new line temporarily.
463 | error_message = exception_only[error_message_index][:-1]
464 |
465 | if self._colorize:
466 | if ":" in error_message:
467 | exception_type, exception_value = error_message.split(":", 1)
468 | exception_type = self._theme["exception_type"].format(exception_type)
469 | exception_value = self._theme["exception_value"].format(exception_value)
470 | error_message = exception_type + ":" + exception_value
471 | else:
472 | error_message = self._theme["exception_type"].format(error_message)
473 |
474 | if self._diagnose and frames:
475 | if issubclass(exc_type, AssertionError) and not str(exc_value) and final_source:
476 | if self._colorize:
477 | final_source = self._syntax_highlighter.highlight(final_source)
478 | error_message += ": " + final_source
479 |
480 | error_message = "\n" + error_message
481 |
482 | exception_only[error_message_index] = error_message + "\n"
483 |
484 | if is_first:
485 | yield self._prefix
486 |
487 | has_introduction = bool(frames)
488 |
489 | if has_introduction:
490 | if is_grouped:
491 | introduction = "Exception Group Traceback (most recent call last):"
492 | else:
493 | introduction = "Traceback (most recent call last):"
494 | if self._colorize:
495 | introduction = self._theme["introduction"].format(introduction)
496 | if group_nesting == 1: # Implies we're processing the root ExceptionGroup.
497 | yield from self._indent(introduction + "\n", group_nesting, prefix="+ ")
498 | else:
499 | yield from self._indent(introduction + "\n", group_nesting)
500 |
501 | frames_lines = self._format_list(frames) + exception_only
502 | if self._colorize or self._backtrace or self._diagnose:
503 | frames_lines = self._format_locations(frames_lines, has_introduction=has_introduction)
504 |
505 | yield from self._indent("".join(frames_lines), group_nesting)
506 |
507 | if is_grouped:
508 | exc = None
509 | for n, exc in enumerate(value.exceptions, start=1):
510 | ruler = "+" + (" %s " % ("..." if n > 15 else n)).center(35, "-")
511 | yield from self._indent(ruler, group_nesting, prefix="+-" if n == 1 else " ")
512 | if n > 15:
513 | message = "and %d more exceptions\n" % (len(value.exceptions) - 15)
514 | yield from self._indent(message, group_nesting + 1)
515 | break
516 | elif group_nesting == 10 and is_exception_group(exc):
517 | message = "... (max_group_depth is 10)\n"
518 | yield from self._indent(message, group_nesting + 1)
519 | else:
520 | yield from self._format_exception(
521 | exc,
522 | exc.__traceback__,
523 | seen=seen,
524 | group_nesting=group_nesting + 1,
525 | )
526 | if not is_exception_group(exc) or group_nesting == 10:
527 | yield from self._indent("-" * 35, group_nesting + 1, prefix="+-")
528 |
529 | def _format_list(self, frames):
530 | result = []
531 | for filename, lineno, name, line in frames:
532 | row = []
533 | row.append(' File "{}", line {}, in {}\n'.format(filename, lineno, name))
534 | if line:
535 | row.append(" {}\n".format(line.strip()))
536 | result.append("".join(row))
537 | return result
538 |
539 | def format_exception(self, type_, value, tb, *, from_decorator=False):
540 | yield from self._format_exception(value, tb, is_first=True, from_decorator=from_decorator)
541 |
--------------------------------------------------------------------------------
/loguru/_colorama.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import os
3 | import sys
4 |
5 |
6 | def should_colorize(stream):
7 | if stream is None:
8 | return False
9 |
10 | if getattr(builtins, "__IPYTHON__", False) and (stream is sys.stdout or stream is sys.stderr):
11 | try:
12 | import ipykernel
13 | import IPython
14 |
15 | ipython = IPython.get_ipython()
16 | is_jupyter_stream = isinstance(stream, ipykernel.iostream.OutStream)
17 | is_jupyter_shell = isinstance(ipython, ipykernel.zmqshell.ZMQInteractiveShell)
18 | except Exception:
19 | pass
20 | else:
21 | if is_jupyter_stream and is_jupyter_shell:
22 | return True
23 |
24 | if stream is sys.__stdout__ or stream is sys.__stderr__:
25 | if "CI" in os.environ and any(
26 | ci in os.environ
27 | for ci in ["TRAVIS", "CIRCLECI", "APPVEYOR", "GITLAB_CI", "GITHUB_ACTIONS"]
28 | ):
29 | return True
30 | if "PYCHARM_HOSTED" in os.environ:
31 | return True
32 | if os.name == "nt" and "TERM" in os.environ:
33 | return True
34 |
35 | try:
36 | return stream.isatty()
37 | except Exception:
38 | return False
39 |
40 |
41 | def should_wrap(stream):
42 | if os.name != "nt":
43 | return False
44 |
45 | if stream is not sys.__stdout__ and stream is not sys.__stderr__:
46 | return False
47 |
48 | from colorama.win32 import winapi_test
49 |
50 | if not winapi_test():
51 | return False
52 |
53 | try:
54 | from colorama.winterm import enable_vt_processing
55 | except ImportError:
56 | return True
57 |
58 | try:
59 | return not enable_vt_processing(stream.fileno())
60 | except Exception:
61 | return True
62 |
63 |
64 | def wrap(stream):
65 | from colorama import AnsiToWin32
66 |
67 | return AnsiToWin32(stream, convert=True, strip=True, autoreset=False).stream
68 |
--------------------------------------------------------------------------------
/loguru/_colorizer.py:
--------------------------------------------------------------------------------
1 | import re
2 | from string import Formatter
3 |
4 |
5 | class Style:
6 | RESET_ALL = 0
7 | BOLD = 1
8 | DIM = 2
9 | ITALIC = 3
10 | UNDERLINE = 4
11 | BLINK = 5
12 | REVERSE = 7
13 | HIDE = 8
14 | STRIKE = 9
15 | NORMAL = 22
16 |
17 |
18 | class Fore:
19 | BLACK = 30
20 | RED = 31
21 | GREEN = 32
22 | YELLOW = 33
23 | BLUE = 34
24 | MAGENTA = 35
25 | CYAN = 36
26 | WHITE = 37
27 | RESET = 39
28 |
29 | LIGHTBLACK_EX = 90
30 | LIGHTRED_EX = 91
31 | LIGHTGREEN_EX = 92
32 | LIGHTYELLOW_EX = 93
33 | LIGHTBLUE_EX = 94
34 | LIGHTMAGENTA_EX = 95
35 | LIGHTCYAN_EX = 96
36 | LIGHTWHITE_EX = 97
37 |
38 |
39 | class Back:
40 | BLACK = 40
41 | RED = 41
42 | GREEN = 42
43 | YELLOW = 43
44 | BLUE = 44
45 | MAGENTA = 45
46 | CYAN = 46
47 | WHITE = 47
48 | RESET = 49
49 |
50 | LIGHTBLACK_EX = 100
51 | LIGHTRED_EX = 101
52 | LIGHTGREEN_EX = 102
53 | LIGHTYELLOW_EX = 103
54 | LIGHTBLUE_EX = 104
55 | LIGHTMAGENTA_EX = 105
56 | LIGHTCYAN_EX = 106
57 | LIGHTWHITE_EX = 107
58 |
59 |
60 | def ansi_escape(codes):
61 | return {name: "\033[%dm" % code for name, code in codes.items()}
62 |
63 |
64 | class TokenType:
65 | TEXT = 1
66 | ANSI = 2
67 | LEVEL = 3
68 | CLOSING = 4
69 |
70 |
71 | class AnsiParser:
72 | _style = ansi_escape(
73 | {
74 | "b": Style.BOLD,
75 | "d": Style.DIM,
76 | "n": Style.NORMAL,
77 | "h": Style.HIDE,
78 | "i": Style.ITALIC,
79 | "l": Style.BLINK,
80 | "s": Style.STRIKE,
81 | "u": Style.UNDERLINE,
82 | "v": Style.REVERSE,
83 | "bold": Style.BOLD,
84 | "dim": Style.DIM,
85 | "normal": Style.NORMAL,
86 | "hide": Style.HIDE,
87 | "italic": Style.ITALIC,
88 | "blink": Style.BLINK,
89 | "strike": Style.STRIKE,
90 | "underline": Style.UNDERLINE,
91 | "reverse": Style.REVERSE,
92 | }
93 | )
94 |
95 | _foreground = ansi_escape(
96 | {
97 | "k": Fore.BLACK,
98 | "r": Fore.RED,
99 | "g": Fore.GREEN,
100 | "y": Fore.YELLOW,
101 | "e": Fore.BLUE,
102 | "m": Fore.MAGENTA,
103 | "c": Fore.CYAN,
104 | "w": Fore.WHITE,
105 | "lk": Fore.LIGHTBLACK_EX,
106 | "lr": Fore.LIGHTRED_EX,
107 | "lg": Fore.LIGHTGREEN_EX,
108 | "ly": Fore.LIGHTYELLOW_EX,
109 | "le": Fore.LIGHTBLUE_EX,
110 | "lm": Fore.LIGHTMAGENTA_EX,
111 | "lc": Fore.LIGHTCYAN_EX,
112 | "lw": Fore.LIGHTWHITE_EX,
113 | "black": Fore.BLACK,
114 | "red": Fore.RED,
115 | "green": Fore.GREEN,
116 | "yellow": Fore.YELLOW,
117 | "blue": Fore.BLUE,
118 | "magenta": Fore.MAGENTA,
119 | "cyan": Fore.CYAN,
120 | "white": Fore.WHITE,
121 | "light-black": Fore.LIGHTBLACK_EX,
122 | "light-red": Fore.LIGHTRED_EX,
123 | "light-green": Fore.LIGHTGREEN_EX,
124 | "light-yellow": Fore.LIGHTYELLOW_EX,
125 | "light-blue": Fore.LIGHTBLUE_EX,
126 | "light-magenta": Fore.LIGHTMAGENTA_EX,
127 | "light-cyan": Fore.LIGHTCYAN_EX,
128 | "light-white": Fore.LIGHTWHITE_EX,
129 | }
130 | )
131 |
132 | _background = ansi_escape(
133 | {
134 | "K": Back.BLACK,
135 | "R": Back.RED,
136 | "G": Back.GREEN,
137 | "Y": Back.YELLOW,
138 | "E": Back.BLUE,
139 | "M": Back.MAGENTA,
140 | "C": Back.CYAN,
141 | "W": Back.WHITE,
142 | "LK": Back.LIGHTBLACK_EX,
143 | "LR": Back.LIGHTRED_EX,
144 | "LG": Back.LIGHTGREEN_EX,
145 | "LY": Back.LIGHTYELLOW_EX,
146 | "LE": Back.LIGHTBLUE_EX,
147 | "LM": Back.LIGHTMAGENTA_EX,
148 | "LC": Back.LIGHTCYAN_EX,
149 | "LW": Back.LIGHTWHITE_EX,
150 | "BLACK": Back.BLACK,
151 | "RED": Back.RED,
152 | "GREEN": Back.GREEN,
153 | "YELLOW": Back.YELLOW,
154 | "BLUE": Back.BLUE,
155 | "MAGENTA": Back.MAGENTA,
156 | "CYAN": Back.CYAN,
157 | "WHITE": Back.WHITE,
158 | "LIGHT-BLACK": Back.LIGHTBLACK_EX,
159 | "LIGHT-RED": Back.LIGHTRED_EX,
160 | "LIGHT-GREEN": Back.LIGHTGREEN_EX,
161 | "LIGHT-YELLOW": Back.LIGHTYELLOW_EX,
162 | "LIGHT-BLUE": Back.LIGHTBLUE_EX,
163 | "LIGHT-MAGENTA": Back.LIGHTMAGENTA_EX,
164 | "LIGHT-CYAN": Back.LIGHTCYAN_EX,
165 | "LIGHT-WHITE": Back.LIGHTWHITE_EX,
166 | }
167 | )
168 |
169 | _regex_tag = re.compile(r"\\??((?:[fb]g\s)?[^<>\s]*)>")
170 |
171 | def __init__(self):
172 | self._tokens = []
173 | self._tags = []
174 | self._color_tokens = []
175 |
176 | @staticmethod
177 | def strip(tokens):
178 | output = ""
179 | for type_, value in tokens:
180 | if type_ == TokenType.TEXT:
181 | output += value
182 | return output
183 |
184 | @staticmethod
185 | def colorize(tokens, ansi_level):
186 | output = ""
187 |
188 | for type_, value in tokens:
189 | if type_ == TokenType.LEVEL:
190 | if ansi_level is None:
191 | raise ValueError(
192 | "The '' color tag is not allowed in this context, "
193 | "it has not yet been associated to any color value."
194 | )
195 | value = ansi_level
196 | output += value
197 |
198 | return output
199 |
200 | @staticmethod
201 | def wrap(tokens, *, ansi_level, color_tokens):
202 | output = ""
203 |
204 | for type_, value in tokens:
205 | if type_ == TokenType.LEVEL:
206 | value = ansi_level
207 | output += value
208 | if type_ == TokenType.CLOSING:
209 | for subtype, subvalue in color_tokens:
210 | if subtype == TokenType.LEVEL:
211 | subvalue = ansi_level
212 | output += subvalue
213 |
214 | return output
215 |
216 | def feed(self, text, *, raw=False):
217 | if raw:
218 | self._tokens.append((TokenType.TEXT, text))
219 | return
220 |
221 | position = 0
222 |
223 | for match in self._regex_tag.finditer(text):
224 | markup, tag = match.group(0), match.group(1)
225 |
226 | self._tokens.append((TokenType.TEXT, text[position : match.start()]))
227 |
228 | position = match.end()
229 |
230 | if markup[0] == "\\":
231 | self._tokens.append((TokenType.TEXT, markup[1:]))
232 | continue
233 |
234 | if markup[1] == "/":
235 | if self._tags and (tag == "" or tag == self._tags[-1]):
236 | self._tags.pop()
237 | self._color_tokens.pop()
238 | self._tokens.append((TokenType.CLOSING, "\033[0m"))
239 | self._tokens.extend(self._color_tokens)
240 | continue
241 | if tag in self._tags:
242 | raise ValueError('Closing tag "%s" violates nesting rules' % markup)
243 | raise ValueError('Closing tag "%s" has no corresponding opening tag' % markup)
244 |
245 | if tag in {"lvl", "level"}:
246 | token = (TokenType.LEVEL, None)
247 | else:
248 | ansi = self._get_ansicode(tag)
249 |
250 | if ansi is None:
251 | raise ValueError(
252 | 'Tag "%s" does not correspond to any known color directive, '
253 | "make sure you did not misspelled it (or prepend '\\' to escape it)"
254 | % markup
255 | )
256 |
257 | token = (TokenType.ANSI, ansi)
258 |
259 | self._tags.append(tag)
260 | self._color_tokens.append(token)
261 | self._tokens.append(token)
262 |
263 | self._tokens.append((TokenType.TEXT, text[position:]))
264 |
265 | def done(self, *, strict=True):
266 | if strict and self._tags:
267 | faulty_tag = self._tags.pop(0)
268 | raise ValueError('Opening tag "<%s>" has no corresponding closing tag' % faulty_tag)
269 | return self._tokens
270 |
271 | def current_color_tokens(self):
272 | return list(self._color_tokens)
273 |
274 | def _get_ansicode(self, tag):
275 | style = self._style
276 | foreground = self._foreground
277 | background = self._background
278 |
279 | # Substitute on a direct match.
280 | if tag in style:
281 | return style[tag]
282 | if tag in foreground:
283 | return foreground[tag]
284 | if tag in background:
285 | return background[tag]
286 |
287 | # An alternative syntax for setting the color (e.g. , ).
288 | if tag.startswith("fg ") or tag.startswith("bg "):
289 | st, color = tag[:2], tag[3:]
290 | code = "38" if st == "fg" else "48"
291 |
292 | if st == "fg" and color.lower() in foreground:
293 | return foreground[color.lower()]
294 | if st == "bg" and color.upper() in background:
295 | return background[color.upper()]
296 | if color.isdigit() and int(color) <= 255:
297 | return "\033[%s;5;%sm" % (code, color)
298 | if re.match(r"#(?:[a-fA-F0-9]{3}){1,2}$", color):
299 | hex_color = color[1:]
300 | if len(hex_color) == 3:
301 | hex_color *= 2
302 | rgb = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
303 | return "\033[%s;2;%s;%s;%sm" % ((code,) + rgb)
304 | if color.count(",") == 2:
305 | colors = tuple(color.split(","))
306 | if all(x.isdigit() and int(x) <= 255 for x in colors):
307 | return "\033[%s;2;%s;%s;%sm" % ((code,) + colors)
308 |
309 | return None
310 |
311 |
312 | class ColoringMessage(str):
313 | __fields__ = ("_messages",)
314 |
315 | def __format__(self, spec):
316 | return next(self._messages).__format__(spec)
317 |
318 |
319 | class ColoredMessage:
320 | def __init__(self, tokens):
321 | self.tokens = tokens
322 | self.stripped = AnsiParser.strip(tokens)
323 |
324 | def colorize(self, ansi_level):
325 | return AnsiParser.colorize(self.tokens, ansi_level)
326 |
327 |
328 | class ColoredFormat:
329 | def __init__(self, tokens, messages_color_tokens):
330 | self._tokens = tokens
331 | self._messages_color_tokens = messages_color_tokens
332 |
333 | def strip(self):
334 | return AnsiParser.strip(self._tokens)
335 |
336 | def colorize(self, ansi_level):
337 | return AnsiParser.colorize(self._tokens, ansi_level)
338 |
339 | def make_coloring_message(self, message, *, ansi_level, colored_message):
340 | messages = [
341 | (
342 | message
343 | if color_tokens is None
344 | else AnsiParser.wrap(
345 | colored_message.tokens, ansi_level=ansi_level, color_tokens=color_tokens
346 | )
347 | )
348 | for color_tokens in self._messages_color_tokens
349 | ]
350 | coloring = ColoringMessage(message)
351 | coloring._messages = iter(messages)
352 | return coloring
353 |
354 |
355 | class Colorizer:
356 | @staticmethod
357 | def prepare_format(string):
358 | tokens, messages_color_tokens = Colorizer._parse_without_formatting(string)
359 | return ColoredFormat(tokens, messages_color_tokens)
360 |
361 | @staticmethod
362 | def prepare_message(string, args=(), kwargs={}): # noqa: B006
363 | tokens = Colorizer._parse_with_formatting(string, args, kwargs)
364 | return ColoredMessage(tokens)
365 |
366 | @staticmethod
367 | def prepare_simple_message(string):
368 | parser = AnsiParser()
369 | parser.feed(string)
370 | tokens = parser.done()
371 | return ColoredMessage(tokens)
372 |
373 | @staticmethod
374 | def ansify(text):
375 | parser = AnsiParser()
376 | parser.feed(text.strip())
377 | tokens = parser.done(strict=False)
378 | return AnsiParser.colorize(tokens, None)
379 |
380 | @staticmethod
381 | def _parse_with_formatting(
382 | string, args, kwargs, *, recursion_depth=2, auto_arg_index=0, recursive=False
383 | ):
384 | # This function re-implements Formatter._vformat()
385 |
386 | if recursion_depth < 0:
387 | raise ValueError("Max string recursion exceeded")
388 |
389 | formatter = Formatter()
390 | parser = AnsiParser()
391 |
392 | for literal_text, field_name, format_spec, conversion in formatter.parse(string):
393 | parser.feed(literal_text, raw=recursive)
394 |
395 | if field_name is not None:
396 | if field_name == "":
397 | if auto_arg_index is False:
398 | raise ValueError(
399 | "cannot switch from manual field "
400 | "specification to automatic field "
401 | "numbering"
402 | )
403 | field_name = str(auto_arg_index)
404 | auto_arg_index += 1
405 | elif field_name.isdigit():
406 | if auto_arg_index:
407 | raise ValueError(
408 | "cannot switch from manual field "
409 | "specification to automatic field "
410 | "numbering"
411 | )
412 | auto_arg_index = False
413 |
414 | obj, _ = formatter.get_field(field_name, args, kwargs)
415 | obj = formatter.convert_field(obj, conversion)
416 |
417 | format_spec, auto_arg_index = Colorizer._parse_with_formatting(
418 | format_spec,
419 | args,
420 | kwargs,
421 | recursion_depth=recursion_depth - 1,
422 | auto_arg_index=auto_arg_index,
423 | recursive=True,
424 | )
425 |
426 | formatted = formatter.format_field(obj, format_spec)
427 | parser.feed(formatted, raw=True)
428 |
429 | tokens = parser.done()
430 |
431 | if recursive:
432 | return AnsiParser.strip(tokens), auto_arg_index
433 |
434 | return tokens
435 |
436 | @staticmethod
437 | def _parse_without_formatting(string, *, recursion_depth=2, recursive=False):
438 | if recursion_depth < 0:
439 | raise ValueError("Max string recursion exceeded")
440 |
441 | formatter = Formatter()
442 | parser = AnsiParser()
443 |
444 | messages_color_tokens = []
445 |
446 | for literal_text, field_name, format_spec, conversion in formatter.parse(string):
447 | if literal_text and literal_text[-1] in "{}":
448 | literal_text += literal_text[-1]
449 |
450 | parser.feed(literal_text, raw=recursive)
451 |
452 | if field_name is not None:
453 | if field_name == "message":
454 | if recursive:
455 | messages_color_tokens.append(None)
456 | else:
457 | color_tokens = parser.current_color_tokens()
458 | messages_color_tokens.append(color_tokens)
459 | field = "{%s" % field_name
460 | if conversion:
461 | field += "!%s" % conversion
462 | if format_spec:
463 | field += ":%s" % format_spec
464 | field += "}"
465 | parser.feed(field, raw=True)
466 |
467 | _, color_tokens = Colorizer._parse_without_formatting(
468 | format_spec, recursion_depth=recursion_depth - 1, recursive=True
469 | )
470 | messages_color_tokens.extend(color_tokens)
471 |
472 | return parser.done(), messages_color_tokens
473 |
--------------------------------------------------------------------------------
/loguru/_contextvars.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | def load_contextvar_class():
5 | if sys.version_info >= (3, 7):
6 | from contextvars import ContextVar
7 | elif sys.version_info >= (3, 5, 3):
8 | from aiocontextvars import ContextVar
9 | else:
10 | from contextvars import ContextVar
11 |
12 | return ContextVar
13 |
14 |
15 | ContextVar = load_contextvar_class()
16 |
--------------------------------------------------------------------------------
/loguru/_ctime_functions.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def load_ctime_functions():
5 | if os.name == "nt":
6 | import win32_setctime
7 |
8 | def get_ctime_windows(filepath):
9 | return os.stat(filepath).st_ctime
10 |
11 | def set_ctime_windows(filepath, timestamp):
12 | if not win32_setctime.SUPPORTED:
13 | return
14 |
15 | try:
16 | win32_setctime.setctime(filepath, timestamp)
17 | except (OSError, ValueError):
18 | pass
19 |
20 | return get_ctime_windows, set_ctime_windows
21 |
22 | if hasattr(os.stat_result, "st_birthtime"):
23 |
24 | def get_ctime_macos(filepath):
25 | return os.stat(filepath).st_birthtime
26 |
27 | def set_ctime_macos(filepath, timestamp):
28 | pass
29 |
30 | return get_ctime_macos, set_ctime_macos
31 |
32 | if hasattr(os, "getxattr") and hasattr(os, "setxattr"):
33 |
34 | def get_ctime_linux(filepath):
35 | try:
36 | return float(os.getxattr(filepath, b"user.loguru_crtime"))
37 | except OSError:
38 | return os.stat(filepath).st_mtime
39 |
40 | def set_ctime_linux(filepath, timestamp):
41 | try:
42 | os.setxattr(filepath, b"user.loguru_crtime", str(timestamp).encode("ascii"))
43 | except OSError:
44 | pass
45 |
46 | return get_ctime_linux, set_ctime_linux
47 |
48 | def get_ctime_fallback(filepath):
49 | return os.stat(filepath).st_mtime
50 |
51 | def set_ctime_fallback(filepath, timestamp):
52 | pass
53 |
54 | return get_ctime_fallback, set_ctime_fallback
55 |
56 |
57 | get_ctime, set_ctime = load_ctime_functions()
58 |
--------------------------------------------------------------------------------
/loguru/_datetime.py:
--------------------------------------------------------------------------------
1 | import re
2 | from calendar import day_abbr, day_name, month_abbr, month_name
3 | from datetime import datetime as datetime_
4 | from datetime import timedelta, timezone
5 | from time import localtime, strftime
6 |
7 | tokens = r"H{1,2}|h{1,2}|m{1,2}|s{1,2}|S+|YYYY|YY|M{1,4}|D{1,4}|Z{1,2}|zz|A|X|x|E|Q|dddd|ddd|d"
8 |
9 | pattern = re.compile(r"(?:{0})|\[(?:{0}|!UTC|)\]".format(tokens))
10 |
11 |
12 | class datetime(datetime_): # noqa: N801
13 | def __format__(self, spec):
14 | if spec.endswith("!UTC"):
15 | dt = self.astimezone(timezone.utc)
16 | spec = spec[:-4]
17 | else:
18 | dt = self
19 |
20 | if not spec:
21 | spec = "%Y-%m-%dT%H:%M:%S.%f%z"
22 |
23 | if "%" in spec:
24 | return datetime_.__format__(dt, spec)
25 |
26 | if "SSSSSSS" in spec:
27 | raise ValueError(
28 | "Invalid time format: the provided format string contains more than six successive "
29 | "'S' characters. This may be due to an attempt to use nanosecond precision, which "
30 | "is not supported."
31 | )
32 |
33 | year, month, day, hour, minute, second, weekday, yearday, _ = dt.timetuple()
34 | microsecond = dt.microsecond
35 | timestamp = dt.timestamp()
36 | tzinfo = dt.tzinfo or timezone(timedelta(seconds=0))
37 | offset = tzinfo.utcoffset(dt).total_seconds()
38 | sign = ("-", "+")[offset >= 0]
39 | (h, m), s = divmod(abs(offset // 60), 60), abs(offset) % 60
40 |
41 | rep = {
42 | "YYYY": "%04d" % year,
43 | "YY": "%02d" % (year % 100),
44 | "Q": "%d" % ((month - 1) // 3 + 1),
45 | "MMMM": month_name[month],
46 | "MMM": month_abbr[month],
47 | "MM": "%02d" % month,
48 | "M": "%d" % month,
49 | "DDDD": "%03d" % yearday,
50 | "DDD": "%d" % yearday,
51 | "DD": "%02d" % day,
52 | "D": "%d" % day,
53 | "dddd": day_name[weekday],
54 | "ddd": day_abbr[weekday],
55 | "d": "%d" % weekday,
56 | "E": "%d" % (weekday + 1),
57 | "HH": "%02d" % hour,
58 | "H": "%d" % hour,
59 | "hh": "%02d" % ((hour - 1) % 12 + 1),
60 | "h": "%d" % ((hour - 1) % 12 + 1),
61 | "mm": "%02d" % minute,
62 | "m": "%d" % minute,
63 | "ss": "%02d" % second,
64 | "s": "%d" % second,
65 | "S": "%d" % (microsecond // 100000),
66 | "SS": "%02d" % (microsecond // 10000),
67 | "SSS": "%03d" % (microsecond // 1000),
68 | "SSSS": "%04d" % (microsecond // 100),
69 | "SSSSS": "%05d" % (microsecond // 10),
70 | "SSSSSS": "%06d" % microsecond,
71 | "A": ("AM", "PM")[hour // 12],
72 | "Z": "%s%02d:%02d%s" % (sign, h, m, (":%09.06f" % s)[: 11 if s % 1 else 3] * (s > 0)),
73 | "ZZ": "%s%02d%02d%s" % (sign, h, m, ("%09.06f" % s)[: 10 if s % 1 else 2] * (s > 0)),
74 | "zz": tzinfo.tzname(dt) or "",
75 | "X": "%d" % timestamp,
76 | "x": "%d" % (int(timestamp) * 1000000 + microsecond),
77 | }
78 |
79 | def get(m):
80 | try:
81 | return rep[m.group(0)]
82 | except KeyError:
83 | return m.group(0)[1:-1]
84 |
85 | return pattern.sub(get, spec)
86 |
87 |
88 | def aware_now():
89 | now = datetime_.now()
90 | timestamp = now.timestamp()
91 | local = localtime(timestamp)
92 |
93 | try:
94 | seconds = local.tm_gmtoff
95 | zone = local.tm_zone
96 | except AttributeError:
97 | # Workaround for Python 3.5.
98 | utc_naive = datetime_.fromtimestamp(timestamp, tz=timezone.utc).replace(tzinfo=None)
99 | offset = datetime_.fromtimestamp(timestamp) - utc_naive
100 | seconds = offset.total_seconds()
101 | zone = strftime("%Z")
102 |
103 | tzinfo = timezone(timedelta(seconds=seconds), zone)
104 |
105 | return datetime.combine(now.date(), now.time().replace(tzinfo=tzinfo))
106 |
--------------------------------------------------------------------------------
/loguru/_defaults.py:
--------------------------------------------------------------------------------
1 | from os import environ
2 |
3 |
4 | def env(key, type_, default=None):
5 | if key not in environ:
6 | return default
7 |
8 | val = environ[key]
9 |
10 | if isinstance(type_, str):
11 | return val
12 | if isinstance(type_, bool):
13 | if val.lower() in ["1", "true", "yes", "y", "ok", "on"]:
14 | return True
15 | if val.lower() in ["0", "false", "no", "n", "nok", "off"]:
16 | return False
17 | raise ValueError(
18 | "Invalid environment variable '%s' (expected a boolean): '%s'" % (key, val)
19 | )
20 | if isinstance(type_, int):
21 | try:
22 | return int(val)
23 | except ValueError:
24 | raise ValueError(
25 | "Invalid environment variable '%s' (expected an integer): '%s'" % (key, val)
26 | ) from None
27 | raise ValueError("The requested type '%r' is not supported" % type_)
28 |
29 |
30 | LOGURU_AUTOINIT = env("LOGURU_AUTOINIT", bool, True)
31 |
32 | LOGURU_FORMAT = env(
33 | "LOGURU_FORMAT",
34 | str,
35 | "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
36 | "{level: <8} | "
37 | "{name}:{function}:{line} - {message}",
38 | )
39 | LOGURU_FILTER = env("LOGURU_FILTER", str, None)
40 | LOGURU_LEVEL = env("LOGURU_LEVEL", str, "DEBUG")
41 | LOGURU_COLORIZE = env("LOGURU_COLORIZE", bool, None)
42 | LOGURU_SERIALIZE = env("LOGURU_SERIALIZE", bool, False)
43 | LOGURU_BACKTRACE = env("LOGURU_BACKTRACE", bool, True)
44 | LOGURU_DIAGNOSE = env("LOGURU_DIAGNOSE", bool, True)
45 | LOGURU_ENQUEUE = env("LOGURU_ENQUEUE", bool, False)
46 | LOGURU_CONTEXT = env("LOGURU_CONTEXT", str, None)
47 | LOGURU_CATCH = env("LOGURU_CATCH", bool, True)
48 |
49 | LOGURU_TRACE_NO = env("LOGURU_TRACE_NO", int, 5)
50 | LOGURU_TRACE_COLOR = env("LOGURU_TRACE_COLOR", str, "")
51 | LOGURU_TRACE_ICON = env("LOGURU_TRACE_ICON", str, "\u270F\uFE0F") # Pencil
52 |
53 | LOGURU_DEBUG_NO = env("LOGURU_DEBUG_NO", int, 10)
54 | LOGURU_DEBUG_COLOR = env("LOGURU_DEBUG_COLOR", str, "")
55 | LOGURU_DEBUG_ICON = env("LOGURU_DEBUG_ICON", str, "\U0001F41E") # Lady Beetle
56 |
57 | LOGURU_INFO_NO = env("LOGURU_INFO_NO", int, 20)
58 | LOGURU_INFO_COLOR = env("LOGURU_INFO_COLOR", str, "")
59 | LOGURU_INFO_ICON = env("LOGURU_INFO_ICON", str, "\u2139\uFE0F") # Information
60 |
61 | LOGURU_SUCCESS_NO = env("LOGURU_SUCCESS_NO", int, 25)
62 | LOGURU_SUCCESS_COLOR = env("LOGURU_SUCCESS_COLOR", str, "")
63 | LOGURU_SUCCESS_ICON = env("LOGURU_SUCCESS_ICON", str, "\u2705") # White Heavy Check Mark
64 |
65 | LOGURU_WARNING_NO = env("LOGURU_WARNING_NO", int, 30)
66 | LOGURU_WARNING_COLOR = env("LOGURU_WARNING_COLOR", str, "")
67 | LOGURU_WARNING_ICON = env("LOGURU_WARNING_ICON", str, "\u26A0\uFE0F") # Warning
68 |
69 | LOGURU_ERROR_NO = env("LOGURU_ERROR_NO", int, 40)
70 | LOGURU_ERROR_COLOR = env("LOGURU_ERROR_COLOR", str, "")
71 | LOGURU_ERROR_ICON = env("LOGURU_ERROR_ICON", str, "\u274C") # Cross Mark
72 |
73 | LOGURU_CRITICAL_NO = env("LOGURU_CRITICAL_NO", int, 50)
74 | LOGURU_CRITICAL_COLOR = env("LOGURU_CRITICAL_COLOR", str, "")
75 | LOGURU_CRITICAL_ICON = env("LOGURU_CRITICAL_ICON", str, "\u2620\uFE0F") # Skull and Crossbones
76 |
--------------------------------------------------------------------------------
/loguru/_error_interceptor.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import traceback
3 |
4 |
5 | class ErrorInterceptor:
6 | def __init__(self, should_catch, handler_id):
7 | self._should_catch = should_catch
8 | self._handler_id = handler_id
9 |
10 | def should_catch(self):
11 | return self._should_catch
12 |
13 | def print(self, record=None, *, exception=None):
14 | if not sys.stderr:
15 | return
16 |
17 | if exception is None:
18 | type_, value, traceback_ = sys.exc_info()
19 | else:
20 | type_, value, traceback_ = (type(exception), exception, exception.__traceback__)
21 |
22 | try:
23 | sys.stderr.write("--- Logging error in Loguru Handler #%d ---\n" % self._handler_id)
24 | try:
25 | record_repr = str(record)
26 | except Exception:
27 | record_repr = "/!\\ Unprintable record /!\\"
28 | sys.stderr.write("Record was: %s\n" % record_repr)
29 | traceback.print_exception(type_, value, traceback_, None, sys.stderr)
30 | sys.stderr.write("--- End of logging error ---\n")
31 | except OSError:
32 | pass
33 | finally:
34 | del type_, value, traceback_
35 |
--------------------------------------------------------------------------------
/loguru/_file_sink.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import decimal
3 | import glob
4 | import numbers
5 | import os
6 | import shutil
7 | import string
8 | from functools import partial
9 | from stat import ST_DEV, ST_INO
10 |
11 | from . import _string_parsers as string_parsers
12 | from ._ctime_functions import get_ctime, set_ctime
13 | from ._datetime import aware_now
14 |
15 |
16 | def generate_rename_path(root, ext, creation_time):
17 | creation_datetime = datetime.datetime.fromtimestamp(creation_time)
18 | date = FileDateFormatter(creation_datetime)
19 |
20 | renamed_path = "{}.{}{}".format(root, date, ext)
21 | counter = 1
22 |
23 | while os.path.exists(renamed_path):
24 | counter += 1
25 | renamed_path = "{}.{}.{}{}".format(root, date, counter, ext)
26 |
27 | return renamed_path
28 |
29 |
30 | class FileDateFormatter:
31 | def __init__(self, datetime=None):
32 | self.datetime = datetime or aware_now()
33 |
34 | def __format__(self, spec):
35 | if not spec:
36 | spec = "%Y-%m-%d_%H-%M-%S_%f"
37 | return self.datetime.__format__(spec)
38 |
39 |
40 | class Compression:
41 | @staticmethod
42 | def add_compress(path_in, path_out, opener, **kwargs):
43 | with opener(path_out, **kwargs) as f_comp:
44 | f_comp.add(path_in, os.path.basename(path_in))
45 |
46 | @staticmethod
47 | def write_compress(path_in, path_out, opener, **kwargs):
48 | with opener(path_out, **kwargs) as f_comp:
49 | f_comp.write(path_in, os.path.basename(path_in))
50 |
51 | @staticmethod
52 | def copy_compress(path_in, path_out, opener, **kwargs):
53 | with open(path_in, "rb") as f_in:
54 | with opener(path_out, **kwargs) as f_out:
55 | shutil.copyfileobj(f_in, f_out)
56 |
57 | @staticmethod
58 | def compression(path_in, ext, compress_function):
59 | path_out = "{}{}".format(path_in, ext)
60 |
61 | if os.path.exists(path_out):
62 | creation_time = get_ctime(path_out)
63 | root, ext_before = os.path.splitext(path_in)
64 | renamed_path = generate_rename_path(root, ext_before + ext, creation_time)
65 | os.rename(path_out, renamed_path)
66 | compress_function(path_in, path_out)
67 | os.remove(path_in)
68 |
69 |
70 | class Retention:
71 | @staticmethod
72 | def retention_count(logs, number):
73 | def key_log(log):
74 | return (-os.stat(log).st_mtime, log)
75 |
76 | for log in sorted(logs, key=key_log)[number:]:
77 | os.remove(log)
78 |
79 | @staticmethod
80 | def retention_age(logs, seconds):
81 | t = datetime.datetime.now().timestamp()
82 | for log in logs:
83 | if os.stat(log).st_mtime <= t - seconds:
84 | os.remove(log)
85 |
86 |
87 | class Rotation:
88 | @staticmethod
89 | def forward_day(t):
90 | return t + datetime.timedelta(days=1)
91 |
92 | @staticmethod
93 | def forward_weekday(t, weekday):
94 | while True:
95 | t += datetime.timedelta(days=1)
96 | if t.weekday() == weekday:
97 | return t
98 |
99 | @staticmethod
100 | def forward_interval(t, interval):
101 | return t + interval
102 |
103 | @staticmethod
104 | def rotation_size(message, file, size_limit):
105 | file.seek(0, 2)
106 | return file.tell() + len(message) > size_limit
107 |
108 | class RotationTime:
109 | def __init__(self, step_forward, time_init=None):
110 | self._step_forward = step_forward
111 | self._time_init = time_init
112 | self._limit = None
113 |
114 | def __call__(self, message, file):
115 | record_time = message.record["time"]
116 |
117 | if self._limit is None:
118 | filepath = os.path.realpath(file.name)
119 | creation_time = get_ctime(filepath)
120 | set_ctime(filepath, creation_time)
121 | start_time = datetime.datetime.fromtimestamp(
122 | creation_time, tz=datetime.timezone.utc
123 | )
124 |
125 | time_init = self._time_init
126 |
127 | if time_init is None:
128 | limit = start_time.astimezone(record_time.tzinfo).replace(tzinfo=None)
129 | limit = self._step_forward(limit)
130 | else:
131 | tzinfo = record_time.tzinfo if time_init.tzinfo is None else time_init.tzinfo
132 | limit = start_time.astimezone(tzinfo).replace(
133 | hour=time_init.hour,
134 | minute=time_init.minute,
135 | second=time_init.second,
136 | microsecond=time_init.microsecond,
137 | )
138 |
139 | if limit <= start_time:
140 | limit = self._step_forward(limit)
141 |
142 | if time_init.tzinfo is None:
143 | limit = limit.replace(tzinfo=None)
144 |
145 | self._limit = limit
146 |
147 | if self._limit.tzinfo is None:
148 | record_time = record_time.replace(tzinfo=None)
149 |
150 | if record_time >= self._limit:
151 | while self._limit <= record_time:
152 | self._limit = self._step_forward(self._limit)
153 | return True
154 | return False
155 |
156 |
157 | class FileSink:
158 | def __init__(
159 | self,
160 | path,
161 | *,
162 | rotation=None,
163 | retention=None,
164 | compression=None,
165 | delay=False,
166 | watch=False,
167 | mode="a",
168 | buffering=1,
169 | encoding="utf8",
170 | **kwargs
171 | ):
172 | self.encoding = encoding
173 |
174 | self._kwargs = {**kwargs, "mode": mode, "buffering": buffering, "encoding": self.encoding}
175 | self._path = str(path)
176 |
177 | self._glob_patterns = self._make_glob_patterns(self._path)
178 | self._rotation_function = self._make_rotation_function(rotation)
179 | self._retention_function = self._make_retention_function(retention)
180 | self._compression_function = self._make_compression_function(compression)
181 |
182 | self._file = None
183 | self._file_path = None
184 |
185 | self._watch = watch
186 | self._file_dev = -1
187 | self._file_ino = -1
188 |
189 | if not delay:
190 | path = self._create_path()
191 | self._create_dirs(path)
192 | self._create_file(path)
193 |
194 | def write(self, message):
195 | if self._file is None:
196 | path = self._create_path()
197 | self._create_dirs(path)
198 | self._create_file(path)
199 |
200 | if self._watch:
201 | self._reopen_if_needed()
202 |
203 | if self._rotation_function is not None and self._rotation_function(message, self._file):
204 | self._terminate_file(is_rotating=True)
205 |
206 | self._file.write(message)
207 |
208 | def stop(self):
209 | if self._watch:
210 | self._reopen_if_needed()
211 |
212 | self._terminate_file(is_rotating=False)
213 |
214 | def tasks_to_complete(self):
215 | return []
216 |
217 | def _create_path(self):
218 | path = self._path.format_map({"time": FileDateFormatter()})
219 | return os.path.abspath(path)
220 |
221 | def _create_dirs(self, path):
222 | dirname = os.path.dirname(path)
223 | os.makedirs(dirname, exist_ok=True)
224 |
225 | def _create_file(self, path):
226 | self._file = open(path, **self._kwargs)
227 | self._file_path = path
228 |
229 | if self._watch:
230 | fileno = self._file.fileno()
231 | result = os.fstat(fileno)
232 | self._file_dev = result[ST_DEV]
233 | self._file_ino = result[ST_INO]
234 |
235 | def _close_file(self):
236 | self._file.flush()
237 | self._file.close()
238 |
239 | self._file = None
240 | self._file_path = None
241 | self._file_dev = -1
242 | self._file_ino = -1
243 |
244 | def _reopen_if_needed(self):
245 | # Implemented based on standard library:
246 | # https://github.com/python/cpython/blob/cb589d1b/Lib/logging/handlers.py#L486
247 | if not self._file:
248 | return
249 |
250 | filepath = self._file_path
251 |
252 | try:
253 | result = os.stat(filepath)
254 | except FileNotFoundError:
255 | result = None
256 |
257 | if not result or result[ST_DEV] != self._file_dev or result[ST_INO] != self._file_ino:
258 | self._close_file()
259 | self._create_dirs(filepath)
260 | self._create_file(filepath)
261 |
262 | def _terminate_file(self, *, is_rotating=False):
263 | old_path = self._file_path
264 |
265 | if self._file is not None:
266 | self._close_file()
267 |
268 | if is_rotating:
269 | new_path = self._create_path()
270 | self._create_dirs(new_path)
271 |
272 | if new_path == old_path:
273 | creation_time = get_ctime(old_path)
274 | root, ext = os.path.splitext(old_path)
275 | renamed_path = generate_rename_path(root, ext, creation_time)
276 | os.rename(old_path, renamed_path)
277 | old_path = renamed_path
278 |
279 | if is_rotating or self._rotation_function is None:
280 | if self._compression_function is not None and old_path is not None:
281 | self._compression_function(old_path)
282 |
283 | if self._retention_function is not None:
284 | logs = {
285 | file
286 | for pattern in self._glob_patterns
287 | for file in glob.glob(pattern)
288 | if os.path.isfile(file)
289 | }
290 | self._retention_function(list(logs))
291 |
292 | if is_rotating:
293 | self._create_file(new_path)
294 | set_ctime(new_path, datetime.datetime.now().timestamp())
295 |
296 | @staticmethod
297 | def _make_glob_patterns(path):
298 | formatter = string.Formatter()
299 | tokens = formatter.parse(path)
300 | escaped = "".join(glob.escape(text) + "*" * (name is not None) for text, name, *_ in tokens)
301 |
302 | root, ext = os.path.splitext(escaped)
303 |
304 | if not ext:
305 | return [escaped, escaped + ".*"]
306 |
307 | return [escaped, escaped + ".*", root + ".*" + ext, root + ".*" + ext + ".*"]
308 |
309 | @staticmethod
310 | def _make_rotation_function(rotation):
311 | if rotation is None:
312 | return None
313 | if isinstance(rotation, str):
314 | size = string_parsers.parse_size(rotation)
315 | if size is not None:
316 | return FileSink._make_rotation_function(size)
317 | interval = string_parsers.parse_duration(rotation)
318 | if interval is not None:
319 | return FileSink._make_rotation_function(interval)
320 | frequency = string_parsers.parse_frequency(rotation)
321 | if frequency is not None:
322 | return Rotation.RotationTime(frequency)
323 | daytime = string_parsers.parse_daytime(rotation)
324 | if daytime is not None:
325 | day, time = daytime
326 | if day is None:
327 | return FileSink._make_rotation_function(time)
328 | if time is None:
329 | time = datetime.time(0, 0, 0)
330 | step_forward = partial(Rotation.forward_weekday, weekday=day)
331 | return Rotation.RotationTime(step_forward, time)
332 | raise ValueError("Cannot parse rotation from: '%s'" % rotation)
333 | if isinstance(rotation, (numbers.Real, decimal.Decimal)):
334 | return partial(Rotation.rotation_size, size_limit=rotation)
335 | if isinstance(rotation, datetime.time):
336 | return Rotation.RotationTime(Rotation.forward_day, rotation)
337 | if isinstance(rotation, datetime.timedelta):
338 | step_forward = partial(Rotation.forward_interval, interval=rotation)
339 | return Rotation.RotationTime(step_forward)
340 | if callable(rotation):
341 | return rotation
342 | raise TypeError("Cannot infer rotation for objects of type: '%s'" % type(rotation).__name__)
343 |
344 | @staticmethod
345 | def _make_retention_function(retention):
346 | if retention is None:
347 | return None
348 | if isinstance(retention, str):
349 | interval = string_parsers.parse_duration(retention)
350 | if interval is None:
351 | raise ValueError("Cannot parse retention from: '%s'" % retention)
352 | return FileSink._make_retention_function(interval)
353 | if isinstance(retention, int):
354 | return partial(Retention.retention_count, number=retention)
355 | if isinstance(retention, datetime.timedelta):
356 | return partial(Retention.retention_age, seconds=retention.total_seconds())
357 | if callable(retention):
358 | return retention
359 | raise TypeError(
360 | "Cannot infer retention for objects of type: '%s'" % type(retention).__name__
361 | )
362 |
363 | @staticmethod
364 | def _make_compression_function(compression):
365 | if compression is None:
366 | return None
367 | if isinstance(compression, str):
368 | ext = compression.strip().lstrip(".")
369 |
370 | if ext == "gz":
371 | import gzip
372 |
373 | compress = partial(Compression.copy_compress, opener=gzip.open, mode="wb")
374 | elif ext == "bz2":
375 | import bz2
376 |
377 | compress = partial(Compression.copy_compress, opener=bz2.open, mode="wb")
378 |
379 | elif ext == "xz":
380 | import lzma
381 |
382 | compress = partial(
383 | Compression.copy_compress, opener=lzma.open, mode="wb", format=lzma.FORMAT_XZ
384 | )
385 |
386 | elif ext == "lzma":
387 | import lzma
388 |
389 | compress = partial(
390 | Compression.copy_compress, opener=lzma.open, mode="wb", format=lzma.FORMAT_ALONE
391 | )
392 | elif ext == "tar":
393 | import tarfile
394 |
395 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:")
396 | elif ext == "tar.gz":
397 | import gzip
398 | import tarfile
399 |
400 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:gz")
401 | elif ext == "tar.bz2":
402 | import bz2
403 | import tarfile
404 |
405 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:bz2")
406 |
407 | elif ext == "tar.xz":
408 | import lzma
409 | import tarfile
410 |
411 | compress = partial(Compression.add_compress, opener=tarfile.open, mode="w:xz")
412 | elif ext == "zip":
413 | import zipfile
414 |
415 | compress = partial(
416 | Compression.write_compress,
417 | opener=zipfile.ZipFile,
418 | mode="w",
419 | compression=zipfile.ZIP_DEFLATED,
420 | )
421 | else:
422 | raise ValueError("Invalid compression format: '%s'" % ext)
423 |
424 | return partial(Compression.compression, ext="." + ext, compress_function=compress)
425 | if callable(compression):
426 | return compression
427 | raise TypeError(
428 | "Cannot infer compression for objects of type: '%s'" % type(compression).__name__
429 | )
430 |
--------------------------------------------------------------------------------
/loguru/_filters.py:
--------------------------------------------------------------------------------
1 | def filter_none(record):
2 | return record["name"] is not None
3 |
4 |
5 | def filter_by_name(record, parent, length):
6 | name = record["name"]
7 | if name is None:
8 | return False
9 | return (name + ".")[:length] == parent
10 |
11 |
12 | def filter_by_level(record, level_per_module):
13 | name = record["name"]
14 |
15 | while True:
16 | level = level_per_module.get(name, None)
17 | if level is False:
18 | return False
19 | if level is not None:
20 | return record["level"].no >= level
21 | if not name:
22 | return True
23 | index = name.rfind(".")
24 | name = name[:index] if index != -1 else ""
25 |
--------------------------------------------------------------------------------
/loguru/_get_frame.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from sys import exc_info
3 |
4 |
5 | def get_frame_fallback(n):
6 | try:
7 | raise Exception
8 | except Exception:
9 | frame = exc_info()[2].tb_frame.f_back
10 | for _ in range(n):
11 | frame = frame.f_back
12 | return frame
13 |
14 |
15 | def load_get_frame_function():
16 | if hasattr(sys, "_getframe"):
17 | get_frame = sys._getframe
18 | else:
19 | get_frame = get_frame_fallback
20 | return get_frame
21 |
22 |
23 | get_frame = load_get_frame_function()
24 |
--------------------------------------------------------------------------------
/loguru/_handler.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import json
3 | import multiprocessing
4 | import os
5 | import threading
6 | from contextlib import contextmanager
7 | from threading import Thread
8 |
9 | from ._colorizer import Colorizer
10 | from ._locks_machinery import create_handler_lock
11 |
12 |
13 | def prepare_colored_format(format_, ansi_level):
14 | colored = Colorizer.prepare_format(format_)
15 | return colored, colored.colorize(ansi_level)
16 |
17 |
18 | def prepare_stripped_format(format_):
19 | colored = Colorizer.prepare_format(format_)
20 | return colored.strip()
21 |
22 |
23 | def memoize(function):
24 | return functools.lru_cache(maxsize=64)(function)
25 |
26 |
27 | class Message(str):
28 | __slots__ = ("record",)
29 |
30 |
31 | class Handler:
32 | def __init__(
33 | self,
34 | *,
35 | sink,
36 | name,
37 | levelno,
38 | formatter,
39 | is_formatter_dynamic,
40 | filter_,
41 | colorize,
42 | serialize,
43 | enqueue,
44 | multiprocessing_context,
45 | error_interceptor,
46 | exception_formatter,
47 | id_,
48 | levels_ansi_codes
49 | ):
50 | self._name = name
51 | self._sink = sink
52 | self._levelno = levelno
53 | self._formatter = formatter
54 | self._is_formatter_dynamic = is_formatter_dynamic
55 | self._filter = filter_
56 | self._colorize = colorize
57 | self._serialize = serialize
58 | self._enqueue = enqueue
59 | self._multiprocessing_context = multiprocessing_context
60 | self._error_interceptor = error_interceptor
61 | self._exception_formatter = exception_formatter
62 | self._id = id_
63 | self._levels_ansi_codes = levels_ansi_codes # Warning, reference shared among handlers
64 |
65 | self._decolorized_format = None
66 | self._precolorized_formats = {}
67 | self._memoize_dynamic_format = None
68 |
69 | self._stopped = False
70 | self._lock = create_handler_lock()
71 | self._lock_acquired = threading.local()
72 | self._queue = None
73 | self._queue_lock = None
74 | self._confirmation_event = None
75 | self._confirmation_lock = None
76 | self._owner_process_pid = None
77 | self._thread = None
78 |
79 | if self._is_formatter_dynamic:
80 | if self._colorize:
81 | self._memoize_dynamic_format = memoize(prepare_colored_format)
82 | else:
83 | self._memoize_dynamic_format = memoize(prepare_stripped_format)
84 | else:
85 | if self._colorize:
86 | for level_name in self._levels_ansi_codes:
87 | self.update_format(level_name)
88 | else:
89 | self._decolorized_format = self._formatter.strip()
90 |
91 | if self._enqueue:
92 | if self._multiprocessing_context is None:
93 | self._queue = multiprocessing.SimpleQueue()
94 | self._confirmation_event = multiprocessing.Event()
95 | self._confirmation_lock = multiprocessing.Lock()
96 | else:
97 | self._queue = self._multiprocessing_context.SimpleQueue()
98 | self._confirmation_event = self._multiprocessing_context.Event()
99 | self._confirmation_lock = self._multiprocessing_context.Lock()
100 | self._queue_lock = create_handler_lock()
101 | self._owner_process_pid = os.getpid()
102 | self._thread = Thread(
103 | target=self._queued_writer, daemon=True, name="loguru-writer-%d" % self._id
104 | )
105 | self._thread.start()
106 |
107 | def __repr__(self):
108 | return "(id=%d, level=%d, sink=%s)" % (self._id, self._levelno, self._name)
109 |
110 | @contextmanager
111 | def _protected_lock(self):
112 | """Acquire the lock, but fail fast if its already acquired by the current thread."""
113 | if getattr(self._lock_acquired, "acquired", False):
114 | raise RuntimeError(
115 | "Could not acquire internal lock because it was already in use (deadlock avoided). "
116 | "This likely happened because the logger was re-used inside a sink, a signal "
117 | "handler or a '__del__' method. This is not permitted because the logger and its "
118 | "handlers are not re-entrant."
119 | )
120 | self._lock_acquired.acquired = True
121 | try:
122 | with self._lock:
123 | yield
124 | finally:
125 | self._lock_acquired.acquired = False
126 |
127 | def emit(self, record, level_id, from_decorator, is_raw, colored_message):
128 | try:
129 | if self._levelno > record["level"].no:
130 | return
131 |
132 | if self._filter is not None:
133 | if not self._filter(record):
134 | return
135 |
136 | if self._is_formatter_dynamic:
137 | dynamic_format = self._formatter(record)
138 |
139 | formatter_record = record.copy()
140 |
141 | if not record["exception"]:
142 | formatter_record["exception"] = ""
143 | else:
144 | type_, value, tb = record["exception"]
145 | formatter = self._exception_formatter
146 | lines = formatter.format_exception(type_, value, tb, from_decorator=from_decorator)
147 | formatter_record["exception"] = "".join(lines)
148 |
149 | if colored_message is not None and colored_message.stripped != record["message"]:
150 | colored_message = None
151 |
152 | if is_raw:
153 | if colored_message is None or not self._colorize:
154 | formatted = record["message"]
155 | else:
156 | ansi_level = self._levels_ansi_codes[level_id]
157 | formatted = colored_message.colorize(ansi_level)
158 | elif self._is_formatter_dynamic:
159 | if not self._colorize:
160 | precomputed_format = self._memoize_dynamic_format(dynamic_format)
161 | formatted = precomputed_format.format_map(formatter_record)
162 | elif colored_message is None:
163 | ansi_level = self._levels_ansi_codes[level_id]
164 | _, precomputed_format = self._memoize_dynamic_format(dynamic_format, ansi_level)
165 | formatted = precomputed_format.format_map(formatter_record)
166 | else:
167 | ansi_level = self._levels_ansi_codes[level_id]
168 | formatter, precomputed_format = self._memoize_dynamic_format(
169 | dynamic_format, ansi_level
170 | )
171 | coloring_message = formatter.make_coloring_message(
172 | record["message"], ansi_level=ansi_level, colored_message=colored_message
173 | )
174 | formatter_record["message"] = coloring_message
175 | formatted = precomputed_format.format_map(formatter_record)
176 |
177 | else:
178 | if not self._colorize:
179 | precomputed_format = self._decolorized_format
180 | formatted = precomputed_format.format_map(formatter_record)
181 | elif colored_message is None:
182 | ansi_level = self._levels_ansi_codes[level_id]
183 | precomputed_format = self._precolorized_formats[level_id]
184 | formatted = precomputed_format.format_map(formatter_record)
185 | else:
186 | ansi_level = self._levels_ansi_codes[level_id]
187 | precomputed_format = self._precolorized_formats[level_id]
188 | coloring_message = self._formatter.make_coloring_message(
189 | record["message"], ansi_level=ansi_level, colored_message=colored_message
190 | )
191 | formatter_record["message"] = coloring_message
192 | formatted = precomputed_format.format_map(formatter_record)
193 |
194 | if self._serialize:
195 | formatted = self._serialize_record(formatted, record)
196 |
197 | str_record = Message(formatted)
198 | str_record.record = record
199 |
200 | with self._protected_lock():
201 | if self._stopped:
202 | return
203 | if self._enqueue:
204 | self._queue.put(str_record)
205 | else:
206 | self._sink.write(str_record)
207 | except Exception:
208 | if not self._error_interceptor.should_catch():
209 | raise
210 | self._error_interceptor.print(record)
211 |
212 | def stop(self):
213 | with self._protected_lock():
214 | self._stopped = True
215 | if self._enqueue:
216 | if self._owner_process_pid != os.getpid():
217 | return
218 | self._queue.put(None)
219 | self._thread.join()
220 | if hasattr(self._queue, "close"):
221 | self._queue.close()
222 |
223 | self._sink.stop()
224 |
225 | def complete_queue(self):
226 | if not self._enqueue:
227 | return
228 |
229 | with self._confirmation_lock:
230 | self._queue.put(True)
231 | self._confirmation_event.wait()
232 | self._confirmation_event.clear()
233 |
234 | def tasks_to_complete(self):
235 | if self._enqueue and self._owner_process_pid != os.getpid():
236 | return []
237 | lock = self._queue_lock if self._enqueue else self._protected_lock()
238 | with lock:
239 | return self._sink.tasks_to_complete()
240 |
241 | def update_format(self, level_id):
242 | if not self._colorize or self._is_formatter_dynamic:
243 | return
244 | ansi_code = self._levels_ansi_codes[level_id]
245 | self._precolorized_formats[level_id] = self._formatter.colorize(ansi_code)
246 |
247 | @property
248 | def levelno(self):
249 | return self._levelno
250 |
251 | @staticmethod
252 | def _serialize_record(text, record):
253 | exception = record["exception"]
254 |
255 | if exception is not None:
256 | exception = {
257 | "type": None if exception.type is None else exception.type.__name__,
258 | "value": exception.value,
259 | "traceback": bool(exception.traceback),
260 | }
261 |
262 | serializable = {
263 | "text": text,
264 | "record": {
265 | "elapsed": {
266 | "repr": record["elapsed"],
267 | "seconds": record["elapsed"].total_seconds(),
268 | },
269 | "exception": exception,
270 | "extra": record["extra"],
271 | "file": {"name": record["file"].name, "path": record["file"].path},
272 | "function": record["function"],
273 | "level": {
274 | "icon": record["level"].icon,
275 | "name": record["level"].name,
276 | "no": record["level"].no,
277 | },
278 | "line": record["line"],
279 | "message": record["message"],
280 | "module": record["module"],
281 | "name": record["name"],
282 | "process": {"id": record["process"].id, "name": record["process"].name},
283 | "thread": {"id": record["thread"].id, "name": record["thread"].name},
284 | "time": {"repr": record["time"], "timestamp": record["time"].timestamp()},
285 | },
286 | }
287 |
288 | return json.dumps(serializable, default=str, ensure_ascii=False) + "\n"
289 |
290 | def _queued_writer(self):
291 | message = None
292 | queue = self._queue
293 |
294 | # We need to use a lock to protect sink during fork.
295 | # Particularly, writing to stderr may lead to deadlock in child process.
296 | lock = self._queue_lock
297 |
298 | while True:
299 | try:
300 | message = queue.get()
301 | except Exception:
302 | with lock:
303 | self._error_interceptor.print(None)
304 | continue
305 |
306 | if message is None:
307 | break
308 |
309 | if message is True:
310 | self._confirmation_event.set()
311 | continue
312 |
313 | with lock:
314 | try:
315 | self._sink.write(message)
316 | except Exception:
317 | self._error_interceptor.print(message.record)
318 |
319 | def __getstate__(self):
320 | state = self.__dict__.copy()
321 | state["_lock"] = None
322 | state["_lock_acquired"] = None
323 | state["_memoize_dynamic_format"] = None
324 | if self._enqueue:
325 | state["_sink"] = None
326 | state["_thread"] = None
327 | state["_owner_process"] = None
328 | state["_queue_lock"] = None
329 | return state
330 |
331 | def __setstate__(self, state):
332 | self.__dict__.update(state)
333 | self._lock = create_handler_lock()
334 | self._lock_acquired = threading.local()
335 | if self._enqueue:
336 | self._queue_lock = create_handler_lock()
337 | if self._is_formatter_dynamic:
338 | if self._colorize:
339 | self._memoize_dynamic_format = memoize(prepare_colored_format)
340 | else:
341 | self._memoize_dynamic_format = memoize(prepare_stripped_format)
342 |
--------------------------------------------------------------------------------
/loguru/_locks_machinery.py:
--------------------------------------------------------------------------------
1 | import os
2 | import threading
3 | import weakref
4 |
5 | if not hasattr(os, "register_at_fork"):
6 |
7 | def create_logger_lock():
8 | return threading.Lock()
9 |
10 | def create_handler_lock():
11 | return threading.Lock()
12 |
13 | else:
14 | # While forking, we need to sanitize all locks to make sure the child process doesn't run into
15 | # a deadlock (if a lock already acquired is inherited) and to protect sink from corrupted state.
16 | # It's very important to acquire logger locks before handlers one to prevent possible deadlock
17 | # while 'remove()' is called for example.
18 |
19 | logger_locks = weakref.WeakSet()
20 | handler_locks = weakref.WeakSet()
21 |
22 | def acquire_locks():
23 | for lock in logger_locks:
24 | lock.acquire()
25 |
26 | for lock in handler_locks:
27 | lock.acquire()
28 |
29 | def release_locks():
30 | for lock in logger_locks:
31 | lock.release()
32 |
33 | for lock in handler_locks:
34 | lock.release()
35 |
36 | os.register_at_fork(
37 | before=acquire_locks,
38 | after_in_parent=release_locks,
39 | after_in_child=release_locks,
40 | )
41 |
42 | def create_logger_lock():
43 | lock = threading.Lock()
44 | logger_locks.add(lock)
45 | return lock
46 |
47 | def create_handler_lock():
48 | lock = threading.Lock()
49 | handler_locks.add(lock)
50 | return lock
51 |
--------------------------------------------------------------------------------
/loguru/_recattrs.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from collections import namedtuple
3 |
4 |
5 | class RecordLevel:
6 | __slots__ = ("name", "no", "icon")
7 |
8 | def __init__(self, name, no, icon):
9 | self.name = name
10 | self.no = no
11 | self.icon = icon
12 |
13 | def __repr__(self):
14 | return "(name=%r, no=%r, icon=%r)" % (self.name, self.no, self.icon)
15 |
16 | def __format__(self, spec):
17 | return self.name.__format__(spec)
18 |
19 |
20 | class RecordFile:
21 | __slots__ = ("name", "path")
22 |
23 | def __init__(self, name, path):
24 | self.name = name
25 | self.path = path
26 |
27 | def __repr__(self):
28 | return "(name=%r, path=%r)" % (self.name, self.path)
29 |
30 | def __format__(self, spec):
31 | return self.name.__format__(spec)
32 |
33 |
34 | class RecordThread:
35 | __slots__ = ("id", "name")
36 |
37 | def __init__(self, id_, name):
38 | self.id = id_
39 | self.name = name
40 |
41 | def __repr__(self):
42 | return "(id=%r, name=%r)" % (self.id, self.name)
43 |
44 | def __format__(self, spec):
45 | return self.id.__format__(spec)
46 |
47 |
48 | class RecordProcess:
49 | __slots__ = ("id", "name")
50 |
51 | def __init__(self, id_, name):
52 | self.id = id_
53 | self.name = name
54 |
55 | def __repr__(self):
56 | return "(id=%r, name=%r)" % (self.id, self.name)
57 |
58 | def __format__(self, spec):
59 | return self.id.__format__(spec)
60 |
61 |
62 | class RecordException(namedtuple("RecordException", ("type", "value", "traceback"))):
63 | def __repr__(self):
64 | return "(type=%r, value=%r, traceback=%r)" % (self.type, self.value, self.traceback)
65 |
66 | def __reduce__(self):
67 | # The traceback is not picklable, therefore it needs to be removed. Additionally, there's a
68 | # possibility that the exception value is not picklable either. In such cases, we also need
69 | # to remove it. This is done for user convenience, aiming to prevent error logging caused by
70 | # custom exceptions from third-party libraries. If the serialization succeeds, we can reuse
71 | # the pickled value later for optimization (so that it's not pickled twice). It's important
72 | # to note that custom exceptions might not necessarily raise a PickleError, hence the
73 | # generic Exception catch.
74 | try:
75 | pickled_value = pickle.dumps(self.value)
76 | except Exception:
77 | return (RecordException, (self.type, None, None))
78 | else:
79 | return (RecordException._from_pickled_value, (self.type, pickled_value, None))
80 |
81 | @classmethod
82 | def _from_pickled_value(cls, type_, pickled_value, traceback_):
83 | try:
84 | # It's safe to use "pickle.loads()" in this case because the pickled value is generated
85 | # by the same code and is not coming from an untrusted source.
86 | value = pickle.loads(pickled_value)
87 | except Exception:
88 | return cls(type_, None, traceback_)
89 | else:
90 | return cls(type_, value, traceback_)
91 |
--------------------------------------------------------------------------------
/loguru/_simple_sinks.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import picologging as logging
3 | import weakref
4 |
5 | from ._asyncio_loop import get_running_loop, get_task_loop
6 |
7 |
8 | class StreamSink:
9 | def __init__(self, stream):
10 | self._stream = stream
11 | self._flushable = callable(getattr(stream, "flush", None))
12 | self._stoppable = callable(getattr(stream, "stop", None))
13 | self._completable = asyncio.iscoroutinefunction(getattr(stream, "complete", None))
14 |
15 | def write(self, message):
16 | self._stream.write(message)
17 | if self._flushable:
18 | self._stream.flush()
19 |
20 | def stop(self):
21 | if self._stoppable:
22 | self._stream.stop()
23 |
24 | def tasks_to_complete(self):
25 | if not self._completable:
26 | return []
27 | return [self._stream.complete()]
28 |
29 |
30 | class StandardSink:
31 | def __init__(self, handler):
32 | self._handler = handler
33 |
34 | def write(self, message):
35 | record = message.record
36 | message = str(message)
37 | exc = record["exception"]
38 | record = logging.getLogger().makeRecord(
39 | record["name"],
40 | record["level"].no,
41 | record["file"].path,
42 | record["line"],
43 | message,
44 | (),
45 | (exc.type, exc.value, exc.traceback) if exc else None,
46 | record["function"],
47 | {"extra": record["extra"]},
48 | )
49 | if exc:
50 | record.exc_text = "\n"
51 | self._handler.handle(record)
52 |
53 | def stop(self):
54 | self._handler.close()
55 |
56 | def tasks_to_complete(self):
57 | return []
58 |
59 |
60 | class AsyncSink:
61 | def __init__(self, function, loop, error_interceptor):
62 | self._function = function
63 | self._loop = loop
64 | self._error_interceptor = error_interceptor
65 | self._tasks = weakref.WeakSet()
66 |
67 | def write(self, message):
68 | try:
69 | loop = self._loop or get_running_loop()
70 | except RuntimeError:
71 | return
72 |
73 | coroutine = self._function(message)
74 | task = loop.create_task(coroutine)
75 |
76 | def check_exception(future):
77 | if future.cancelled() or future.exception() is None:
78 | return
79 | if not self._error_interceptor.should_catch():
80 | raise future.exception()
81 | self._error_interceptor.print(message.record, exception=future.exception())
82 |
83 | task.add_done_callback(check_exception)
84 | self._tasks.add(task)
85 |
86 | def stop(self):
87 | for task in self._tasks:
88 | task.cancel()
89 |
90 | def tasks_to_complete(self):
91 | # To avoid errors due to "self._tasks" being mutated while iterated, the
92 | # "tasks_to_complete()" method must be protected by the same lock as "write()" (which
93 | # happens to be the handler lock). However, the tasks must not be awaited while the lock is
94 | # acquired as this could lead to a deadlock. Therefore, we first need to collect the tasks
95 | # to complete, then return them so that they can be awaited outside of the lock.
96 | return [self._complete_task(task) for task in self._tasks]
97 |
98 | async def _complete_task(self, task):
99 | loop = get_running_loop()
100 | if get_task_loop(task) is not loop:
101 | return
102 | try:
103 | await task
104 | except Exception:
105 | pass # Handled in "check_exception()"
106 |
107 | def __getstate__(self):
108 | state = self.__dict__.copy()
109 | state["_tasks"] = None
110 | return state
111 |
112 | def __setstate__(self, state):
113 | self.__dict__.update(state)
114 | self._tasks = weakref.WeakSet()
115 |
116 |
117 | class CallableSink:
118 | def __init__(self, function):
119 | self._function = function
120 |
121 | def write(self, message):
122 | self._function(message)
123 |
124 | def stop(self):
125 | pass
126 |
127 | def tasks_to_complete(self):
128 | return []
129 |
--------------------------------------------------------------------------------
/loguru/_string_parsers.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import re
3 |
4 |
5 | class Frequencies:
6 | @staticmethod
7 | def hourly(t):
8 | dt = t + datetime.timedelta(hours=1)
9 | return dt.replace(minute=0, second=0, microsecond=0)
10 |
11 | @staticmethod
12 | def daily(t):
13 | dt = t + datetime.timedelta(days=1)
14 | return dt.replace(hour=0, minute=0, second=0, microsecond=0)
15 |
16 | @staticmethod
17 | def weekly(t):
18 | dt = t + datetime.timedelta(days=7 - t.weekday())
19 | return dt.replace(hour=0, minute=0, second=0, microsecond=0)
20 |
21 | @staticmethod
22 | def monthly(t):
23 | if t.month == 12:
24 | y, m = t.year + 1, 1
25 | else:
26 | y, m = t.year, t.month + 1
27 | return t.replace(year=y, month=m, day=1, hour=0, minute=0, second=0, microsecond=0)
28 |
29 | @staticmethod
30 | def yearly(t):
31 | y = t.year + 1
32 | return t.replace(year=y, month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
33 |
34 |
35 | def parse_size(size):
36 | size = size.strip()
37 | reg = re.compile(r"([e\+\-\.\d]+)\s*([kmgtpezy])?(i)?(b)", flags=re.I)
38 |
39 | match = reg.fullmatch(size)
40 |
41 | if not match:
42 | return None
43 |
44 | s, u, i, b = match.groups()
45 |
46 | try:
47 | s = float(s)
48 | except ValueError as e:
49 | raise ValueError("Invalid float value while parsing size: '%s'" % s) from e
50 |
51 | u = "kmgtpezy".index(u.lower()) + 1 if u else 0
52 | i = 1024 if i else 1000
53 | b = {"b": 8, "B": 1}[b] if b else 1
54 | return s * i**u / b
55 |
56 |
57 | def parse_duration(duration):
58 | duration = duration.strip()
59 | reg = r"(?:([e\+\-\.\d]+)\s*([a-z]+)[\s\,]*)"
60 |
61 | units = [
62 | ("y|years?", 31536000),
63 | ("months?", 2628000),
64 | ("w|weeks?", 604800),
65 | ("d|days?", 86400),
66 | ("h|hours?", 3600),
67 | ("min(?:ute)?s?", 60),
68 | ("s|sec(?:ond)?s?", 1), # spellchecker: disable-line
69 | ("ms|milliseconds?", 0.001),
70 | ("us|microseconds?", 0.000001),
71 | ]
72 |
73 | if not re.fullmatch(reg + "+", duration, flags=re.I):
74 | return None
75 |
76 | seconds = 0
77 |
78 | for value, unit in re.findall(reg, duration, flags=re.I):
79 | try:
80 | value = float(value)
81 | except ValueError as e:
82 | raise ValueError("Invalid float value while parsing duration: '%s'" % value) from e
83 |
84 | try:
85 | unit = next(u for r, u in units if re.fullmatch(r, unit, flags=re.I))
86 | except StopIteration:
87 | raise ValueError("Invalid unit value while parsing duration: '%s'" % unit) from None
88 |
89 | seconds += value * unit
90 |
91 | return datetime.timedelta(seconds=seconds)
92 |
93 |
94 | def parse_frequency(frequency):
95 | frequencies = {
96 | "hourly": Frequencies.hourly,
97 | "daily": Frequencies.daily,
98 | "weekly": Frequencies.weekly,
99 | "monthly": Frequencies.monthly,
100 | "yearly": Frequencies.yearly,
101 | }
102 | frequency = frequency.strip().lower()
103 | return frequencies.get(frequency, None)
104 |
105 |
106 | def parse_day(day):
107 | days = {
108 | "monday": 0,
109 | "tuesday": 1,
110 | "wednesday": 2,
111 | "thursday": 3,
112 | "friday": 4,
113 | "saturday": 5,
114 | "sunday": 6,
115 | }
116 | day = day.strip().lower()
117 | if day in days:
118 | return days[day]
119 | if day.startswith("w") and day[1:].isdigit():
120 | day = int(day[1:])
121 | if not 0 <= day < 7:
122 | raise ValueError("Invalid weekday value while parsing day (expected [0-6]): '%d'" % day)
123 | else:
124 | day = None
125 |
126 | return day
127 |
128 |
129 | def parse_time(time):
130 | time = time.strip()
131 | reg = re.compile(r"^[\d\.\:]+\s*(?:[ap]m)?$", flags=re.I)
132 |
133 | if not reg.match(time):
134 | return None
135 |
136 | formats = [
137 | "%H",
138 | "%H:%M",
139 | "%H:%M:%S",
140 | "%H:%M:%S.%f",
141 | "%I %p",
142 | "%I:%M %S",
143 | "%I:%M:%S %p",
144 | "%I:%M:%S.%f %p",
145 | ]
146 |
147 | for format_ in formats:
148 | try:
149 | dt = datetime.datetime.strptime(time, format_)
150 | except ValueError:
151 | pass
152 | else:
153 | return dt.time()
154 |
155 | raise ValueError("Unrecognized format while parsing time: '%s'" % time)
156 |
157 |
158 | def parse_daytime(daytime):
159 | daytime = daytime.strip()
160 | reg = re.compile(r"^(.*?)\s+at\s+(.*)$", flags=re.I)
161 |
162 | match = reg.match(daytime)
163 | if match:
164 | day, time = match.groups()
165 | else:
166 | day = time = daytime
167 |
168 | try:
169 | day = parse_day(day)
170 | if match and day is None:
171 | raise ValueError
172 | except ValueError as e:
173 | raise ValueError("Invalid day while parsing daytime: '%s'" % day) from e
174 |
175 | try:
176 | time = parse_time(time)
177 | if match and time is None:
178 | raise ValueError
179 | except ValueError as e:
180 | raise ValueError("Invalid time while parsing daytime: '%s'" % time) from e
181 |
182 | if day is None and time is None:
183 | return None
184 |
185 | return day, time
186 |
--------------------------------------------------------------------------------
/loguru/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FastestMolasses/Fastest-FastAPI/bed27b943813341f83402aedb9ec6cc4090076c5/loguru/py.typed
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # uvicorn main:server --reload
2 | import app.models.mysql
3 | import app.db.connection
4 | import app.api.middleware as middleware
5 |
6 | from fastapi import FastAPI
7 | from app.api.router import apiRouter
8 | from app.core.config import settings
9 | from app.log.setup import setup_logging
10 | from app.types.server import ServerResponse
11 | from fastapi.responses import ORJSONResponse
12 | from starlette.middleware.cors import CORSMiddleware
13 | from prometheus_fastapi_instrumentator import Instrumentator
14 |
15 | server = FastAPI(
16 | title='fastapi-server',
17 | debug=settings.DEBUG,
18 | openapi_url=f'{settings.API_V1_STR}/openapi.json',
19 | default_response_class=ORJSONResponse,
20 | )
21 | server.include_router(apiRouter, prefix=settings.API_V1_STR)
22 |
23 | setup_logging()
24 |
25 | # Prometheus metrics
26 | Instrumentator(
27 | env_var_name='ENABLE_METRICS',
28 | ).instrument(server).expose(server)
29 |
30 | # Creates the tables if they dont exist
31 | app.db.connection.createMySQLTables()
32 |
33 | # Set all CORS enabled origins
34 | if settings.BACKEND_CORS_ORIGINS:
35 | server.add_middleware(
36 | CORSMiddleware,
37 | allow_origins=[str(origin)
38 | for origin in settings.BACKEND_CORS_ORIGINS],
39 | allow_credentials=True,
40 | allow_methods=['*'],
41 | allow_headers=['*'],
42 | )
43 |
44 | server.add_middleware(middleware.DBExceptionsMiddleware)
45 | server.add_middleware(middleware.CatchAllMiddleware)
46 | server.add_middleware(middleware.ProfilingMiddleware)
47 |
48 | @server.get('/')
49 | async def root() -> ServerResponse:
50 | return ServerResponse()
51 |
--------------------------------------------------------------------------------
/make-env-example.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | ### Used with Makefiles to set environment variables
4 | ### Copy this file into make-env.sh and fill in the blanks
5 |
6 | export AWS_REGION=
7 | export AWS_ACCOUNT_ID=
8 | export CONTAINER_REPO_NAME=
9 | export DOCKER_IMAGE_NAME=
10 | export CONTAINER_NAME=
11 | export ECS_CLUSTER_NAME=
12 | export ECS_SERVICE_NAME=
13 |
14 | make "$@"
15 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | plugins = pydantic.mypy
3 | ignore_missing_imports = True
4 | disallow_untyped_defs = False
5 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "fastapi-base"
3 | version = "0.1.0"
4 | description = ""
5 | authors = [{ name = "Abe M", email = "abe.malla8@gmail.com" }]
6 | readme = "README.md"
7 | requires-python = ">=3.11"
8 |
9 | [tool.ruff]
10 | line-length = 100
11 | target-version = "py311"
12 | extend-exclude = [".pytest_cache"]
13 | ignore-init-module-imports = true
14 |
15 | [tool.ruff.format]
16 | quote-style = "single"
17 | indent-style = "space"
18 | skip-magic-trailing-comma = false
19 | line-ending = "auto"
20 |
21 | [tool.ruff.mccabe]
22 | max-complexity = 10
23 |
24 | [tool.ruff.per-file-ignores]
25 | "main.py" = ["E402"]
26 | "shell.py" = ["F401"]
27 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by uv via the following command:
2 | # uv pip compile pyproject.toml -o requirements.txt
3 | aiohappyeyeballs==2.3.4
4 | # via aiohttp
5 | aiohttp==3.10.1
6 | # via fastapi-base (pyproject.toml)
7 | aiosignal==1.3.1
8 | # via aiohttp
9 | alembic==1.13.2
10 | # via fastapi-base (pyproject.toml)
11 | annotated-types==0.5.0
12 | # via pydantic
13 | anyio==3.7.1
14 | # via starlette
15 | attrs==23.1.0
16 | # via aiohttp
17 | boto3==1.34.154
18 | # via fastapi-base (pyproject.toml)
19 | botocore==1.34.154
20 | # via
21 | # boto3
22 | # s3transfer
23 | certifi==2023.7.22
24 | # via requests
25 | cffi==1.15.1
26 | # via cryptography
27 | charset-normalizer==3.2.0
28 | # via requests
29 | click==8.1.7
30 | # via uvicorn
31 | cryptography==41.0.4
32 | # via python-jose
33 | ecdsa==0.18.0
34 | # via python-jose
35 | fastapi==0.112.0
36 | # via fastapi-base (pyproject.toml)
37 | frozenlist==1.4.0
38 | # via
39 | # aiohttp
40 | # aiosignal
41 | gunicorn==22.0.0
42 | # via fastapi-base (pyproject.toml)
43 | h11==0.14.0
44 | # via uvicorn
45 | idna==3.4
46 | # via
47 | # anyio
48 | # requests
49 | # yarl
50 | jmespath==1.0.1
51 | # via
52 | # boto3
53 | # botocore
54 | mako==1.2.4
55 | # via alembic
56 | markupsafe==2.1.3
57 | # via mako
58 | multidict==6.0.4
59 | # via
60 | # aiohttp
61 | # yarl
62 | orjson==3.10.6
63 | # via fastapi-base (pyproject.toml)
64 | packaging==23.1
65 | # via gunicorn
66 | picologging==0.9.3
67 | # via fastapi-base (pyproject.toml)
68 | prometheus-client==0.17.1
69 | # via prometheus-fastapi-instrumentator
70 | prometheus-fastapi-instrumentator==7.0.0
71 | # via fastapi-base (pyproject.toml)
72 | psycopg==3.2.1
73 | # via fastapi-base (pyproject.toml)
74 | pyasn1==0.5.0
75 | # via
76 | # python-jose
77 | # rsa
78 | pycparser==2.21
79 | # via cffi
80 | pydantic==2.8.2
81 | # via
82 | # fastapi-base (pyproject.toml)
83 | # fastapi
84 | # pydantic-settings
85 | pydantic-core==2.20.1
86 | # via pydantic
87 | pydantic-settings==2.4.0
88 | # via fastapi-base (pyproject.toml)
89 | pyinstrument==4.7.2
90 | # via fastapi-base (pyproject.toml)
91 | pymysql==1.1.1
92 | # via fastapi-base (pyproject.toml)
93 | python-dateutil==2.8.2
94 | # via botocore
95 | python-dotenv==1.0.1
96 | # via
97 | # fastapi-base (pyproject.toml)
98 | # pydantic-settings
99 | python-jose==3.3.0
100 | # via fastapi-base (pyproject.toml)
101 | redis==5.0.8
102 | # via fastapi-base (pyproject.toml)
103 | requests==2.32.3
104 | # via fastapi-base (pyproject.toml)
105 | rsa==4.9
106 | # via python-jose
107 | s3transfer==0.10.2
108 | # via boto3
109 | six==1.16.0
110 | # via
111 | # ecdsa
112 | # python-dateutil
113 | sniffio==1.3.0
114 | # via anyio
115 | sqlalchemy==2.0.32
116 | # via
117 | # fastapi-base (pyproject.toml)
118 | # alembic
119 | starlette==0.37.2
120 | # via
121 | # fastapi
122 | # prometheus-fastapi-instrumentator
123 | typing-extensions==4.8.0
124 | # via
125 | # alembic
126 | # fastapi
127 | # psycopg
128 | # pydantic
129 | # pydantic-core
130 | # sqlalchemy
131 | urllib3==1.26.16
132 | # via
133 | # botocore
134 | # requests
135 | uvicorn==0.30.5
136 | # via fastapi-base (pyproject.toml)
137 | uvloop==0.19.0
138 | # via fastapi-base (pyproject.toml)
139 | yarl==1.9.2
140 | # via aiohttp
141 |
--------------------------------------------------------------------------------
/shell.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is used to run a shell with all the models and the session
3 | loaded. This is useful for debugging and testing.
4 | """
5 | # python -i shell.py
6 | # noqa: F401
7 | from app.models.mysql import User, UserNotification, Notification
8 | from app.db.connection import MySqlSession
9 |
10 |
11 | session = MySqlSession()
12 |
--------------------------------------------------------------------------------