├── .github
├── CODEOWNERS
├── dependabot.yaml
└── workflows
│ ├── ci.yaml
│ ├── pr-title.yaml
│ ├── publish.yaml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.rst
├── LICENSE
├── Makefile
├── README.md
├── examples
├── __init__.py
├── basic.py
├── postgres.py
└── tasks.py
├── litestar_saq
├── __init__.py
├── __metadata__.py
├── base.py
├── cli.py
├── config.py
├── controllers.py
├── exceptions.py
├── plugin.py
└── py.typed
├── pyproject.toml
├── tests
├── __init__.py
├── conftest.py
├── test_cli
│ ├── __init__.py
│ ├── conftest.py
│ └── test_cli.py
└── test_plugin.py
└── uv.lock
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Code owner settings for `litstar-org`
2 | # @maintainers should be assigned to all reviews.
3 | # Most specific assignment takes precedence though, so if you add a more specific thing than the `*` glob, you must also add @maintainers
4 | # For more info about code owners see https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners#codeowners-file-example
5 |
6 | # Global Assignment
7 | * @litestar-org/maintainers @litestar-org/members
8 |
--------------------------------------------------------------------------------
/.github/dependabot.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: Tests And Linting
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - main
8 |
9 | concurrency:
10 | group: test-${{ github.head_ref }}
11 | cancel-in-progress: true
12 |
13 |
14 | jobs:
15 | validate:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - uses: actions/checkout@v4
19 |
20 | - name: Install uv
21 | uses: astral-sh/setup-uv@v6
22 |
23 | - name: Set up Python
24 | run: uv python install 3.12
25 |
26 | - name: Create virtual environment
27 | run: uv sync --all-extras --dev
28 |
29 | - name: Install Pre-Commit hooks
30 | run: uv run pre-commit install
31 |
32 | - name: Load cached Pre-Commit Dependencies
33 | id: cached-pre-commit-dependencies
34 | uses: actions/cache@v4
35 | with:
36 | path: ~/.cache/pre-commit/
37 | key: pre-commit|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
38 |
39 | - name: Execute Pre-Commit
40 | run: uv run pre-commit run --show-diff-on-failure --color=always --all-files
41 | mypy:
42 | runs-on: ubuntu-latest
43 | steps:
44 | - uses: actions/checkout@v4
45 |
46 | - name: Install uv
47 | uses: astral-sh/setup-uv@v6
48 |
49 | - name: Set up Python
50 | run: uv python install 3.12
51 |
52 | - name: Install dependencies
53 | run: uv sync --all-extras --dev
54 |
55 | - name: Run mypy
56 | run: uv run mypy litestar_saq/
57 |
58 | pyright:
59 | runs-on: ubuntu-latest
60 | steps:
61 | - uses: actions/checkout@v4
62 |
63 | - name: Install uv
64 | uses: astral-sh/setup-uv@v6
65 |
66 | - name: Set up Python
67 | run: uv python install 3.12
68 |
69 | - name: Install dependencies
70 | run: uv sync --all-extras --dev
71 |
72 | - name: Run pyright
73 | run: uv run pyright
74 |
75 | slotscheck:
76 | runs-on: ubuntu-latest
77 | steps:
78 | - uses: actions/checkout@v4
79 |
80 | - name: Install uv
81 | uses: astral-sh/setup-uv@v6
82 |
83 | - name: Set up Python
84 | run: uv python install 3.12
85 |
86 | - name: Install dependencies
87 | run: uv sync --all-extras --dev
88 |
89 | - name: Run slotscheck
90 | run: uv run slotscheck -m litestar_saq
91 |
92 |
93 | test_python:
94 | name: "test (python ${{ matrix.python-version }})"
95 | strategy:
96 | fail-fast: true
97 | matrix:
98 | python-version: ["3.9", "3.10", "3.11", "3.12","3.13"]
99 | uses: ./.github/workflows/test.yml
100 | with:
101 | coverage: ${{ matrix.python-version == '3.12' }}
102 | python-version: ${{ matrix.python-version }}
103 |
--------------------------------------------------------------------------------
/.github/workflows/pr-title.yaml:
--------------------------------------------------------------------------------
1 | name: "Lint PR Title"
2 |
3 | on:
4 | pull_request_target:
5 | types:
6 | - opened
7 | - edited
8 | - synchronize
9 |
10 | permissions:
11 | pull-requests: read
12 |
13 | jobs:
14 | main:
15 | name: Validate PR title
16 | runs-on: ubuntu-latest
17 | steps:
18 | - uses: amannn/action-semantic-pull-request@v5
19 | env:
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | name: PublishLatest Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | jobs:
9 | publish-python-release:
10 | runs-on: ubuntu-latest
11 | permissions:
12 | id-token: write
13 | environment: release
14 | steps:
15 | - name: Check out repository
16 | uses: actions/checkout@v4
17 |
18 | - name: Install uv
19 | uses: astral-sh/setup-uv@v6
20 |
21 | - name: Set up Python
22 | run: uv python install 3.12
23 |
24 | - name: Install dependencies
25 | run: uv sync --all-extras
26 |
27 | - name: Build package
28 | run: uv build
29 |
30 | - name: Publish package distributions to PyPI
31 | uses: pypa/gh-action-pypi-publish@release/v1
32 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | python-version:
7 | required: true
8 | type: string
9 | coverage:
10 | required: false
11 | type: boolean
12 | default: false
13 | os:
14 | required: false
15 | type: string
16 | default: "ubuntu-latest"
17 | timeout:
18 | required: false
19 | type: number
20 | default: 60
21 |
22 | jobs:
23 | test:
24 | runs-on: ${{ inputs.os }}
25 | timeout-minutes: ${{ inputs.timeout }}
26 | defaults:
27 | run:
28 | shell: bash
29 | steps:
30 | - name: Check out repository
31 | uses: actions/checkout@v4
32 |
33 | - name: Install uv
34 | uses: astral-sh/setup-uv@v6
35 |
36 | - name: Set up Python
37 | run: uv python install ${{ inputs.python-version }}
38 |
39 | - name: Install dependencies
40 | run: uv sync --all-extras --dev
41 |
42 | - name: Set PYTHONPATH
43 | run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV
44 |
45 | - name: Test
46 | if: ${{ !inputs.coverage }}
47 | run: uv run pytest --dist "loadgroup" -m "" -n 2
48 |
49 | - name: Test with coverage
50 | if: ${{ inputs.coverage }}
51 | run: uv run pytest --dist "loadgroup" -m "" --cov=litestar_saq --cov-report=xml -n 2
52 |
53 | - uses: actions/upload-artifact@v4
54 | if: ${{ inputs.coverage }}
55 | with:
56 | name: coverage-xml
57 | path: coverage.xml
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 | .pdm-python
113 | .pdm-build/
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
165 | .vscode/
166 | examples/tmp.py
167 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: "3"
3 | repos:
4 | - repo: https://github.com/compilerla/conventional-pre-commit
5 | rev: v4.1.0
6 | hooks:
7 | - id: conventional-pre-commit
8 | stages: [commit-msg]
9 | - repo: https://github.com/pre-commit/pre-commit-hooks
10 | rev: v5.0.0
11 | hooks:
12 | - id: check-ast
13 | - id: check-case-conflict
14 | - id: check-toml
15 | - id: debug-statements
16 | - id: end-of-file-fixer
17 | - id: mixed-line-ending
18 | - id: trailing-whitespace
19 | - repo: https://github.com/charliermarsh/ruff-pre-commit
20 | rev: "v0.11.8"
21 | hooks:
22 | # Run the linter.
23 | - id: ruff
24 | types_or: [ python, pyi ]
25 | args: [ --fix ]
26 | # Run the formatter.
27 | - id: ruff-format
28 | types_or: [ python, pyi ]
29 | - repo: https://github.com/codespell-project/codespell
30 | rev: v2.4.1
31 | hooks:
32 | - id: codespell
33 | exclude: "uv.lock|package.json|package-lock.json"
34 | additional_dependencies:
35 | - tomli
36 | - repo: https://github.com/sphinx-contrib/sphinx-lint
37 | rev: "v1.0.0"
38 | hooks:
39 | - id: sphinx-lint
40 |
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
1 | Contribution guide
2 | ==================
3 |
4 | Setting up the environment
5 | --------------------------
6 |
7 | 1. Install `PDM `_
8 | 2. Run ``pdm install -G:all`` to create a `virtual environment `_ and install
9 | the dependencies
10 | 3. If you're working on the documentation and need to build it locally, install the extra dependencies with ``pdm install -G:docs``
11 | 4. Install `pre-commit `_
12 | 5. Run ``pre-commit install`` to install pre-commit hooks
13 |
14 | Code contributions
15 | ------------------
16 |
17 | Workflow
18 | ++++++++
19 |
20 | 1. `Fork `_ the `Advanced Alchemy repository `_
21 | 2. Clone your fork locally with git
22 | 3. `Set up the environment <#setting-up-the-environment>`_
23 | 4. Make your changes
24 | 5. (Optional) Run ``pre-commit run --all-files`` to run linters and formatters. This step is optional and will be executed
25 | automatically by git before you make a commit, but you may want to run it manually in order to apply fixes
26 | 6. Commit your changes to git
27 | 7. Push the changes to your fork
28 | 8. Open a `pull request `_. Give the pull request a descriptive title
29 | indicating what it changes. If it has a corresponding open issue, the issue number should be included in the title as
30 | well. For example a pull request that fixes issue ``bug: Increased stack size making it impossible to find needle #100``
31 | could be titled ``fix(#100): Make needles easier to find by applying fire to haystack``
32 |
33 | .. tip:: Pull requests and commits all need to follow the
34 | `Conventional Commit format `_
35 |
36 | Guidelines for writing code
37 | ----------------------------
38 |
39 | - All code should be fully `typed `_. This is enforced via
40 | `mypy `_.
41 | - All code should be tested. This is enforced via `pytest `_.
42 | - All code should be properly formatted. This is enforced via `black `_ and `Ruff `_.
43 |
44 | Writing and running tests
45 | +++++++++++++++++++++++++
46 |
47 | .. todo:: Write this section
48 |
49 | Project documentation
50 | ---------------------
51 |
52 | The documentation is located in the ``/docs`` directory and is `ReST `_ and
53 | `Sphinx `_. If you're unfamiliar with any of those,
54 | `ReStructuredText primer `_ and
55 | `Sphinx quickstart `_ are recommended reads.
56 |
57 | Running the docs locally
58 | ++++++++++++++++++++++++
59 |
60 | To run or build the docs locally, you need to first install the required dependencies:
61 |
62 | ``pdm install -G:docs``
63 |
64 | Then you can serve the documentation with ``make docs-serve``, or build them with ``make docs``.
65 |
66 | Creating a new release
67 | ----------------------
68 |
69 | 1. Increment the version in `pyproject.toml `_.
70 | .. note:: The version should follow `semantic versioning `_ and `PEP 440 `_.
71 | 2. `Draft a new release `_ on GitHub
72 |
73 | * Use ``vMAJOR.MINOR.PATCH`` (e.g. ``v1.2.3``) as both the tag and release title
74 | * Fill in the release description. You can use the "Generate release notes" function to get a draft for this
75 | 3. Commit your changes and push to ``main``
76 | 4. Publish the release
77 | 5. Go to `Actions `_ and approve the release workflow
78 | 6. Check that the workflow runs successfully
79 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Cody Fincher
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SHELL := /bin/bash
2 |
3 | # =============================================================================
4 | # Configuration and Environment Variables
5 | # =============================================================================
6 |
7 | .DEFAULT_GOAL:=help
8 | .ONESHELL:
9 | .EXPORT_ALL_VARIABLES:
10 | MAKEFLAGS += --no-print-directory
11 |
12 | # -----------------------------------------------------------------------------
13 | # Display Formatting and Colors
14 | # -----------------------------------------------------------------------------
15 | BLUE := $(shell printf "\033[1;34m")
16 | GREEN := $(shell printf "\033[1;32m")
17 | RED := $(shell printf "\033[1;31m")
18 | YELLOW := $(shell printf "\033[1;33m")
19 | NC := $(shell printf "\033[0m")
20 | INFO := $(shell printf "$(BLUE)ℹ$(NC)")
21 | OK := $(shell printf "$(GREEN)✓$(NC)")
22 | WARN := $(shell printf "$(YELLOW)⚠$(NC)")
23 | ERROR := $(shell printf "$(RED)✖$(NC)")
24 |
25 | # =============================================================================
26 | # Help and Documentation
27 | # =============================================================================
28 |
29 | .PHONY: help
30 | help: ## Display this help text for Makefile
31 | @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z0-9_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
32 |
33 | # =============================================================================
34 | # Installation and Environment Setup
35 | # =============================================================================
36 |
37 | .PHONY: install-uv
38 | install-uv: ## Install latest version of uv
39 | @echo "${INFO} Installing uv..."
40 | @curl -LsSf https://astral.sh/uv/install.sh | sh >/dev/null 2>&1
41 | @echo "${OK} UV installed successfully"
42 |
43 | .PHONY: install
44 | install: destroy clean ## Install the project, dependencies, and pre-commit
45 | @echo "${INFO} Starting fresh installation..."
46 | @uv sync --all-extras --dev
47 | @echo "${OK} Installation complete! 🎉"
48 |
49 | .PHONY: destroy
50 | destroy: ## Destroy the virtual environment
51 | @echo "${INFO} Destroying virtual environment... 🗑️"
52 | @uv run pre-commit clean >/dev/null 2>&1
53 | @rm -rf .venv
54 | @echo "${OK} Virtual environment destroyed 🗑️"
55 |
56 | # =============================================================================
57 | # Dependency Management
58 | # =============================================================================
59 |
60 | .PHONY: upgrade
61 | upgrade: ## Upgrade all dependencies to latest stable versions
62 | @echo "${INFO} Updating all dependencies... 🔄"
63 | @uv lock --upgrade
64 | @echo "${OK} Dependencies updated 🔄"
65 | @uv run pre-commit autoupdate
66 | @echo "${OK} Updated Pre-commit hooks 🔄"
67 |
68 | .PHONY: lock
69 | lock: ## Rebuild lockfiles from scratch
70 | @echo "${INFO} Rebuilding lockfiles... 🔄"
71 | @uv lock --upgrade >/dev/null 2>&1
72 | @echo "${OK} Lockfiles updated"
73 |
74 | # =============================================================================
75 | # Build and Release
76 | # =============================================================================
77 |
78 | .PHONY: build
79 | build: ## Build the package
80 | @echo "${INFO} Building package... 📦"
81 | @uv build
82 | @echo "${OK} Package build complete"
83 |
84 | .PHONY: release
85 | release: ## Bump version and create release tag
86 | @echo "${INFO} Preparing for release... 📦"
87 | @make docs
88 | @make clean
89 | @make build
90 | @uv lock --upgrade-package litestar-saq
91 | @uv run bump-my-version bump $(bump)
92 | @echo "${OK} Release complete 🎉"
93 |
94 | # =============================================================================
95 | # Cleaning and Maintenance
96 | # =============================================================================
97 |
98 | .PHONY: clean
99 | clean: ## Cleanup temporary build artifacts
100 | @echo "${INFO} Cleaning working directory... 🧹"
101 | @rm -rf pytest_cache .ruff_cache .hypothesis build/ -rf dist/ .eggs/ .coverage coverage.xml coverage.json htmlcov/ .pytest_cache tests/.pytest_cache tests/**/.pytest_cache .mypy_cache .unasyncd_cache/ .auto_pytabs_cache node_modules >/dev/null 2>&1
102 | @find . -name '*.egg-info' -exec rm -rf {} + >/dev/null 2>&1
103 | @find . -type f -name '*.egg' -exec rm -f {} + >/dev/null 2>&1
104 | @find . -name '*.pyc' -exec rm -f {} + >/dev/null 2>&1
105 | @find . -name '*.pyo' -exec rm -f {} + >/dev/null 2>&1
106 | @find . -name '*~' -exec rm -f {} + >/dev/null 2>&1
107 | @find . -name '__pycache__' -exec rm -rf {} + >/dev/null 2>&1
108 | @find . -name '.ipynb_checkpoints' -exec rm -rf {} + >/dev/null 2>&1
109 | @echo "${OK} Working directory cleaned"
110 |
111 | # =============================================================================
112 | # Testing and Quality Checks
113 | # =============================================================================
114 |
115 | .PHONY: test
116 | test: ## Run the tests
117 | @echo "${INFO} Running test cases... 🧪"
118 | @uv run pytest -n 2 --quiet
119 | @echo "${OK} Tests passed ✨"
120 |
121 | .PHONY: coverage
122 | coverage: ## Run tests with coverage report
123 | @echo "${INFO} Running tests with coverage... 📊"
124 | @uv run pytest --cov -n auto --quiet
125 | @uv run coverage html >/dev/null 2>&1
126 | @uv run coverage xml >/dev/null 2>&1
127 | @echo "${OK} Coverage report generated ✨"
128 |
129 | # -----------------------------------------------------------------------------
130 | # Type Checking
131 | # -----------------------------------------------------------------------------
132 |
133 | .PHONY: mypy
134 | mypy: ## Run mypy
135 | @echo "${INFO} Running mypy... 🔍"
136 | @uv run dmypy run litestar_saq/
137 | @echo "${OK} Mypy checks passed ✨"
138 |
139 | .PHONY: pyright
140 | pyright: ## Run pyright
141 | @echo "${INFO} Running pyright... 🔍"
142 | @uv run pyright
143 | @echo "${OK} Pyright checks passed ✨"
144 |
145 | .PHONY: type-check
146 | type-check: mypy pyright ## Run all type checking
147 |
148 | # -----------------------------------------------------------------------------
149 | # Linting and Formatting
150 | # -----------------------------------------------------------------------------
151 |
152 | .PHONY: pre-commit
153 | pre-commit: ## Run pre-commit hooks
154 | @echo "${INFO} Running pre-commit checks... 🔎"
155 | @NODE_OPTIONS="--no-deprecation --disable-warning=ExperimentalWarning" uv run pre-commit run --color=always --all-files
156 | @echo "${OK} Pre-commit checks passed ✨"
157 |
158 | .PHONY: slotscheck
159 | slotscheck: ## Run slotscheck
160 | @echo "${INFO} Running slots check... 🔍"
161 | @uv run slotscheck -m litestar_saq
162 | @echo "${OK} Slots check passed ✨"
163 |
164 | .PHONY: fix
165 | fix: ## Run code formatters
166 | @echo "${INFO} Running code formatters... 🔧"
167 | @uv run ruff check --fix --unsafe-fixes
168 | @echo "${OK} Code formatting complete ✨"
169 |
170 | .PHONY: lint
171 | lint: pre-commit type-check slotscheck ## Run all linting checks
172 |
173 | .PHONY: check-all
174 | check-all: lint test coverage ## Run all checks (lint, test, coverage)
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Litestar SAQ
2 |
3 | ## Installation
4 |
5 | ```shell
6 | pip install litestar-saq
7 | ```
8 |
9 | ## Usage
10 |
11 | Here is a basic application that demonstrates how to use the plugin.
12 |
13 | ```python
14 | from __future__ import annotations
15 |
16 | from litestar import Litestar
17 |
18 | from litestar_saq import QueueConfig, SAQConfig, SAQPlugin
19 |
20 | saq = SAQPlugin(config=SAQConfig(dsn="redis://localhost:6397/0", queue_configs=[QueueConfig(name="samples")]))
21 | app = Litestar(plugins=[saq])
22 |
23 |
24 | ```
25 |
26 | You can start a background worker with the following command now:
27 |
28 | ```shell
29 | litestar --app-dir=examples/ --app basic:app workers run
30 | Using Litestar app from env: 'basic:app'
31 | Starting SAQ Workers ──────────────────────────────────────────────────────────────────
32 | INFO - 2023-10-04 17:39:03,255 - saq - worker - Worker starting: Queue>>, name='samples'>
33 | INFO - 2023-10-04 17:39:06,545 - saq - worker - Worker shutting down
34 | ```
35 |
36 | You can also start the process for only specific queues. This is helpful if you want separated processes working on different queues instead of combining them.
37 |
38 | ```shell
39 | litestar --app-dir=examples/ --app basic:app workers run --queues sample
40 | Using Litestar app from env: 'basic:app'
41 | Starting SAQ Workers ──────────────────────────────────────────────────────────────────
42 | INFO - 2023-10-04 17:39:03,255 - saq - worker - Worker starting: Queue>>, name='samples'>
43 | INFO - 2023-10-04 17:39:06,545 - saq - worker - Worker shutting down
44 | ```
45 |
46 | If you are starting the process for only specific queues and still want to read from the other queues or enqueue a task into another queue that was not initialized in your worker or is found somewhere else, you can do so like here
47 |
48 | ```python
49 | import os
50 | from saq import Queue
51 |
52 |
53 | def get_queue_directly(queue_name: str, redis_url: str) -> Queue:
54 | return Queue.from_url(redis_url, name=queue_name)
55 |
56 | redis_url = os.getenv("REDIS_URL")
57 | queue = get_queue_directly("queue-in-other-process", redis_url)
58 | # Get queue info
59 | info = await queue.info(jobs=True)
60 | # Enqueue new task
61 | queue.enqueue(
62 | ....
63 | )
64 | ```
65 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cofin/litestar-saq/329934de25c5a3f7f3459fd8e8d2ca7013d6fe50/examples/__init__.py
--------------------------------------------------------------------------------
/examples/basic.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from litestar import Controller, Litestar, get
6 |
7 | from examples import tasks
8 | from litestar_saq import CronJob, QueueConfig, SAQConfig, SAQPlugin
9 |
10 | if TYPE_CHECKING:
11 | from saq.types import QueueInfo
12 |
13 | from litestar_saq.config import TaskQueues
14 |
15 |
16 | class SampleController(Controller):
17 | @get(path="/samples")
18 | async def samples_queue_info(self, task_queues: TaskQueues) -> QueueInfo:
19 | """Check database available and returns app config info."""
20 | queue = task_queues.get("samples")
21 | return await queue.info()
22 |
23 |
24 | saq = SAQPlugin(
25 | config=SAQConfig(
26 | web_enabled=True,
27 | use_server_lifespan=True,
28 | queue_configs=[
29 | QueueConfig(
30 | dsn="redis://localhost:6397/0",
31 | name="samples",
32 | tasks=[tasks.background_worker_task, tasks.system_task, tasks.system_upkeep],
33 | scheduled_tasks=[CronJob(function=tasks.system_upkeep, cron="* * * * *", timeout=600, ttl=2000)],
34 | ),
35 | ],
36 | ),
37 | )
38 | app = Litestar(plugins=[saq], route_handlers=[SampleController])
39 |
--------------------------------------------------------------------------------
/examples/postgres.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from logging import getLogger
5 | from typing import TYPE_CHECKING
6 |
7 | from litestar import Controller, Litestar, get
8 |
9 | from examples import tasks
10 | from litestar_saq import CronJob, QueueConfig, SAQConfig, SAQPlugin
11 |
12 | if TYPE_CHECKING:
13 | from saq.types import Context, QueueInfo
14 |
15 | from litestar_saq.config import TaskQueues
16 |
17 | logger = getLogger(__name__)
18 |
19 |
20 | async def system_upkeep(_: Context) -> None:
21 | logger.info("Performing system upkeep operations.")
22 | logger.info("Simulating a long running operation. Sleeping for 60 seconds.")
23 | await asyncio.sleep(3)
24 | logger.info("Simulating an even longer running operation. Sleeping for 120 seconds.")
25 | await asyncio.sleep(3)
26 | logger.info("Long running process complete.")
27 | logger.info("Performing system upkeep operations.")
28 |
29 |
30 | async def background_worker_task(_: Context) -> None:
31 | logger.info("Performing background worker task.")
32 | await asyncio.sleep(1)
33 | logger.info("Performing system upkeep operations.")
34 |
35 |
36 | async def system_task(_: Context) -> None:
37 | logger.info("Performing simple system task")
38 | await asyncio.sleep(2)
39 | logger.info("System task complete.")
40 |
41 |
42 | class SampleController(Controller):
43 | @get(path="/samples")
44 | async def samples_queue_info(self, task_queues: TaskQueues) -> QueueInfo:
45 | queue = task_queues.get("samples")
46 | return await queue.info()
47 |
48 |
49 | saq = SAQPlugin(
50 | config=SAQConfig(
51 | web_enabled=True,
52 | use_server_lifespan=True,
53 | queue_configs=[
54 | QueueConfig(
55 | dsn="postgresql://app:app@localhost:15432/app",
56 | tasks=[tasks.background_worker_task, tasks.system_task, tasks.system_upkeep],
57 | scheduled_tasks=[CronJob(function=tasks.system_upkeep, cron="* * * * *", timeout=600, ttl=2000)],
58 | )
59 | ],
60 | ),
61 | )
62 | app = Litestar(plugins=[saq], route_handlers=[SampleController])
63 |
--------------------------------------------------------------------------------
/examples/tasks.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | from logging import getLogger
5 | from typing import TYPE_CHECKING
6 |
7 | if TYPE_CHECKING:
8 | from saq.types import Context
9 |
10 | logger = getLogger(__name__)
11 |
12 |
13 | async def system_upkeep(_: Context) -> None:
14 | logger.info("Performing system upkeep operations.")
15 | logger.info("Simulating a long running operation. Sleeping for 60 seconds.")
16 | await asyncio.sleep(60)
17 | logger.info("Simulating an even longer running operation. Sleeping for 120 seconds.")
18 | await asyncio.sleep(120)
19 | logger.info("Long running process complete.")
20 | logger.info("Performing system upkeep operations.")
21 |
22 |
23 | async def background_worker_task(_: Context) -> None:
24 | logger.info("Performing background worker task.")
25 | await asyncio.sleep(20)
26 | logger.info("Performing system upkeep operations.")
27 |
28 |
29 | async def system_task(_: Context) -> None:
30 | logger.info("Performing simple system task")
31 | await asyncio.sleep(2)
32 | logger.info("System task complete.")
33 |
--------------------------------------------------------------------------------
/litestar_saq/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from litestar_saq.base import CronJob, Job, Worker
4 | from litestar_saq.config import PostgresQueueOptions, QueueConfig, RedisQueueOptions, SAQConfig, TaskQueues
5 | from litestar_saq.plugin import SAQPlugin
6 |
7 | __all__ = (
8 | "CronJob",
9 | "Job",
10 | "PostgresQueueOptions",
11 | "QueueConfig",
12 | "RedisQueueOptions",
13 | "SAQConfig",
14 | "SAQPlugin",
15 | "TaskQueues",
16 | "Worker",
17 | )
18 |
--------------------------------------------------------------------------------
/litestar_saq/__metadata__.py:
--------------------------------------------------------------------------------
1 | """Metadata for the Project."""
2 |
3 | from __future__ import annotations
4 |
5 | from importlib.metadata import PackageNotFoundError, metadata, version
6 |
7 | __all__ = ("__project__", "__version__")
8 |
9 | try:
10 | __version__ = version("litestar_saq")
11 | """Version of the project."""
12 | __project__ = metadata("litestar_saq")["Name"]
13 | """Name of the project."""
14 | except PackageNotFoundError: # pragma: no cover
15 | __version__ = "0.0.0"
16 | __project__ = "Litestar SAQ"
17 | finally:
18 | del version, PackageNotFoundError, metadata
19 |
--------------------------------------------------------------------------------
/litestar_saq/base.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from dataclasses import dataclass, field
3 | from datetime import timezone, tzinfo
4 | from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
5 |
6 | from litestar.utils.module_loader import import_string
7 | from saq import Job as SaqJob
8 | from saq.job import CronJob as SaqCronJob
9 | from saq.worker import Worker as SaqWorker
10 |
11 | if TYPE_CHECKING:
12 | from collections.abc import Collection
13 |
14 | from saq.queue.base import Queue
15 | from saq.types import Function, PartialTimersDict, ReceivesContext
16 |
17 | JsonDict = dict[str, Any]
18 |
19 |
20 | @dataclass
21 | class Job(SaqJob):
22 | """Job Details"""
23 |
24 |
25 | @dataclass
26 | class CronJob(SaqCronJob):
27 | """Cron Job Details"""
28 |
29 | function: "Union[Function, str]" # type: ignore[assignment]
30 | meta: "dict[str, Any]" = field(default_factory=dict) # pyright: ignore
31 |
32 | def __post_init__(self) -> None:
33 | self.function = self._get_or_import_function(self.function) # pyright: ignore[reportIncompatibleMethodOverride]
34 |
35 | @staticmethod
36 | def _get_or_import_function(function_or_import_string: "Union[str, Function]") -> "Function":
37 | if isinstance(function_or_import_string, str):
38 | return cast("Function", import_string(function_or_import_string))
39 | return function_or_import_string
40 |
41 |
42 | class Worker(SaqWorker):
43 | """Worker."""
44 |
45 | def __init__(
46 | self,
47 | queue: "Queue",
48 | functions: "Collection[Union[Function, tuple[str, Function]]]",
49 | *,
50 | id: "Optional[str]" = None, # noqa: A002
51 | concurrency: int = 10,
52 | cron_jobs: "Optional[Collection[CronJob]]" = None,
53 | cron_tz: "tzinfo" = timezone.utc,
54 | startup: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None,
55 | shutdown: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None,
56 | before_process: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None,
57 | after_process: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None,
58 | timers: "Optional[PartialTimersDict]" = None,
59 | dequeue_timeout: float = 0,
60 | burst: bool = False,
61 | max_burst_jobs: "Optional[int]" = None,
62 | metadata: "Optional[JsonDict]" = None,
63 | separate_process: bool = True,
64 | multiprocessing_mode: Literal["multiprocessing", "threading"] = "multiprocessing",
65 | ) -> None:
66 | self.separate_process = separate_process
67 | self.multiprocessing_mode = multiprocessing_mode
68 | super().__init__(
69 | queue,
70 | functions,
71 | id=id,
72 | concurrency=concurrency,
73 | cron_jobs=cron_jobs,
74 | cron_tz=cron_tz,
75 | startup=startup,
76 | shutdown=shutdown,
77 | before_process=before_process,
78 | after_process=after_process,
79 | timers=timers,
80 | dequeue_timeout=dequeue_timeout,
81 | burst=burst,
82 | max_burst_jobs=max_burst_jobs,
83 | metadata=metadata,
84 | )
85 |
86 | async def on_app_startup(self) -> None:
87 | """Attach the worker to the running event loop."""
88 | if not self.separate_process:
89 | self.SIGNALS = []
90 | loop = asyncio.get_running_loop()
91 | self._saq_asyncio_tasks = loop.create_task(self.start())
92 |
93 | async def on_app_shutdown(self) -> None:
94 | """Attach the worker to the running event loop."""
95 | if not self.separate_process:
96 | loop = asyncio.get_running_loop()
97 | self._saq_asyncio_tasks = loop.create_task(self.stop())
98 |
--------------------------------------------------------------------------------
/litestar_saq/cli.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Optional
2 |
3 | if TYPE_CHECKING:
4 | from click import Group
5 | from litestar import Litestar
6 | from litestar.logging.config import BaseLoggingConfig
7 |
8 | from litestar_saq.base import Worker
9 | from litestar_saq.plugin import SAQPlugin
10 |
11 |
12 | def build_cli_app() -> "Group": # noqa: C901
13 | import asyncio
14 | import multiprocessing
15 | import platform
16 | from typing import cast
17 |
18 | from click import IntRange, group, option
19 | from litestar.cli._utils import LitestarGroup, console # pyright: ignore
20 |
21 | @group(cls=LitestarGroup, name="workers", no_args_is_help=True)
22 | def background_worker_group() -> None:
23 | """Manage background task workers."""
24 |
25 | @background_worker_group.command(
26 | name="run",
27 | help="Run background worker processes.",
28 | )
29 | @option(
30 | "--workers",
31 | help="The number of worker processes to start.",
32 | type=IntRange(min=1),
33 | default=1,
34 | required=False,
35 | show_default=True,
36 | )
37 | @option(
38 | "--queues",
39 | help="List of queue names to process.",
40 | type=str,
41 | multiple=True,
42 | required=False,
43 | show_default=False,
44 | )
45 | @option("-v", "--verbose", help="Enable verbose logging.", is_flag=True, default=None, type=bool, required=False)
46 | @option("-d", "--debug", help="Enable debugging.", is_flag=True, default=None, type=bool, required=False)
47 | def run_worker( # pyright: ignore[reportUnusedFunction]
48 | app: "Litestar",
49 | workers: int,
50 | queues: "Optional[tuple[str, ...]]",
51 | verbose: "Optional[bool]",
52 | debug: "Optional[bool]",
53 | ) -> None:
54 | """Run the API server."""
55 | console.rule("[yellow]Starting SAQ Workers[/]", align="left")
56 | if platform.system() == "Darwin":
57 | multiprocessing.set_start_method("fork", force=True)
58 |
59 | if app.logging_config is not None:
60 | app.logging_config.configure()
61 | if debug is not None or verbose is not None:
62 | app.debug = True
63 | plugin = get_saq_plugin(app)
64 | if queues:
65 | queue_list = list(queues)
66 | limited_start_up(plugin, queue_list)
67 | show_saq_info(app, workers, plugin)
68 | managed_workers = list(plugin.get_workers().values())
69 | processes: list[multiprocessing.Process] = []
70 | if workers > 1:
71 | for _ in range(workers - 1):
72 | for worker in managed_workers:
73 | p = multiprocessing.Process(
74 | target=run_saq_worker,
75 | args=(
76 | worker,
77 | app.logging_config,
78 | ),
79 | )
80 | p.start()
81 | processes.append(p)
82 |
83 | if len(managed_workers) > 1:
84 | for j in range(len(managed_workers) - 1):
85 | p = multiprocessing.Process(target=run_saq_worker, args=(managed_workers[j + 1], app.logging_config))
86 | p.start()
87 | processes.append(p)
88 |
89 | try:
90 | run_saq_worker(
91 | worker=managed_workers[0],
92 | logging_config=cast("BaseLoggingConfig", app.logging_config),
93 | )
94 | except KeyboardInterrupt:
95 | loop = asyncio.get_event_loop()
96 | for w in managed_workers:
97 | loop.run_until_complete(w.stop())
98 | console.print("[yellow]SAQ workers stopped.[/]")
99 |
100 | @background_worker_group.command(
101 | name="status",
102 | help="Check the status of currently configured workers and queues.",
103 | )
104 | @option("-v", "--verbose", help="Enable verbose logging.", is_flag=True, default=None, type=bool, required=False)
105 | @option("-d", "--debug", help="Enable debugging.", is_flag=True, default=None, type=bool, required=False)
106 | def worker_status( # pyright: ignore[reportUnusedFunction]
107 | app: "Litestar",
108 | verbose: "Optional[bool]",
109 | debug: "Optional[bool]",
110 | ) -> None:
111 | """Check the status of currently configured workers and queues."""
112 | console.rule("[yellow]Checking SAQ worker status[/]", align="left")
113 | if app.logging_config is not None:
114 | app.logging_config.configure()
115 | if debug is not None or verbose is not None:
116 | app.debug = True
117 | plugin = get_saq_plugin(app)
118 | show_saq_info(app, plugin.config.worker_processes, plugin)
119 |
120 | return background_worker_group
121 |
122 |
123 | def limited_start_up(plugin: "SAQPlugin", queues: "list[str]") -> None:
124 | """Reset the workers and include only the specified queues."""
125 | plugin.remove_workers()
126 | plugin.config.filter_delete_queues(queues)
127 |
128 |
129 | def get_saq_plugin(app: "Litestar") -> "SAQPlugin":
130 | """Retrieve a SAQ plugin from the Litestar application's plugins.
131 |
132 | This function attempts to find a SAQ plugin instance.
133 | If plugin is not found, it raises an ImproperlyConfiguredException.
134 |
135 | Args:
136 | app: The Litestar application instance.
137 |
138 | Returns:
139 | The SAQ plugin instance.
140 |
141 | Raises:
142 | ImproperConfigurationError: If the SAQ plugin is not found.
143 | """
144 | from contextlib import suppress
145 |
146 | from litestar_saq.exceptions import ImproperConfigurationError
147 | from litestar_saq.plugin import SAQPlugin
148 |
149 | with suppress(KeyError):
150 | return app.plugins.get(SAQPlugin)
151 | msg = "Failed to initialize SAQ. The required plugin (SAQPlugin) is missing."
152 | raise ImproperConfigurationError(
153 | msg,
154 | )
155 |
156 |
157 | def show_saq_info(app: "Litestar", workers: int, plugin: "SAQPlugin") -> None: # pragma: no cover
158 | """Display basic information about the application and its configuration."""
159 |
160 | from litestar.cli._utils import _format_is_enabled, console # pyright: ignore
161 | from rich.table import Table
162 | from saq import __version__ as saq_version
163 |
164 | table = Table(show_header=False)
165 | table.add_column("title", style="cyan")
166 | table.add_column("value", style="bright_blue")
167 |
168 | table.add_row("SAQ version", saq_version)
169 | table.add_row("Debug mode", _format_is_enabled(app.debug))
170 | table.add_row("Number of Processes", str(workers))
171 | table.add_row("Queues", str(len(plugin.config.queue_configs)))
172 |
173 | console.print(table)
174 |
175 |
176 | def run_saq_worker(worker: "Worker", logging_config: "Optional[BaseLoggingConfig]") -> None:
177 | """Run a worker."""
178 | import asyncio
179 |
180 | loop = asyncio.get_event_loop()
181 | if logging_config is not None:
182 | logging_config.configure()
183 |
184 | async def worker_start(w: "Worker") -> None:
185 | try:
186 | await w.queue.connect()
187 | await w.start()
188 | finally:
189 | await w.queue.disconnect()
190 |
191 | try:
192 | if worker.separate_process:
193 | loop.run_until_complete(loop.create_task(worker_start(worker)))
194 | except KeyboardInterrupt:
195 | loop.run_until_complete(loop.create_task(worker.stop()))
196 |
--------------------------------------------------------------------------------
/litestar_saq/config.py:
--------------------------------------------------------------------------------
1 | from collections.abc import AsyncGenerator, Collection, Mapping
2 | from dataclasses import dataclass, field
3 | from datetime import timezone, tzinfo
4 | from pathlib import Path
5 | from typing import TYPE_CHECKING, Literal, Optional, TypedDict, TypeVar, Union, cast
6 |
7 | from litestar.exceptions import ImproperlyConfiguredException
8 | from litestar.serialization import decode_json, encode_json
9 | from litestar.utils.module_loader import import_string, module_to_os_path
10 | from saq.queue.base import Queue
11 | from saq.types import DumpType, LoadType, PartialTimersDict, QueueInfo, ReceivesContext, WorkerInfo
12 | from typing_extensions import NotRequired
13 |
14 | from litestar_saq.base import CronJob, Job, JsonDict, Worker
15 |
16 | if TYPE_CHECKING:
17 | from typing import Any
18 |
19 | from litestar.types.callable_types import Guard # pyright: ignore[reportUnknownVariableType]
20 | from saq.types import Function
21 |
22 | T = TypeVar("T")
23 |
24 |
25 | def serializer(value: "Any") -> str:
26 | """Serialize JSON field values.
27 |
28 | Args:
29 | value: Any json serializable value.
30 |
31 | Returns:
32 | JSON string.
33 |
34 | """
35 | return encode_json(value).decode("utf-8")
36 |
37 |
38 | def _get_static_files() -> Path:
39 | return Path(module_to_os_path("saq") / "web" / "static")
40 |
41 |
42 | @dataclass
43 | class TaskQueues:
44 | """Task queues."""
45 |
46 | queues: "Mapping[str, Queue]" = field(default_factory=dict) # pyright: ignore
47 |
48 | def get(self, name: str) -> "Queue":
49 | """Get a queue by name.
50 |
51 | Args:
52 | name: The name of the queue.
53 |
54 | Returns:
55 | The queue.
56 |
57 | Raises:
58 | ImproperlyConfiguredException: If the queue does not exist.
59 |
60 | """
61 | queue = self.queues.get(name)
62 | if queue is not None:
63 | return queue
64 | msg = "Could not find the specified queue. Please check your configuration."
65 | raise ImproperlyConfiguredException(msg)
66 |
67 |
68 | @dataclass
69 | class SAQConfig:
70 | """SAQ Configuration."""
71 |
72 | queue_configs: "Collection[QueueConfig]" = field(default_factory=list) # pyright: ignore
73 | """Configuration for Queues"""
74 |
75 | queue_instances: "Optional[Mapping[str, Queue]]" = None
76 | """Current configured queue instances. When None, queues will be auto-created on startup"""
77 | queues_dependency_key: str = field(default="task_queues")
78 | """Key to use for storing dependency information in litestar."""
79 | worker_processes: int = 1
80 | """The number of worker processes to spawn.
81 |
82 | Default is set to 1.
83 | """
84 |
85 | json_deserializer: "LoadType" = decode_json
86 | """This is a Python callable that will
87 | convert a JSON string to a Python object. By default, this is set to Litestar's
88 | :attr:`decode_json() <.serialization.decode_json>` function."""
89 | json_serializer: "DumpType" = serializer
90 | """This is a Python callable that will render a given object as JSON.
91 | By default, Litestar's :attr:`encode_json() <.serialization.encode_json>` is used."""
92 | static_files: Path = field(default_factory=_get_static_files)
93 | """Location of the static files to serve for the SAQ UI"""
94 | web_enabled: bool = False
95 | """If true, the worker admin UI is launched on worker startup.."""
96 | web_path: str = "/saq"
97 | """Base path to serve the SAQ web UI"""
98 | web_guards: "Optional[list[Guard]]" = field(default=None)
99 | """Guards to apply to web endpoints."""
100 | web_include_in_schema: bool = False
101 | """Include Queue API endpoints in generated OpenAPI schema"""
102 | use_server_lifespan: bool = False
103 | """Utilize the server lifespan hook to run SAQ."""
104 |
105 | @property
106 | def signature_namespace(self) -> "dict[str, Any]":
107 | """Return the plugin's signature namespace.
108 |
109 | Returns:
110 | A string keyed dict of names to be added to the namespace for signature forward reference resolution.
111 | """
112 | return {
113 | "Queue": Queue,
114 | "Worker": Worker,
115 | "QueueInfo": QueueInfo,
116 | "WorkerInfo": WorkerInfo,
117 | "Job": Job,
118 | "TaskQueues": TaskQueues,
119 | }
120 |
121 | async def provide_queues(self) -> "AsyncGenerator[TaskQueues, None]":
122 | """Provide the configured job queues.
123 |
124 | Yields:
125 | The configured job queues.
126 | """
127 | queues = self.get_queues()
128 | for queue in queues.queues.values():
129 | await queue.connect()
130 | yield queues
131 |
132 | def filter_delete_queues(self, queues: "list[str]") -> None:
133 | """Remove all queues except the ones in the given list."""
134 | new_config = [queue_config for queue_config in self.queue_configs if queue_config.name in queues]
135 | self.queue_configs = new_config
136 | if self.queue_instances is not None:
137 | for queue_name in dict(self.queue_instances):
138 | if queue_name not in queues:
139 | del self.queue_instances[queue_name] # type: ignore
140 |
141 | def get_queues(self) -> "TaskQueues":
142 | """Get the configured SAQ queues.
143 |
144 | Returns:
145 | The configured job queues.
146 | """
147 | if self.queue_instances is not None:
148 | return TaskQueues(queues=self.queue_instances)
149 |
150 | self.queue_instances = {}
151 | for c in self.queue_configs:
152 | self.queue_instances[c.name] = c.queue_class( # type: ignore
153 | c.get_broker(),
154 | name=c.name, # pyright: ignore[reportCallIssue]
155 | dump=self.json_serializer,
156 | load=self.json_deserializer,
157 | **c._broker_options, # pyright: ignore[reportArgumentType,reportPrivateUsage] # noqa: SLF001
158 | )
159 | self.queue_instances[c.name]._is_pool_provided = False # type: ignore # noqa: SLF001
160 | return TaskQueues(queues=self.queue_instances)
161 |
162 |
163 | class RedisQueueOptions(TypedDict, total=False):
164 | """Options for the Redis backend."""
165 |
166 | max_concurrent_ops: NotRequired[int]
167 | """Maximum concurrent operations. (default 15)
168 | This throttles calls to `enqueue`, `job`, and `abort` to prevent the Queue
169 | from consuming too many Redis connections."""
170 | swept_error_message: NotRequired[str]
171 |
172 |
173 | class PostgresQueueOptions(TypedDict, total=False):
174 | """Options for the Postgres backend."""
175 |
176 | versions_table: NotRequired[str]
177 | jobs_table: NotRequired[str]
178 | stats_table: NotRequired[str]
179 | min_size: NotRequired[int]
180 | max_size: NotRequired[int]
181 | saq_lock_keyspace: NotRequired[int]
182 | job_lock_keyspace: NotRequired[int]
183 | job_lock_sweep: NotRequired[bool]
184 | priorities: NotRequired[tuple[int, int]]
185 | swept_error_message: NotRequired[str]
186 | manage_pool_lifecycle: NotRequired[bool]
187 |
188 |
189 | @dataclass
190 | class QueueConfig:
191 | """SAQ Queue Configuration"""
192 |
193 | dsn: "Optional[str]" = None
194 | """DSN for connecting to backend. e.g. 'redis://...' or 'postgres://...'.
195 | """
196 |
197 | broker_instance: "Optional[Any]" = None
198 | """An instance of a supported saq backend connection..
199 | """
200 | id: "Optional[str]" = None
201 | """An optional ID to supply for the worker."""
202 | name: str = "default"
203 | """The name of the queue to create."""
204 | concurrency: int = 10
205 | """Number of jobs to process concurrently."""
206 | broker_options: "Union[RedisQueueOptions, PostgresQueueOptions, dict[str, Any]]" = field(default_factory=dict) # pyright: ignore
207 | """Broker-specific options. For Redis or Postgres backends."""
208 | tasks: "Collection[Union[ReceivesContext, tuple[str, Function], str]]" = field(default_factory=list) # pyright: ignore
209 | """Allowed list of functions to execute in this queue."""
210 | scheduled_tasks: "Collection[CronJob]" = field(default_factory=list) # pyright: ignore
211 | """Scheduled cron jobs to execute in this queue."""
212 | cron_tz: "tzinfo" = timezone.utc
213 | """Timezone for cron jobs."""
214 | startup: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None
215 | """Async callable to call on startup."""
216 | shutdown: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None
217 | """Async callable to call on shutdown."""
218 | before_process: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None
219 | """Async callable to call before a job processes."""
220 | after_process: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None
221 | """Async callable to call after a job processes."""
222 | timers: "Optional[PartialTimersDict]" = None
223 | """Dict with various timer overrides in seconds
224 | schedule: how often we poll to schedule jobs
225 | stats: how often to update stats
226 | sweep: how often to clean up stuck jobs
227 | abort: how often to check if a job is aborted"""
228 | dequeue_timeout: float = 0
229 | """How long to wait to dequeue."""
230 | burst: bool = False
231 | """If True, the worker will process jobs in burst mode."""
232 | max_burst_jobs: "Optional[int]" = None
233 | """The maximum number of jobs to process in burst mode."""
234 | metadata: "Optional[JsonDict]" = None
235 | """Arbitrary data to pass to the worker which it will register with saq."""
236 | multiprocessing_mode: 'Literal["multiprocessing", "threading"]' = "multiprocessing"
237 | """Executes with the multiprocessing or threading backend. Multi-processing is recommended and how SAQ is designed to work."""
238 | separate_process: bool = True
239 | """Executes as a separate event loop when True.
240 | Set it False to execute within the Litestar application."""
241 |
242 | def __post_init__(self) -> None:
243 | if self.dsn and self.broker_instance:
244 | msg = "Cannot specify both `dsn` and `broker_instance`"
245 | raise ImproperlyConfiguredException(msg)
246 | if not self.dsn and not self.broker_instance:
247 | msg = "Must specify either `dsn` or `broker_instance`"
248 | raise ImproperlyConfiguredException(msg)
249 | self.tasks = [self._get_or_import_task(task) for task in self.tasks]
250 | if self.startup is not None and not isinstance(self.startup, Collection):
251 | self.startup = [self.startup]
252 | if self.shutdown is not None and not isinstance(self.shutdown, Collection):
253 | self.shutdown = [self.shutdown]
254 | if self.before_process is not None and not isinstance(self.before_process, Collection):
255 | self.before_process = [self.before_process]
256 | if self.after_process is not None and not isinstance(self.after_process, Collection):
257 | self.after_process = [self.after_process]
258 | self.startup = [self._get_or_import_task(task) for task in self.startup or []] # pyright: ignore
259 | self.shutdown = [self._get_or_import_task(task) for task in self.shutdown or []] # pyright: ignore
260 | self.before_process = [self._get_or_import_task(task) for task in self.before_process or []] # pyright: ignore
261 | self.after_process = [self._get_or_import_task(task) for task in self.after_process or []] # pyright: ignore
262 | self._broker_type: Optional[Literal["redis", "postgres", "http"]] = None
263 | self._queue_class: Optional[type[Queue]] = None
264 |
265 | def get_broker(self) -> "Any":
266 | """Get the configured Broker connection.
267 |
268 | Raises:
269 | ImproperlyConfiguredException: If the broker type is invalid.
270 |
271 | Returns:
272 | Dictionary of queues.
273 | """
274 |
275 | if self.broker_instance is not None:
276 | return self.broker_instance
277 |
278 | if self.dsn and self.dsn.startswith("redis"):
279 | from redis.asyncio import from_url as redis_from_url # pyright: ignore[reportUnknownVariableType]
280 | from saq.queue.redis import RedisQueue
281 |
282 | self.broker_instance = redis_from_url(self.dsn)
283 | self._broker_type = "redis"
284 | self._queue_class = RedisQueue
285 | elif self.dsn and self.dsn.startswith("postgresql"):
286 | from psycopg_pool import AsyncConnectionPool
287 | from saq.queue.postgres import PostgresQueue
288 |
289 | self.broker_instance = AsyncConnectionPool(self.dsn, check=AsyncConnectionPool.check_connection, open=False)
290 | self._broker_type = "postgres"
291 | self._queue_class = PostgresQueue
292 | elif self.dsn and self.dsn.startswith("http"):
293 | from saq.queue.http import HttpQueue
294 |
295 | self.broker_instance = HttpQueue(self.dsn)
296 | self._broker_type = "http"
297 | self._queue_class = HttpQueue
298 | else:
299 | msg = "Invalid broker type"
300 | raise ImproperlyConfiguredException(msg)
301 | return self.broker_instance
302 |
303 | @property
304 | def broker_type(self) -> 'Literal["redis", "postgres", "http"]':
305 | """Type of broker to use.
306 |
307 | Raises:
308 | ImproperlyConfiguredException: If the broker type is invalid.
309 |
310 | Returns:
311 | The broker type.
312 | """
313 | if self._broker_type is None and self.broker_instance is not None:
314 | if self.broker_instance.__class__.__name__ == "AsyncConnectionPool":
315 | self._broker_type = "postgres"
316 | elif self.broker_instance.__class__.__name__ == "Redis":
317 | self._broker_type = "redis"
318 | elif self.broker_instance.__class__.__name__ == "HttpQueue":
319 | self._broker_type = "http"
320 | if self._broker_type is None:
321 | self.get_broker()
322 | if self._broker_type is None:
323 | msg = "Invalid broker type"
324 | raise ImproperlyConfiguredException(msg)
325 | return self._broker_type
326 |
327 | @property
328 | def _broker_options(self) -> "Union[RedisQueueOptions, PostgresQueueOptions, dict[str, Any]]":
329 | """Broker-specific options.
330 |
331 | Returns:
332 | The broker options.
333 | """
334 | if self._broker_type == "postgres" and "manage_pool_lifecycle" not in self.broker_options:
335 | self.broker_options["manage_pool_lifecycle"] = True # type: ignore[typeddict-unknown-key]
336 | return self.broker_options
337 |
338 | @property
339 | def queue_class(self) -> "type[Queue]":
340 | """Type of queue to use.
341 |
342 | Raises:
343 | ImproperlyConfiguredException: If the queue class is invalid.
344 |
345 | Returns:
346 | The queue class.
347 | """
348 | if self._queue_class is None and self.broker_instance is not None:
349 | if self.broker_instance.__class__.__name__ == "AsyncConnectionPool":
350 | from saq.queue.postgres import PostgresQueue
351 |
352 | self._queue_class = PostgresQueue
353 | elif self.broker_instance.__class__.__name__ == "Redis":
354 | from saq.queue.redis import RedisQueue
355 |
356 | self._queue_class = RedisQueue
357 | elif self.broker_instance.__class__.__name__ == "HttpQueue":
358 | from saq.queue.http import HttpQueue
359 |
360 | self._queue_class = HttpQueue
361 | if self._queue_class is None:
362 | self.get_broker()
363 | if self._queue_class is None:
364 | msg = "Invalid queue class"
365 | raise ImproperlyConfiguredException(msg)
366 | return self._queue_class
367 |
368 | @staticmethod
369 | def _get_or_import_task(
370 | task_or_import_string: "Union[str, tuple[str, Function], ReceivesContext]",
371 | ) -> "ReceivesContext":
372 | """Get or import a task.
373 |
374 | Args:
375 | task_or_import_string: The task or import string.
376 |
377 | Returns:
378 | The task.
379 | """
380 | if isinstance(task_or_import_string, str):
381 | return cast("ReceivesContext", import_string(task_or_import_string))
382 | if isinstance(task_or_import_string, tuple):
383 | return task_or_import_string[1] # pyright: ignore
384 | return task_or_import_string
385 |
--------------------------------------------------------------------------------
/litestar_saq/controllers.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: PLR6301
2 | from functools import lru_cache
3 | from typing import TYPE_CHECKING, Any, Optional, cast
4 |
5 | from litestar.exceptions import NotFoundException
6 |
7 | if TYPE_CHECKING:
8 | from litestar import Controller
9 | from litestar.types.callable_types import Guard # pyright: ignore[reportUnknownVariableType]
10 | from saq.queue.base import Queue
11 | from saq.types import QueueInfo
12 |
13 | from litestar_saq.base import Job
14 | from litestar_saq.config import TaskQueues
15 |
16 |
17 | async def job_info(queue: "Queue", job_id: str) -> "Job":
18 | job = await queue.job(job_id)
19 | if not job:
20 | msg = f"Could not find job ID {job_id}"
21 | raise NotFoundException(msg)
22 | return cast("Job", job)
23 |
24 |
25 | @lru_cache(typed=True)
26 | def build_controller( # noqa: C901
27 | url_base: str = "/saq",
28 | controller_guards: "Optional[list[Guard]]" = None, # pyright: ignore[reportUnknownParameterType]
29 | include_in_schema_: bool = False,
30 | ) -> "type[Controller]":
31 | from litestar import Controller, MediaType, get, post
32 | from litestar.exceptions import NotFoundException
33 | from litestar.status_codes import HTTP_202_ACCEPTED
34 |
35 | class SAQController(Controller):
36 | tags = ["SAQ"]
37 | guards = controller_guards # pyright: ignore[reportUnknownVariableType]
38 | include_in_schema = include_in_schema_
39 |
40 | @get(
41 | operation_id="WorkerQueueList",
42 | name="worker:queue-list",
43 | path=[f"{url_base}/api/queues"],
44 | media_type=MediaType.JSON,
45 | cache=False,
46 | summary="Queue List",
47 | description="List configured worker queues.",
48 | )
49 | async def queue_list(self, task_queues: "TaskQueues") -> "dict[str, list[QueueInfo]]":
50 | """Get Worker queues.
51 |
52 | Args:
53 | task_queues: The task queues.
54 |
55 | Returns:
56 | The worker queues.
57 | """
58 | return {"queues": [await queue.info() for queue in task_queues.queues.values()]}
59 |
60 | @get(
61 | operation_id="WorkerQueueDetail",
62 | name="worker:queue-detail",
63 | path=f"{url_base}/api/queues/{{queue_id:str}}",
64 | media_type=MediaType.JSON,
65 | cache=False,
66 | summary="Queue Detail",
67 | description="List queue details.",
68 | )
69 | async def queue_detail(self, task_queues: "TaskQueues", queue_id: str) -> "dict[str, QueueInfo]":
70 | """Get queue information.
71 |
72 | Args:
73 | task_queues: The task queues.
74 | queue_id: The queue ID.
75 |
76 | Raises:
77 | NotFoundException: If the queue is not found.
78 |
79 | Returns:
80 | The queue information.
81 | """
82 | queue = task_queues.get(queue_id)
83 | if not queue:
84 | msg = f"Could not find the {queue_id} queue"
85 | raise NotFoundException(msg)
86 | return {"queue": await queue.info(jobs=True)}
87 |
88 | @get(
89 | operation_id="WorkerJobDetail",
90 | name="worker:job-detail",
91 | path=f"{url_base}/api/queues/{{queue_id:str}}/jobs/{{job_id:str}}",
92 | media_type=MediaType.JSON,
93 | cache=False,
94 | summary="Job Details",
95 | description="List job details.",
96 | )
97 | async def job_detail(
98 | self, task_queues: "TaskQueues", queue_id: str, job_id: str
99 | ) -> "dict[str, dict[str, Any]]":
100 | """Get job information.
101 |
102 | Args:
103 | task_queues: The task queues.
104 | queue_id: The queue ID.
105 | job_id: The job ID.
106 |
107 | Raises:
108 | NotFoundException: If the queue or job is not found.
109 |
110 | Returns:
111 | The job information.
112 | """
113 | queue = task_queues.get(queue_id)
114 | if not queue:
115 | msg = f"Could not find the {queue_id} queue"
116 | raise NotFoundException(msg)
117 | job = await job_info(queue, job_id)
118 | job_dict = job.to_dict()
119 | if "kwargs" in job_dict:
120 | job_dict["kwargs"] = repr(job_dict["kwargs"])
121 | if "result" in job_dict:
122 | job_dict["result"] = repr(job_dict["result"])
123 | return {"job": job_dict}
124 |
125 | @post(
126 | operation_id="WorkerJobRetry",
127 | name="worker:job-retry",
128 | path=f"{url_base}/api/queues/{{queue_id:str}}/jobs/{{job_id:str}}/retry",
129 | media_type=MediaType.JSON,
130 | cache=False,
131 | summary="Job Retry",
132 | description="Retry a failed job..",
133 | status_code=HTTP_202_ACCEPTED,
134 | )
135 | async def job_retry(self, task_queues: "TaskQueues", queue_id: str, job_id: str) -> "dict[str, str]":
136 | """Retry job.
137 |
138 | Args:
139 | task_queues: The task queues.
140 | queue_id: The queue ID.
141 | job_id: The job ID.
142 |
143 | Raises:
144 | NotFoundException: If the queue or job is not found.
145 |
146 | Returns:
147 | The job information.
148 | """
149 | queue = task_queues.get(queue_id)
150 | if not queue:
151 | msg = f"Could not find the {queue_id} queue"
152 | raise NotFoundException(msg)
153 | job = await job_info(queue, job_id)
154 | await job.retry("retried from ui")
155 | return {}
156 |
157 | @post(
158 | operation_id="WorkerJobAbort",
159 | name="worker:job-abort",
160 | path=f"{url_base}/api/queues/{{queue_id:str}}/jobs/{{job_id:str}}/abort",
161 | media_type=MediaType.JSON,
162 | cache=False,
163 | summary="Job Abort",
164 | description="Abort active job.",
165 | status_code=HTTP_202_ACCEPTED,
166 | )
167 | async def job_abort(self, task_queues: "TaskQueues", queue_id: str, job_id: str) -> "dict[str, str]":
168 | """Abort job.
169 |
170 | Args:
171 | task_queues: The task queues.
172 | queue_id: The queue ID.
173 | job_id: The job ID.
174 |
175 | Raises:
176 | NotFoundException: If the queue or job is not found.
177 |
178 | Returns:
179 | The job information.
180 | """
181 | queue = task_queues.get(queue_id)
182 | if not queue:
183 | msg = f"Could not find the {queue_id} queue"
184 | raise NotFoundException(msg)
185 | job = await job_info(queue, job_id)
186 | await job.abort("aborted from ui")
187 | return {}
188 |
189 | # static site
190 | @get(
191 | [
192 | f"{url_base}/",
193 | f"{url_base}/queues/{{queue_id:str}}",
194 | f"{url_base}/queues/{{queue_id:str}}/jobs/{{job_id:str}}",
195 | ],
196 | operation_id="WorkerIndex",
197 | name="worker:index",
198 | media_type=MediaType.HTML,
199 | include_in_schema=False,
200 | )
201 | async def index(self) -> str:
202 | """Serve site root.
203 |
204 | Returns:
205 | The site root.
206 | """
207 | return f"""
208 |
209 |
210 |
211 |
212 |
213 |
214 | SAQ
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 | """.strip()
223 |
224 | return SAQController
225 |
--------------------------------------------------------------------------------
/litestar_saq/exceptions.py:
--------------------------------------------------------------------------------
1 | class LitestarSaqError(Exception):
2 | """Base exception type for the Litestar Saq."""
3 |
4 |
5 | class ImproperConfigurationError(LitestarSaqError):
6 | """Improper Configuration error.
7 |
8 | This exception is raised only when a module depends on a dependency that has not been installed.
9 | """
10 |
11 |
12 | class BackgroundTaskError(Exception):
13 | """Base class for `Task` related exceptions."""
14 |
--------------------------------------------------------------------------------
/litestar_saq/plugin.py:
--------------------------------------------------------------------------------
1 | import signal
2 | import sys
3 | import time
4 | from contextlib import contextmanager
5 | from importlib.util import find_spec
6 | from multiprocessing import Process
7 | from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
8 |
9 | from litestar.plugins import CLIPlugin, InitPluginProtocol
10 |
11 | from litestar_saq.base import Worker
12 |
13 | if TYPE_CHECKING:
14 | from collections.abc import Collection, Iterator
15 |
16 | from click import Group
17 | from litestar import Litestar
18 | from litestar.config.app import AppConfig
19 | from saq.queue.base import Queue
20 | from saq.types import Function, ReceivesContext
21 |
22 | from litestar_saq.config import SAQConfig, TaskQueues
23 |
24 | T = TypeVar("T")
25 |
26 | STRUCTLOG_INSTALLED = find_spec("structlog") is not None
27 |
28 |
29 | class SAQPlugin(InitPluginProtocol, CLIPlugin):
30 | """SAQ plugin."""
31 |
32 | __slots__ = ("_config", "_processes", "_worker_instances")
33 |
34 | WORKER_SHUTDOWN_TIMEOUT = 5.0 # seconds
35 | WORKER_JOIN_TIMEOUT = 1.0 # seconds
36 |
37 | def __init__(self, config: "SAQConfig") -> None:
38 | """Initialize ``SAQPlugin``.
39 |
40 | Args:
41 | config: configure and start SAQ.
42 | """
43 | self._config = config
44 | self._worker_instances: Optional[dict[str, Worker]] = None
45 |
46 | @property
47 | def config(self) -> "SAQConfig":
48 | return self._config
49 |
50 | def on_cli_init(self, cli: "Group") -> None:
51 | from litestar_saq.cli import build_cli_app
52 |
53 | cli.add_command(build_cli_app())
54 | return super().on_cli_init(cli) # type: ignore[safe-super]
55 |
56 | def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
57 | """Configure application for use with SQLAlchemy.
58 |
59 | Args:
60 | app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
61 |
62 | Returns:
63 | The :class:`AppConfig <.config.app.AppConfig>` instance.
64 | """
65 |
66 | from litestar.di import Provide
67 | from litestar.static_files import create_static_files_router # pyright: ignore[reportUnknownVariableType]
68 |
69 | from litestar_saq.controllers import build_controller
70 |
71 | app_config.dependencies.update(
72 | {self._config.queues_dependency_key: Provide(dependency=self._config.provide_queues)}
73 | )
74 | if self._config.web_enabled:
75 | app_config.route_handlers.append(
76 | create_static_files_router(
77 | directories=[self._config.static_files],
78 | path=f"{self._config.web_path}/static",
79 | name="saq",
80 | html_mode=False,
81 | opt={"exclude_from_auth": True},
82 | include_in_schema=False,
83 | ),
84 | )
85 | app_config.route_handlers.append(
86 | build_controller(self._config.web_path, self._config.web_guards, self._config.web_include_in_schema), # type: ignore[arg-type]
87 | )
88 | app_config.signature_namespace.update(self._config.signature_namespace)
89 |
90 | workers = self.get_workers()
91 | for worker in workers.values():
92 | app_config.on_startup.append(worker.on_app_startup)
93 | app_config.on_shutdown.append(worker.on_app_shutdown)
94 | app_config.on_shutdown.extend([self.remove_workers])
95 | return app_config
96 |
97 | def get_workers(self) -> "dict[str, Worker]":
98 | """Return workers"""
99 | if self._worker_instances is not None:
100 | return self._worker_instances
101 | self._worker_instances = {
102 | queue_config.name: Worker(
103 | queue=self.get_queue(queue_config.name),
104 | id=queue_config.id,
105 | functions=cast("Collection[Function]", queue_config.tasks),
106 | cron_jobs=queue_config.scheduled_tasks,
107 | cron_tz=queue_config.cron_tz,
108 | concurrency=queue_config.concurrency,
109 | startup=cast("Collection[ReceivesContext]", queue_config.startup),
110 | shutdown=cast("Collection[ReceivesContext]", queue_config.shutdown),
111 | before_process=cast("Collection[ReceivesContext]", queue_config.before_process),
112 | after_process=cast("Collection[ReceivesContext]", queue_config.after_process),
113 | timers=queue_config.timers,
114 | dequeue_timeout=queue_config.dequeue_timeout,
115 | separate_process=queue_config.separate_process,
116 | burst=queue_config.burst,
117 | max_burst_jobs=queue_config.max_burst_jobs,
118 | metadata=queue_config.metadata,
119 | )
120 | for queue_config in self._config.queue_configs
121 | }
122 |
123 | return self._worker_instances
124 |
125 | def remove_workers(self) -> None:
126 | self._worker_instances = None
127 |
128 | def get_queues(self) -> "TaskQueues":
129 | return self._config.get_queues()
130 |
131 | def get_queue(self, name: str) -> "Queue":
132 | return self.get_queues().get(name)
133 |
134 | @contextmanager
135 | def server_lifespan(self, app: "Litestar") -> "Iterator[None]":
136 | import multiprocessing
137 | import platform
138 |
139 | from litestar.cli._utils import console # pyright: ignore
140 |
141 | from litestar_saq.cli import run_saq_worker
142 |
143 | if platform.system() == "Darwin":
144 | multiprocessing.set_start_method("fork", force=True)
145 |
146 | if not self._config.use_server_lifespan:
147 | yield
148 | return
149 |
150 | console.rule("[yellow]Starting SAQ Workers[/]", align="left")
151 | self._processes: list[Process] = []
152 |
153 | def handle_shutdown(_signum: Any, _frame: Any) -> None:
154 | """Handle shutdown signals gracefully."""
155 | console.print("[yellow]Received shutdown signal, stopping workers...[/]")
156 | self._terminate_workers(self._processes)
157 | sys.exit(0)
158 |
159 | # Register signal handlers
160 | signal.signal(signal.SIGTERM, handle_shutdown)
161 | signal.signal(signal.SIGINT, handle_shutdown)
162 |
163 | try:
164 | for worker_name, worker in self.get_workers().items():
165 | for i in range(self.config.worker_processes):
166 | console.print(f"[yellow]Starting worker process {i + 1} for {worker_name}[/]")
167 | process = Process(
168 | target=run_saq_worker,
169 | args=(
170 | worker,
171 | app.logging_config,
172 | ),
173 | name=f"worker-{worker_name}-{i + 1}",
174 | )
175 | process.start()
176 | self._processes.append(process)
177 |
178 | yield
179 |
180 | except Exception as e:
181 | console.print(f"[red]Error in worker processes: {e}[/]")
182 | raise
183 | finally:
184 | console.print("[yellow]Shutting down SAQ workers...[/]")
185 | self._terminate_workers(self._processes)
186 | console.print("[yellow]SAQ workers stopped.[/]")
187 |
188 | @staticmethod
189 | def _terminate_workers(processes: "list[Process]", timeout: float = 5.0) -> None:
190 | """Gracefully terminate worker processes with timeout.
191 |
192 | Args:
193 | processes: List of worker processes to terminate
194 | timeout: Maximum time to wait for graceful shutdown in seconds
195 | """
196 | # Send SIGTERM to all processes
197 | from litestar.cli._utils import console # pyright: ignore
198 |
199 | for p in processes:
200 | if p.is_alive():
201 | p.terminate()
202 |
203 | # Wait for processes to terminate gracefully
204 | termination_start = time.time()
205 | while time.time() - termination_start < timeout:
206 | if not any(p.is_alive() for p in processes):
207 | break
208 | time.sleep(0.1)
209 |
210 | # Force kill any remaining processes
211 | for p in processes:
212 | if p.is_alive():
213 | try:
214 | p.kill() # Send SIGKILL
215 | p.join(timeout=1.0)
216 | except Exception as e: # noqa: BLE001
217 | console.print(f"[red]Error killing worker process: {e}[/]")
218 |
--------------------------------------------------------------------------------
/litestar_saq/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cofin/litestar-saq/329934de25c5a3f7f3459fd8e8d2ca7013d6fe50/litestar_saq/py.typed
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | authors = [{ name = "Cody Fincher", email = "cody.fincher@gmail.com" }]
3 | classifiers = [
4 | "Development Status :: 3 - Alpha",
5 | "Environment :: Web Environment",
6 | "License :: OSI Approved :: MIT License",
7 | "Natural Language :: English",
8 | "Operating System :: OS Independent",
9 | "Programming Language :: Python :: 3.9",
10 | "Programming Language :: Python :: 3.10",
11 | "Programming Language :: Python :: 3.11",
12 | "Programming Language :: Python :: 3.12",
13 | "Programming Language :: Python :: 3.13",
14 | "Programming Language :: Python",
15 | "Topic :: Software Development",
16 | "Typing :: Typed",
17 | ]
18 | dependencies = [
19 | "litestar>=2.0.1",
20 | "saq>=0.24.4",
21 | ]
22 | description = "Litestar integration for SAQ"
23 | keywords = ["litestar", "saq"]
24 | license = { text = "MIT" }
25 | name = "litestar-saq"
26 | readme = "README.md"
27 | requires-python = ">=3.9"
28 | version = "0.5.3"
29 |
30 | [project.optional-dependencies]
31 | hiredis = ["hiredis"]
32 | psycopg = ["psycopg[pool,binary]"]
33 |
34 | [project.urls]
35 | Changelog = "https://cofin.github.io/litesatr-saq/latest/changelog"
36 | Discord = "https://discord.gg/X3FJqy8d2j"
37 | Documentation = "https://cofin.github.io/litesatr-saq/latest/"
38 | Homepage = "https://cofin.github.io/litesatr-saq/latest/"
39 | Issue = "https://github.com/cofin/litestar-saq/issues/"
40 | Source = "https://github.com/cofin/litestar-saq"
41 |
42 | [build-system]
43 | build-backend = "hatchling.build"
44 | requires = ["hatchling"]
45 |
46 | [dependency-groups]
47 | dev = [{include-group = "build"}, {include-group = "linting"}, {include-group = "test"}]
48 | build = ["bump-my-version"]
49 | linting = [
50 | "pre-commit",
51 | "mypy",
52 | "ruff",
53 | "types-click",
54 | "types-redis",
55 | "types-croniter",
56 | "pyright",
57 | "slotscheck",
58 | ]
59 | test = [
60 | "pytest",
61 | "pytest-mock",
62 | "httpx",
63 | "pytest-cov",
64 | "coverage",
65 | "pytest-databases",
66 | "pytest-sugar",
67 | "pytest-asyncio",
68 | "pytest-xdist",
69 | "anyio",
70 | "litestar[jinja,redis,standard]",
71 | "psycopg[pool,binary]",
72 | ]
73 |
74 |
75 | [tool.bumpversion]
76 | allow_dirty = true
77 | commit = false
78 | commit_args = "--no-verify"
79 | current_version = "0.5.3"
80 | ignore_missing_files = false
81 | ignore_missing_version = false
82 | message = "chore(release): bump to `v{new_version}`"
83 | parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)"
84 | regex = false
85 | replace = "{new_version}"
86 | search = "{current_version}"
87 | serialize = ["{major}.{minor}.{patch}"]
88 | sign_tags = false
89 | tag = false
90 | tag_message = "chore(release): `v{new_version}`"
91 | tag_name = "v{new_version}"
92 |
93 | [[tool.bumpversion.files]]
94 | filename = "pyproject.toml"
95 | replace = 'version = "{new_version}"'
96 | search = 'version = "{current_version}"'
97 |
98 |
99 | [[tool.bumpversion.files]]
100 | filename = "uv.lock"
101 | replace = """
102 | name = "litestar-saq"
103 | version = "{new_version}"
104 | """
105 | search = """
106 | name = "litestar-saq"
107 | version = "{current_version}"
108 | """
109 |
110 | [tool.pytest.ini_options]
111 | addopts = ["-q", "-ra"]
112 | filterwarnings = [
113 | "ignore::DeprecationWarning:pkg_resources",
114 | "ignore::DeprecationWarning:xdist.*",
115 | ]
116 | minversion = "6.0"
117 | testpaths = ["tests"]
118 | tmp_path_retention_policy = "failed"
119 | tmp_path_retention_count = 3
120 | asyncio_default_fixture_loop_scope = "function"
121 | asyncio_mode = "auto"
122 |
123 | [tool.coverage.report]
124 | exclude_lines = [
125 | 'if TYPE_CHECKING:',
126 | 'pragma: no cover',
127 | "if __name__ == .__main__.:",
128 | 'def __repr__',
129 | 'if self\.debug:',
130 | 'if settings\.DEBUG',
131 | 'raise AssertionError',
132 | 'raise NotImplementedError',
133 | 'if 0:',
134 | 'class .*\bProtocol\):',
135 | '@(abc\.)?abstractmethod',
136 | ]
137 | omit = ["*/tests/*"]
138 | show_missing = true
139 |
140 |
141 | [tool.coverage.run]
142 | branch = true
143 | concurrency = ["multiprocessing" ]
144 | omit = ["tests/*"]
145 | parallel = true
146 |
147 | [tool.slotscheck]
148 | strict-imports = false
149 |
150 | [tool.ruff]
151 | exclude = [
152 | ".bzr",
153 | ".direnv",
154 | ".eggs",
155 | ".git",
156 | ".hg",
157 | ".mypy_cache",
158 | ".nox",
159 | ".pants.d",
160 | ".ruff_cache",
161 | ".svn",
162 | ".tox",
163 | ".venv",
164 | "__pypackages__",
165 | "_build",
166 | "buck-out",
167 | "build",
168 | "dist",
169 | "node_modules",
170 | "venv",
171 | '__pycache__',
172 | ]
173 | fix = true
174 | line-length = 120
175 | lint.fixable = ["ALL"]
176 | lint.ignore = [
177 | "A003", # flake8-builtins - class attribute {name} is shadowing a python builtin
178 | "B010", # flake8-bugbear - do not call setattr with a constant attribute value
179 | "D100", # pydocstyle - missing docstring in public module
180 | "D101", # pydocstyle - missing docstring in public class
181 | "D102", # pydocstyle - missing docstring in public method
182 | "D103", # pydocstyle - missing docstring in public function
183 | "D104", # pydocstyle - missing docstring in public package
184 | "D105", # pydocstyle - missing docstring in magic method
185 | "D106", # pydocstyle - missing docstring in public nested class
186 | "D107", # pydocstyle - missing docstring in __init__
187 | "D202", # pydocstyle - no blank lines allowed after function docstring
188 | "D205", # pydocstyle - 1 blank line required between summary line and description
189 | "D415", # pydocstyle - first line should end with a period, question mark, or exclamation point
190 | "E501", # pycodestyle line too long, handled by black
191 | "PLW2901", # pylint - for loop variable overwritten by assignment target
192 | "RUF012", # Ruff-specific rule - annotated with classvar
193 | "ANN401",
194 | "FBT",
195 | "PLR0913", # too many arguments
196 | "PT",
197 | "TD",
198 | "ARG002", # ignore for now; investigate
199 | "PERF203", # ignore for now; investigate
200 | "ISC001",
201 | "COM812",
202 | "FA100", # use __future__ import annotations
203 | "CPY001", # copyright at the top of the file
204 | "PLC0415", # import at the top of the file
205 | "PGH003", # Use specific rule codes when ignoring type issues
206 | "PLC2701", # ignore for now; investigate
207 | ]
208 | lint.select = ["ALL"]
209 | # Allow unused variables when underscore-prefixed.
210 | lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
211 | src = ["litestar_saq", "tests"]
212 | target-version = "py39"
213 | unsafe-fixes = true
214 |
215 | [tool.ruff.lint.pydocstyle]
216 | convention = "google"
217 |
218 | [tool.ruff.lint.mccabe]
219 | max-complexity = 12
220 |
221 | [tool.ruff.lint.pep8-naming]
222 | classmethod-decorators = [
223 | "sqlalchemy.ext.declarative.declared_attr",
224 | "sqlalchemy.orm.declared_attr.directive",
225 | "sqlalchemy.orm.declared_attr",
226 | ]
227 |
228 | [tool.ruff.lint.per-file-ignores]
229 | "tests/**/*.*" = [
230 | "A",
231 | "ARG",
232 | "B",
233 | "BLE",
234 | "C901",
235 | "D",
236 | "DTZ",
237 | "EM",
238 | "FBT",
239 | "G",
240 | "N",
241 | "PGH",
242 | "PIE",
243 | "PLR",
244 | "PLW",
245 | "PTH",
246 | "RSE",
247 | "S",
248 | "S101",
249 | "SIM",
250 | "TC",
251 | "TRY",
252 | "UP006",
253 | "SLF001",
254 | "ERA001",
255 |
256 | ]
257 | "tools/*.py" = [ "PLR0911"]
258 |
259 | [tool.ruff.lint.isort]
260 | known-first-party = ["litestar_saq", "tests"]
261 |
262 | [tool.mypy]
263 | disallow_any_generics = false
264 | disallow_untyped_decorators = true
265 | implicit_reexport = false
266 | show_error_codes = true
267 | strict = true
268 | warn_redundant_casts = true
269 | warn_return_any = true
270 | warn_unreachable = true
271 | warn_unused_configs = true
272 | warn_unused_ignores = true
273 |
274 | [[tool.mypy.overrides]]
275 | disable_error_code = "attr-defined"
276 | disallow_untyped_decorators = false
277 | module = "tests.*"
278 |
279 | [tool.pyright]
280 | venvPath = "."
281 | disableBytesTypePromotions = true
282 | exclude = [
283 | "docs",
284 | "tests/helpers.py",
285 | ]
286 | include = ["litestar_saq"]
287 | pythonVersion = "3.9"
288 | strict = ["litestar_saq/**/*"]
289 | venv = ".venv"
290 |
291 | [tool.codespell]
292 | ignore-words-list = "selectin"
293 | skip = 'uv.lock,pyproject.toml'
294 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cofin/litestar-saq/329934de25c5a3f7f3459fd8e8d2ca7013d6fe50/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import pytest
6 | from redis.asyncio import Redis
7 |
8 | if TYPE_CHECKING:
9 | from collections.abc import AsyncGenerator
10 |
11 | from pytest_databases.docker.redis import RedisService
12 |
13 | pytestmark = pytest.mark.anyio
14 | pytest_plugins = [
15 | "pytest_databases.docker",
16 | "pytest_databases.docker.redis",
17 | ]
18 |
19 |
20 | @pytest.fixture(scope="session")
21 | def anyio_backend() -> str:
22 | return "asyncio"
23 |
24 |
25 | @pytest.fixture(name="redis", autouse=True)
26 | async def fx_redis(redis_service: RedisService) -> AsyncGenerator[Redis, None]:
27 | """Redis instance for testing.
28 |
29 | Returns:
30 | Redis client instance, function scoped.
31 | """
32 | yield Redis(host=redis_service.host, port=redis_service.port)
33 |
--------------------------------------------------------------------------------
/tests/test_cli/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | APP_DEFAULT_CONFIG_FILE_CONTENT = """
4 | from __future__ import annotations
5 |
6 | import asyncio
7 | from logging import getLogger
8 | from typing import TYPE_CHECKING
9 |
10 | from examples import tasks
11 | from litestar import Controller, Litestar, get
12 |
13 | from litestar_saq import CronJob, QueueConfig, SAQConfig, SAQPlugin
14 |
15 | if TYPE_CHECKING:
16 | from saq.types import Context, QueueInfo
17 |
18 | from litestar_saq.config import TaskQueues
19 |
20 | logger = getLogger(__name__)
21 |
22 |
23 | async def system_upkeep(_: Context) -> None:
24 | logger.info("Performing system upkeep operations.")
25 | logger.info("Simulating a long running operation. Sleeping for 60 seconds.")
26 | await asyncio.sleep(3)
27 | logger.info("Simulating an even longer running operation. Sleeping for 120 seconds.")
28 | await asyncio.sleep(3)
29 | logger.info("Long running process complete.")
30 | logger.info("Performing system upkeep operations.")
31 |
32 |
33 | async def background_worker_task(_: Context) -> None:
34 | logger.info("Performing background worker task.")
35 | await asyncio.sleep(1)
36 | logger.info("Performing system upkeep operations.")
37 |
38 |
39 | async def system_task(_: Context) -> None:
40 | logger.info("Performing simple system task")
41 | await asyncio.sleep(2)
42 | logger.info("System task complete.")
43 |
44 |
45 | class SampleController(Controller):
46 | @get(path="/samples")
47 | async def samples_queue_info(self, task_queues: TaskQueues) -> QueueInfo:
48 | queue = task_queues.get("samples")
49 | return await queue.info()
50 |
51 |
52 | saq = SAQPlugin(
53 | config=SAQConfig(
54 | web_enabled=True,
55 | use_server_lifespan=True,
56 | queue_configs=[
57 | QueueConfig(
58 | dsn="redis://localhost:6397/0",
59 | name="samples",
60 | tasks=[tasks.background_worker_task, tasks.system_task, tasks.system_upkeep],
61 | scheduled_tasks=[CronJob(function=tasks.system_upkeep, cron="* * * * *", timeout=600, ttl=2000)],
62 | ),
63 | ],
64 | ),
65 | )
66 | app = Litestar(plugins=[saq], route_handlers=[SampleController])
67 |
68 | """
69 |
--------------------------------------------------------------------------------
/tests/test_cli/conftest.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import importlib.util
4 | import sys
5 | from collections.abc import Generator
6 | from pathlib import Path
7 | from shutil import rmtree
8 | from typing import TYPE_CHECKING, Callable, Protocol, cast
9 |
10 | import pytest
11 | from _pytest.fixtures import FixtureRequest
12 | from _pytest.monkeypatch import MonkeyPatch
13 | from click.testing import CliRunner
14 | from litestar.cli._utils import _path_to_dotted_path
15 | from pytest_mock import MockerFixture
16 |
17 | from . import APP_DEFAULT_CONFIG_FILE_CONTENT
18 |
19 | if TYPE_CHECKING:
20 | from unittest.mock import MagicMock
21 |
22 | from litestar.cli._utils import LitestarGroup
23 |
24 |
25 | @pytest.fixture(autouse=True)
26 | def reset_litestar_app_env(monkeypatch: MonkeyPatch) -> None:
27 | monkeypatch.delenv("LITESTAR_APP", raising=False)
28 |
29 |
30 | @pytest.fixture()
31 | def root_command() -> LitestarGroup:
32 | import litestar.cli.main
33 |
34 | return cast("LitestarGroup", importlib.reload(litestar.cli.main).litestar_group)
35 |
36 |
37 | @pytest.fixture
38 | def patch_autodiscovery_paths(request: FixtureRequest) -> Callable[[list[str]], None]:
39 | def patcher(paths: list[str]) -> None:
40 | from litestar.cli._utils import AUTODISCOVERY_FILE_NAMES
41 |
42 | old_paths = AUTODISCOVERY_FILE_NAMES[::]
43 | AUTODISCOVERY_FILE_NAMES[:] = paths
44 |
45 | def finalizer() -> None:
46 | AUTODISCOVERY_FILE_NAMES[:] = old_paths
47 |
48 | request.addfinalizer(finalizer)
49 |
50 | return patcher
51 |
52 |
53 | @pytest.fixture(autouse=True)
54 | def tmp_project_dir(monkeypatch: MonkeyPatch, tmp_path: Path) -> Path:
55 | path = tmp_path / "project_dir"
56 | path.mkdir(exist_ok=True)
57 | monkeypatch.chdir(path)
58 | return path
59 |
60 |
61 | class CreateAppFileFixture(Protocol):
62 | def __call__(
63 | self,
64 | file: str | Path,
65 | directory: str | Path | None = None,
66 | content: str | None = None,
67 | init_content: str = "",
68 | subdir: str | None = None,
69 | ) -> Path: ...
70 |
71 |
72 | def _purge_module(module_names: list[str], path: str | Path) -> None:
73 | for name in module_names:
74 | if name in sys.modules:
75 | del sys.modules[name]
76 | Path(importlib.util.cache_from_source(path)).unlink(missing_ok=True) # type: ignore[arg-type]
77 |
78 |
79 | @pytest.fixture
80 | def create_app_file(tmp_project_dir: Path, request: FixtureRequest) -> CreateAppFileFixture:
81 | def _create_app_file(
82 | file: str | Path,
83 | directory: str | Path | None = None,
84 | content: str | None = None,
85 | init_content: str = "",
86 | subdir: str | None = None,
87 | ) -> Path:
88 | base = tmp_project_dir
89 | if directory:
90 | base /= Path(Path(directory) / subdir) if subdir else Path(directory)
91 | base.mkdir(parents=True)
92 | base.joinpath("__init__.py").write_text(init_content)
93 |
94 | tmp_app_file = base / file
95 | tmp_app_file.write_text(content or APP_DEFAULT_CONFIG_FILE_CONTENT)
96 |
97 | if directory:
98 | request.addfinalizer(lambda: rmtree(directory))
99 | request.addfinalizer(
100 | lambda: _purge_module(
101 | [directory, _path_to_dotted_path(tmp_app_file.relative_to(Path.cwd()))], # type: ignore[list-item]
102 | tmp_app_file,
103 | ),
104 | )
105 | else:
106 | request.addfinalizer(tmp_app_file.unlink)
107 | request.addfinalizer(lambda: _purge_module([str(file).replace(".py", "")], tmp_app_file))
108 | return tmp_app_file
109 |
110 | return _create_app_file
111 |
112 |
113 | @pytest.fixture
114 | def app_file(create_app_file: CreateAppFileFixture) -> Path:
115 | return create_app_file("app.py")
116 |
117 |
118 | @pytest.fixture
119 | def runner() -> CliRunner:
120 | return CliRunner()
121 |
122 |
123 | @pytest.fixture
124 | def mock_uvicorn_run(mocker: MockerFixture) -> MagicMock:
125 | return mocker.patch("uvicorn.run")
126 |
127 |
128 | @pytest.fixture()
129 | def mock_subprocess_run(mocker: MockerFixture) -> MagicMock:
130 | return mocker.patch("subprocess.run")
131 |
132 |
133 | @pytest.fixture
134 | def mock_confirm_ask(mocker: MockerFixture) -> Generator[MagicMock, None, None]:
135 | yield mocker.patch("rich.prompt.Confirm.ask", return_value=True)
136 |
137 |
138 | @pytest.fixture(
139 | params=[
140 | pytest.param((APP_DEFAULT_CONFIG_FILE_CONTENT, "app"), id="app_obj"),
141 | ],
142 | )
143 | def _app_file_content(request: FixtureRequest) -> tuple[str, str]:
144 | return cast("tuple[str, str]", request.param)
145 |
146 |
147 | @pytest.fixture
148 | def app_file_content(_app_file_content: tuple[str, str]) -> str:
149 | return _app_file_content[0]
150 |
151 |
152 | @pytest.fixture
153 | def app_file_app_name(_app_file_content: tuple[str, str]) -> str:
154 | return _app_file_content[1]
155 |
--------------------------------------------------------------------------------
/tests/test_cli/test_cli.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from click.testing import CliRunner
3 | from litestar.cli._utils import LitestarGroup
4 | from redis.asyncio import Redis
5 |
6 | from tests.test_cli import APP_DEFAULT_CONFIG_FILE_CONTENT
7 | from tests.test_cli.conftest import CreateAppFileFixture
8 |
9 | pytestmark = pytest.mark.anyio
10 |
11 |
12 | async def test_basic_command(
13 | runner: CliRunner,
14 | create_app_file: CreateAppFileFixture,
15 | root_command: LitestarGroup,
16 | redis_service: None,
17 | redis: Redis,
18 | ) -> None:
19 | app_file = create_app_file("command_test_app.py", content=APP_DEFAULT_CONFIG_FILE_CONTENT)
20 | result = runner.invoke(root_command, ["--app", f"{app_file.stem}:app", "workers"])
21 |
22 | assert not result.exception
23 | assert "Manage background task workers." in result.output
24 |
--------------------------------------------------------------------------------
/tests/test_plugin.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | import pytest
4 | from litestar.cli._utils import LitestarGroup as Group
5 |
6 | from litestar_saq.config import SAQConfig, TaskQueues
7 | from litestar_saq.plugin import SAQPlugin
8 |
9 | # Assuming SAQConfig, Worker, Group, AppConfig, TaskQueues, Queue are available and can be imported
10 | # Assuming there are meaningful __eq__ methods for comparisons where needed
11 |
12 |
13 | # Test on_cli_init method
14 | @pytest.mark.parametrize(
15 | "cli_group",
16 | [
17 | Mock(Group),
18 | ],
19 | )
20 | def test_on_cli_init(cli_group: Group) -> None:
21 | # Arrange
22 | config = Mock(SAQConfig)
23 | plugin = SAQPlugin(config)
24 |
25 | # Act
26 | plugin.on_cli_init(cli_group)
27 |
28 | # Assert
29 | cli_group.add_command.assert_called_once() # pyright: ignore[reportFunctionMemberAccess]
30 |
31 |
32 | # Test get_queues method
33 | @pytest.mark.parametrize(
34 | "queues",
35 | [
36 | Mock(),
37 | ],
38 | )
39 | def test_get_queues(queues: TaskQueues) -> None:
40 | # Arrange
41 | config = Mock(SAQConfig, get_queues=Mock(return_value=queues))
42 | plugin = SAQPlugin(config)
43 |
44 | # Act
45 | result = plugin.get_queues()
46 |
47 | # Assert
48 | assert result == queues
49 |
--------------------------------------------------------------------------------