├── .github ├── CODEOWNERS ├── dependabot.yaml └── workflows │ ├── ci.yaml │ ├── pr-title.yaml │ ├── publish.yaml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.rst ├── LICENSE ├── Makefile ├── README.md ├── examples ├── __init__.py ├── basic.py ├── postgres.py └── tasks.py ├── litestar_saq ├── __init__.py ├── __metadata__.py ├── base.py ├── cli.py ├── config.py ├── controllers.py ├── exceptions.py ├── plugin.py └── py.typed ├── pyproject.toml ├── tests ├── __init__.py ├── conftest.py ├── test_cli │ ├── __init__.py │ ├── conftest.py │ └── test_cli.py └── test_plugin.py └── uv.lock /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Code owner settings for `litstar-org` 2 | # @maintainers should be assigned to all reviews. 3 | # Most specific assignment takes precedence though, so if you add a more specific thing than the `*` glob, you must also add @maintainers 4 | # For more info about code owners see https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners#codeowners-file-example 5 | 6 | # Global Assignment 7 | * @litestar-org/maintainers @litestar-org/members 8 | -------------------------------------------------------------------------------- /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Tests And Linting 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | concurrency: 10 | group: test-${{ github.head_ref }} 11 | cancel-in-progress: true 12 | 13 | 14 | jobs: 15 | validate: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Install uv 21 | uses: astral-sh/setup-uv@v6 22 | 23 | - name: Set up Python 24 | run: uv python install 3.12 25 | 26 | - name: Create virtual environment 27 | run: uv sync --all-extras --dev 28 | 29 | - name: Install Pre-Commit hooks 30 | run: uv run pre-commit install 31 | 32 | - name: Load cached Pre-Commit Dependencies 33 | id: cached-pre-commit-dependencies 34 | uses: actions/cache@v4 35 | with: 36 | path: ~/.cache/pre-commit/ 37 | key: pre-commit|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }} 38 | 39 | - name: Execute Pre-Commit 40 | run: uv run pre-commit run --show-diff-on-failure --color=always --all-files 41 | mypy: 42 | runs-on: ubuntu-latest 43 | steps: 44 | - uses: actions/checkout@v4 45 | 46 | - name: Install uv 47 | uses: astral-sh/setup-uv@v6 48 | 49 | - name: Set up Python 50 | run: uv python install 3.12 51 | 52 | - name: Install dependencies 53 | run: uv sync --all-extras --dev 54 | 55 | - name: Run mypy 56 | run: uv run mypy litestar_saq/ 57 | 58 | pyright: 59 | runs-on: ubuntu-latest 60 | steps: 61 | - uses: actions/checkout@v4 62 | 63 | - name: Install uv 64 | uses: astral-sh/setup-uv@v6 65 | 66 | - name: Set up Python 67 | run: uv python install 3.12 68 | 69 | - name: Install dependencies 70 | run: uv sync --all-extras --dev 71 | 72 | - name: Run pyright 73 | run: uv run pyright 74 | 75 | slotscheck: 76 | runs-on: ubuntu-latest 77 | steps: 78 | - uses: actions/checkout@v4 79 | 80 | - name: Install uv 81 | uses: astral-sh/setup-uv@v6 82 | 83 | - name: Set up Python 84 | run: uv python install 3.12 85 | 86 | - name: Install dependencies 87 | run: uv sync --all-extras --dev 88 | 89 | - name: Run slotscheck 90 | run: uv run slotscheck -m litestar_saq 91 | 92 | 93 | test_python: 94 | name: "test (python ${{ matrix.python-version }})" 95 | strategy: 96 | fail-fast: true 97 | matrix: 98 | python-version: ["3.9", "3.10", "3.11", "3.12","3.13"] 99 | uses: ./.github/workflows/test.yml 100 | with: 101 | coverage: ${{ matrix.python-version == '3.12' }} 102 | python-version: ${{ matrix.python-version }} 103 | -------------------------------------------------------------------------------- /.github/workflows/pr-title.yaml: -------------------------------------------------------------------------------- 1 | name: "Lint PR Title" 2 | 3 | on: 4 | pull_request_target: 5 | types: 6 | - opened 7 | - edited 8 | - synchronize 9 | 10 | permissions: 11 | pull-requests: read 12 | 13 | jobs: 14 | main: 15 | name: Validate PR title 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: amannn/action-semantic-pull-request@v5 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: PublishLatest Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | publish-python-release: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | id-token: write 13 | environment: release 14 | steps: 15 | - name: Check out repository 16 | uses: actions/checkout@v4 17 | 18 | - name: Install uv 19 | uses: astral-sh/setup-uv@v6 20 | 21 | - name: Set up Python 22 | run: uv python install 3.12 23 | 24 | - name: Install dependencies 25 | run: uv sync --all-extras 26 | 27 | - name: Build package 28 | run: uv build 29 | 30 | - name: Publish package distributions to PyPI 31 | uses: pypa/gh-action-pypi-publish@release/v1 32 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | python-version: 7 | required: true 8 | type: string 9 | coverage: 10 | required: false 11 | type: boolean 12 | default: false 13 | os: 14 | required: false 15 | type: string 16 | default: "ubuntu-latest" 17 | timeout: 18 | required: false 19 | type: number 20 | default: 60 21 | 22 | jobs: 23 | test: 24 | runs-on: ${{ inputs.os }} 25 | timeout-minutes: ${{ inputs.timeout }} 26 | defaults: 27 | run: 28 | shell: bash 29 | steps: 30 | - name: Check out repository 31 | uses: actions/checkout@v4 32 | 33 | - name: Install uv 34 | uses: astral-sh/setup-uv@v6 35 | 36 | - name: Set up Python 37 | run: uv python install ${{ inputs.python-version }} 38 | 39 | - name: Install dependencies 40 | run: uv sync --all-extras --dev 41 | 42 | - name: Set PYTHONPATH 43 | run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV 44 | 45 | - name: Test 46 | if: ${{ !inputs.coverage }} 47 | run: uv run pytest --dist "loadgroup" -m "" -n 2 48 | 49 | - name: Test with coverage 50 | if: ${{ inputs.coverage }} 51 | run: uv run pytest --dist "loadgroup" -m "" --cov=litestar_saq --cov-report=xml -n 2 52 | 53 | - uses: actions/upload-artifact@v4 54 | if: ${{ inputs.coverage }} 55 | with: 56 | name: coverage-xml 57 | path: coverage.xml 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | .vscode/ 166 | examples/tmp.py 167 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: "3" 3 | repos: 4 | - repo: https://github.com/compilerla/conventional-pre-commit 5 | rev: v4.1.0 6 | hooks: 7 | - id: conventional-pre-commit 8 | stages: [commit-msg] 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v5.0.0 11 | hooks: 12 | - id: check-ast 13 | - id: check-case-conflict 14 | - id: check-toml 15 | - id: debug-statements 16 | - id: end-of-file-fixer 17 | - id: mixed-line-ending 18 | - id: trailing-whitespace 19 | - repo: https://github.com/charliermarsh/ruff-pre-commit 20 | rev: "v0.11.8" 21 | hooks: 22 | # Run the linter. 23 | - id: ruff 24 | types_or: [ python, pyi ] 25 | args: [ --fix ] 26 | # Run the formatter. 27 | - id: ruff-format 28 | types_or: [ python, pyi ] 29 | - repo: https://github.com/codespell-project/codespell 30 | rev: v2.4.1 31 | hooks: 32 | - id: codespell 33 | exclude: "uv.lock|package.json|package-lock.json" 34 | additional_dependencies: 35 | - tomli 36 | - repo: https://github.com/sphinx-contrib/sphinx-lint 37 | rev: "v1.0.0" 38 | hooks: 39 | - id: sphinx-lint 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contribution guide 2 | ================== 3 | 4 | Setting up the environment 5 | -------------------------- 6 | 7 | 1. Install `PDM `_ 8 | 2. Run ``pdm install -G:all`` to create a `virtual environment `_ and install 9 | the dependencies 10 | 3. If you're working on the documentation and need to build it locally, install the extra dependencies with ``pdm install -G:docs`` 11 | 4. Install `pre-commit `_ 12 | 5. Run ``pre-commit install`` to install pre-commit hooks 13 | 14 | Code contributions 15 | ------------------ 16 | 17 | Workflow 18 | ++++++++ 19 | 20 | 1. `Fork `_ the `Advanced Alchemy repository `_ 21 | 2. Clone your fork locally with git 22 | 3. `Set up the environment <#setting-up-the-environment>`_ 23 | 4. Make your changes 24 | 5. (Optional) Run ``pre-commit run --all-files`` to run linters and formatters. This step is optional and will be executed 25 | automatically by git before you make a commit, but you may want to run it manually in order to apply fixes 26 | 6. Commit your changes to git 27 | 7. Push the changes to your fork 28 | 8. Open a `pull request `_. Give the pull request a descriptive title 29 | indicating what it changes. If it has a corresponding open issue, the issue number should be included in the title as 30 | well. For example a pull request that fixes issue ``bug: Increased stack size making it impossible to find needle #100`` 31 | could be titled ``fix(#100): Make needles easier to find by applying fire to haystack`` 32 | 33 | .. tip:: Pull requests and commits all need to follow the 34 | `Conventional Commit format `_ 35 | 36 | Guidelines for writing code 37 | ---------------------------- 38 | 39 | - All code should be fully `typed `_. This is enforced via 40 | `mypy `_. 41 | - All code should be tested. This is enforced via `pytest `_. 42 | - All code should be properly formatted. This is enforced via `black `_ and `Ruff `_. 43 | 44 | Writing and running tests 45 | +++++++++++++++++++++++++ 46 | 47 | .. todo:: Write this section 48 | 49 | Project documentation 50 | --------------------- 51 | 52 | The documentation is located in the ``/docs`` directory and is `ReST `_ and 53 | `Sphinx `_. If you're unfamiliar with any of those, 54 | `ReStructuredText primer `_ and 55 | `Sphinx quickstart `_ are recommended reads. 56 | 57 | Running the docs locally 58 | ++++++++++++++++++++++++ 59 | 60 | To run or build the docs locally, you need to first install the required dependencies: 61 | 62 | ``pdm install -G:docs`` 63 | 64 | Then you can serve the documentation with ``make docs-serve``, or build them with ``make docs``. 65 | 66 | Creating a new release 67 | ---------------------- 68 | 69 | 1. Increment the version in `pyproject.toml `_. 70 | .. note:: The version should follow `semantic versioning `_ and `PEP 440 `_. 71 | 2. `Draft a new release `_ on GitHub 72 | 73 | * Use ``vMAJOR.MINOR.PATCH`` (e.g. ``v1.2.3``) as both the tag and release title 74 | * Fill in the release description. You can use the "Generate release notes" function to get a draft for this 75 | 3. Commit your changes and push to ``main`` 76 | 4. Publish the release 77 | 5. Go to `Actions `_ and approve the release workflow 78 | 6. Check that the workflow runs successfully 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Cody Fincher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | # ============================================================================= 4 | # Configuration and Environment Variables 5 | # ============================================================================= 6 | 7 | .DEFAULT_GOAL:=help 8 | .ONESHELL: 9 | .EXPORT_ALL_VARIABLES: 10 | MAKEFLAGS += --no-print-directory 11 | 12 | # ----------------------------------------------------------------------------- 13 | # Display Formatting and Colors 14 | # ----------------------------------------------------------------------------- 15 | BLUE := $(shell printf "\033[1;34m") 16 | GREEN := $(shell printf "\033[1;32m") 17 | RED := $(shell printf "\033[1;31m") 18 | YELLOW := $(shell printf "\033[1;33m") 19 | NC := $(shell printf "\033[0m") 20 | INFO := $(shell printf "$(BLUE)ℹ$(NC)") 21 | OK := $(shell printf "$(GREEN)✓$(NC)") 22 | WARN := $(shell printf "$(YELLOW)⚠$(NC)") 23 | ERROR := $(shell printf "$(RED)✖$(NC)") 24 | 25 | # ============================================================================= 26 | # Help and Documentation 27 | # ============================================================================= 28 | 29 | .PHONY: help 30 | help: ## Display this help text for Makefile 31 | @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z0-9_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) 32 | 33 | # ============================================================================= 34 | # Installation and Environment Setup 35 | # ============================================================================= 36 | 37 | .PHONY: install-uv 38 | install-uv: ## Install latest version of uv 39 | @echo "${INFO} Installing uv..." 40 | @curl -LsSf https://astral.sh/uv/install.sh | sh >/dev/null 2>&1 41 | @echo "${OK} UV installed successfully" 42 | 43 | .PHONY: install 44 | install: destroy clean ## Install the project, dependencies, and pre-commit 45 | @echo "${INFO} Starting fresh installation..." 46 | @uv sync --all-extras --dev 47 | @echo "${OK} Installation complete! 🎉" 48 | 49 | .PHONY: destroy 50 | destroy: ## Destroy the virtual environment 51 | @echo "${INFO} Destroying virtual environment... 🗑️" 52 | @uv run pre-commit clean >/dev/null 2>&1 53 | @rm -rf .venv 54 | @echo "${OK} Virtual environment destroyed 🗑️" 55 | 56 | # ============================================================================= 57 | # Dependency Management 58 | # ============================================================================= 59 | 60 | .PHONY: upgrade 61 | upgrade: ## Upgrade all dependencies to latest stable versions 62 | @echo "${INFO} Updating all dependencies... 🔄" 63 | @uv lock --upgrade 64 | @echo "${OK} Dependencies updated 🔄" 65 | @uv run pre-commit autoupdate 66 | @echo "${OK} Updated Pre-commit hooks 🔄" 67 | 68 | .PHONY: lock 69 | lock: ## Rebuild lockfiles from scratch 70 | @echo "${INFO} Rebuilding lockfiles... 🔄" 71 | @uv lock --upgrade >/dev/null 2>&1 72 | @echo "${OK} Lockfiles updated" 73 | 74 | # ============================================================================= 75 | # Build and Release 76 | # ============================================================================= 77 | 78 | .PHONY: build 79 | build: ## Build the package 80 | @echo "${INFO} Building package... 📦" 81 | @uv build 82 | @echo "${OK} Package build complete" 83 | 84 | .PHONY: release 85 | release: ## Bump version and create release tag 86 | @echo "${INFO} Preparing for release... 📦" 87 | @make docs 88 | @make clean 89 | @make build 90 | @uv lock --upgrade-package litestar-saq 91 | @uv run bump-my-version bump $(bump) 92 | @echo "${OK} Release complete 🎉" 93 | 94 | # ============================================================================= 95 | # Cleaning and Maintenance 96 | # ============================================================================= 97 | 98 | .PHONY: clean 99 | clean: ## Cleanup temporary build artifacts 100 | @echo "${INFO} Cleaning working directory... 🧹" 101 | @rm -rf pytest_cache .ruff_cache .hypothesis build/ -rf dist/ .eggs/ .coverage coverage.xml coverage.json htmlcov/ .pytest_cache tests/.pytest_cache tests/**/.pytest_cache .mypy_cache .unasyncd_cache/ .auto_pytabs_cache node_modules >/dev/null 2>&1 102 | @find . -name '*.egg-info' -exec rm -rf {} + >/dev/null 2>&1 103 | @find . -type f -name '*.egg' -exec rm -f {} + >/dev/null 2>&1 104 | @find . -name '*.pyc' -exec rm -f {} + >/dev/null 2>&1 105 | @find . -name '*.pyo' -exec rm -f {} + >/dev/null 2>&1 106 | @find . -name '*~' -exec rm -f {} + >/dev/null 2>&1 107 | @find . -name '__pycache__' -exec rm -rf {} + >/dev/null 2>&1 108 | @find . -name '.ipynb_checkpoints' -exec rm -rf {} + >/dev/null 2>&1 109 | @echo "${OK} Working directory cleaned" 110 | 111 | # ============================================================================= 112 | # Testing and Quality Checks 113 | # ============================================================================= 114 | 115 | .PHONY: test 116 | test: ## Run the tests 117 | @echo "${INFO} Running test cases... 🧪" 118 | @uv run pytest -n 2 --quiet 119 | @echo "${OK} Tests passed ✨" 120 | 121 | .PHONY: coverage 122 | coverage: ## Run tests with coverage report 123 | @echo "${INFO} Running tests with coverage... 📊" 124 | @uv run pytest --cov -n auto --quiet 125 | @uv run coverage html >/dev/null 2>&1 126 | @uv run coverage xml >/dev/null 2>&1 127 | @echo "${OK} Coverage report generated ✨" 128 | 129 | # ----------------------------------------------------------------------------- 130 | # Type Checking 131 | # ----------------------------------------------------------------------------- 132 | 133 | .PHONY: mypy 134 | mypy: ## Run mypy 135 | @echo "${INFO} Running mypy... 🔍" 136 | @uv run dmypy run litestar_saq/ 137 | @echo "${OK} Mypy checks passed ✨" 138 | 139 | .PHONY: pyright 140 | pyright: ## Run pyright 141 | @echo "${INFO} Running pyright... 🔍" 142 | @uv run pyright 143 | @echo "${OK} Pyright checks passed ✨" 144 | 145 | .PHONY: type-check 146 | type-check: mypy pyright ## Run all type checking 147 | 148 | # ----------------------------------------------------------------------------- 149 | # Linting and Formatting 150 | # ----------------------------------------------------------------------------- 151 | 152 | .PHONY: pre-commit 153 | pre-commit: ## Run pre-commit hooks 154 | @echo "${INFO} Running pre-commit checks... 🔎" 155 | @NODE_OPTIONS="--no-deprecation --disable-warning=ExperimentalWarning" uv run pre-commit run --color=always --all-files 156 | @echo "${OK} Pre-commit checks passed ✨" 157 | 158 | .PHONY: slotscheck 159 | slotscheck: ## Run slotscheck 160 | @echo "${INFO} Running slots check... 🔍" 161 | @uv run slotscheck -m litestar_saq 162 | @echo "${OK} Slots check passed ✨" 163 | 164 | .PHONY: fix 165 | fix: ## Run code formatters 166 | @echo "${INFO} Running code formatters... 🔧" 167 | @uv run ruff check --fix --unsafe-fixes 168 | @echo "${OK} Code formatting complete ✨" 169 | 170 | .PHONY: lint 171 | lint: pre-commit type-check slotscheck ## Run all linting checks 172 | 173 | .PHONY: check-all 174 | check-all: lint test coverage ## Run all checks (lint, test, coverage) 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Litestar SAQ 2 | 3 | ## Installation 4 | 5 | ```shell 6 | pip install litestar-saq 7 | ``` 8 | 9 | ## Usage 10 | 11 | Here is a basic application that demonstrates how to use the plugin. 12 | 13 | ```python 14 | from __future__ import annotations 15 | 16 | from litestar import Litestar 17 | 18 | from litestar_saq import QueueConfig, SAQConfig, SAQPlugin 19 | 20 | saq = SAQPlugin(config=SAQConfig(dsn="redis://localhost:6397/0", queue_configs=[QueueConfig(name="samples")])) 21 | app = Litestar(plugins=[saq]) 22 | 23 | 24 | ``` 25 | 26 | You can start a background worker with the following command now: 27 | 28 | ```shell 29 | litestar --app-dir=examples/ --app basic:app workers run 30 | Using Litestar app from env: 'basic:app' 31 | Starting SAQ Workers ────────────────────────────────────────────────────────────────── 32 | INFO - 2023-10-04 17:39:03,255 - saq - worker - Worker starting: Queue>>, name='samples'> 33 | INFO - 2023-10-04 17:39:06,545 - saq - worker - Worker shutting down 34 | ``` 35 | 36 | You can also start the process for only specific queues. This is helpful if you want separated processes working on different queues instead of combining them. 37 | 38 | ```shell 39 | litestar --app-dir=examples/ --app basic:app workers run --queues sample 40 | Using Litestar app from env: 'basic:app' 41 | Starting SAQ Workers ────────────────────────────────────────────────────────────────── 42 | INFO - 2023-10-04 17:39:03,255 - saq - worker - Worker starting: Queue>>, name='samples'> 43 | INFO - 2023-10-04 17:39:06,545 - saq - worker - Worker shutting down 44 | ``` 45 | 46 | If you are starting the process for only specific queues and still want to read from the other queues or enqueue a task into another queue that was not initialized in your worker or is found somewhere else, you can do so like here 47 | 48 | ```python 49 | import os 50 | from saq import Queue 51 | 52 | 53 | def get_queue_directly(queue_name: str, redis_url: str) -> Queue: 54 | return Queue.from_url(redis_url, name=queue_name) 55 | 56 | redis_url = os.getenv("REDIS_URL") 57 | queue = get_queue_directly("queue-in-other-process", redis_url) 58 | # Get queue info 59 | info = await queue.info(jobs=True) 60 | # Enqueue new task 61 | queue.enqueue( 62 | .... 63 | ) 64 | ``` 65 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cofin/litestar-saq/329934de25c5a3f7f3459fd8e8d2ca7013d6fe50/examples/__init__.py -------------------------------------------------------------------------------- /examples/basic.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from litestar import Controller, Litestar, get 6 | 7 | from examples import tasks 8 | from litestar_saq import CronJob, QueueConfig, SAQConfig, SAQPlugin 9 | 10 | if TYPE_CHECKING: 11 | from saq.types import QueueInfo 12 | 13 | from litestar_saq.config import TaskQueues 14 | 15 | 16 | class SampleController(Controller): 17 | @get(path="/samples") 18 | async def samples_queue_info(self, task_queues: TaskQueues) -> QueueInfo: 19 | """Check database available and returns app config info.""" 20 | queue = task_queues.get("samples") 21 | return await queue.info() 22 | 23 | 24 | saq = SAQPlugin( 25 | config=SAQConfig( 26 | web_enabled=True, 27 | use_server_lifespan=True, 28 | queue_configs=[ 29 | QueueConfig( 30 | dsn="redis://localhost:6397/0", 31 | name="samples", 32 | tasks=[tasks.background_worker_task, tasks.system_task, tasks.system_upkeep], 33 | scheduled_tasks=[CronJob(function=tasks.system_upkeep, cron="* * * * *", timeout=600, ttl=2000)], 34 | ), 35 | ], 36 | ), 37 | ) 38 | app = Litestar(plugins=[saq], route_handlers=[SampleController]) 39 | -------------------------------------------------------------------------------- /examples/postgres.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from logging import getLogger 5 | from typing import TYPE_CHECKING 6 | 7 | from litestar import Controller, Litestar, get 8 | 9 | from examples import tasks 10 | from litestar_saq import CronJob, QueueConfig, SAQConfig, SAQPlugin 11 | 12 | if TYPE_CHECKING: 13 | from saq.types import Context, QueueInfo 14 | 15 | from litestar_saq.config import TaskQueues 16 | 17 | logger = getLogger(__name__) 18 | 19 | 20 | async def system_upkeep(_: Context) -> None: 21 | logger.info("Performing system upkeep operations.") 22 | logger.info("Simulating a long running operation. Sleeping for 60 seconds.") 23 | await asyncio.sleep(3) 24 | logger.info("Simulating an even longer running operation. Sleeping for 120 seconds.") 25 | await asyncio.sleep(3) 26 | logger.info("Long running process complete.") 27 | logger.info("Performing system upkeep operations.") 28 | 29 | 30 | async def background_worker_task(_: Context) -> None: 31 | logger.info("Performing background worker task.") 32 | await asyncio.sleep(1) 33 | logger.info("Performing system upkeep operations.") 34 | 35 | 36 | async def system_task(_: Context) -> None: 37 | logger.info("Performing simple system task") 38 | await asyncio.sleep(2) 39 | logger.info("System task complete.") 40 | 41 | 42 | class SampleController(Controller): 43 | @get(path="/samples") 44 | async def samples_queue_info(self, task_queues: TaskQueues) -> QueueInfo: 45 | queue = task_queues.get("samples") 46 | return await queue.info() 47 | 48 | 49 | saq = SAQPlugin( 50 | config=SAQConfig( 51 | web_enabled=True, 52 | use_server_lifespan=True, 53 | queue_configs=[ 54 | QueueConfig( 55 | dsn="postgresql://app:app@localhost:15432/app", 56 | tasks=[tasks.background_worker_task, tasks.system_task, tasks.system_upkeep], 57 | scheduled_tasks=[CronJob(function=tasks.system_upkeep, cron="* * * * *", timeout=600, ttl=2000)], 58 | ) 59 | ], 60 | ), 61 | ) 62 | app = Litestar(plugins=[saq], route_handlers=[SampleController]) 63 | -------------------------------------------------------------------------------- /examples/tasks.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from logging import getLogger 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from saq.types import Context 9 | 10 | logger = getLogger(__name__) 11 | 12 | 13 | async def system_upkeep(_: Context) -> None: 14 | logger.info("Performing system upkeep operations.") 15 | logger.info("Simulating a long running operation. Sleeping for 60 seconds.") 16 | await asyncio.sleep(60) 17 | logger.info("Simulating an even longer running operation. Sleeping for 120 seconds.") 18 | await asyncio.sleep(120) 19 | logger.info("Long running process complete.") 20 | logger.info("Performing system upkeep operations.") 21 | 22 | 23 | async def background_worker_task(_: Context) -> None: 24 | logger.info("Performing background worker task.") 25 | await asyncio.sleep(20) 26 | logger.info("Performing system upkeep operations.") 27 | 28 | 29 | async def system_task(_: Context) -> None: 30 | logger.info("Performing simple system task") 31 | await asyncio.sleep(2) 32 | logger.info("System task complete.") 33 | -------------------------------------------------------------------------------- /litestar_saq/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from litestar_saq.base import CronJob, Job, Worker 4 | from litestar_saq.config import PostgresQueueOptions, QueueConfig, RedisQueueOptions, SAQConfig, TaskQueues 5 | from litestar_saq.plugin import SAQPlugin 6 | 7 | __all__ = ( 8 | "CronJob", 9 | "Job", 10 | "PostgresQueueOptions", 11 | "QueueConfig", 12 | "RedisQueueOptions", 13 | "SAQConfig", 14 | "SAQPlugin", 15 | "TaskQueues", 16 | "Worker", 17 | ) 18 | -------------------------------------------------------------------------------- /litestar_saq/__metadata__.py: -------------------------------------------------------------------------------- 1 | """Metadata for the Project.""" 2 | 3 | from __future__ import annotations 4 | 5 | from importlib.metadata import PackageNotFoundError, metadata, version 6 | 7 | __all__ = ("__project__", "__version__") 8 | 9 | try: 10 | __version__ = version("litestar_saq") 11 | """Version of the project.""" 12 | __project__ = metadata("litestar_saq")["Name"] 13 | """Name of the project.""" 14 | except PackageNotFoundError: # pragma: no cover 15 | __version__ = "0.0.0" 16 | __project__ = "Litestar SAQ" 17 | finally: 18 | del version, PackageNotFoundError, metadata 19 | -------------------------------------------------------------------------------- /litestar_saq/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass, field 3 | from datetime import timezone, tzinfo 4 | from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast 5 | 6 | from litestar.utils.module_loader import import_string 7 | from saq import Job as SaqJob 8 | from saq.job import CronJob as SaqCronJob 9 | from saq.worker import Worker as SaqWorker 10 | 11 | if TYPE_CHECKING: 12 | from collections.abc import Collection 13 | 14 | from saq.queue.base import Queue 15 | from saq.types import Function, PartialTimersDict, ReceivesContext 16 | 17 | JsonDict = dict[str, Any] 18 | 19 | 20 | @dataclass 21 | class Job(SaqJob): 22 | """Job Details""" 23 | 24 | 25 | @dataclass 26 | class CronJob(SaqCronJob): 27 | """Cron Job Details""" 28 | 29 | function: "Union[Function, str]" # type: ignore[assignment] 30 | meta: "dict[str, Any]" = field(default_factory=dict) # pyright: ignore 31 | 32 | def __post_init__(self) -> None: 33 | self.function = self._get_or_import_function(self.function) # pyright: ignore[reportIncompatibleMethodOverride] 34 | 35 | @staticmethod 36 | def _get_or_import_function(function_or_import_string: "Union[str, Function]") -> "Function": 37 | if isinstance(function_or_import_string, str): 38 | return cast("Function", import_string(function_or_import_string)) 39 | return function_or_import_string 40 | 41 | 42 | class Worker(SaqWorker): 43 | """Worker.""" 44 | 45 | def __init__( 46 | self, 47 | queue: "Queue", 48 | functions: "Collection[Union[Function, tuple[str, Function]]]", 49 | *, 50 | id: "Optional[str]" = None, # noqa: A002 51 | concurrency: int = 10, 52 | cron_jobs: "Optional[Collection[CronJob]]" = None, 53 | cron_tz: "tzinfo" = timezone.utc, 54 | startup: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None, 55 | shutdown: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None, 56 | before_process: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None, 57 | after_process: "Optional[Union[ReceivesContext, Collection[ReceivesContext]]]" = None, 58 | timers: "Optional[PartialTimersDict]" = None, 59 | dequeue_timeout: float = 0, 60 | burst: bool = False, 61 | max_burst_jobs: "Optional[int]" = None, 62 | metadata: "Optional[JsonDict]" = None, 63 | separate_process: bool = True, 64 | multiprocessing_mode: Literal["multiprocessing", "threading"] = "multiprocessing", 65 | ) -> None: 66 | self.separate_process = separate_process 67 | self.multiprocessing_mode = multiprocessing_mode 68 | super().__init__( 69 | queue, 70 | functions, 71 | id=id, 72 | concurrency=concurrency, 73 | cron_jobs=cron_jobs, 74 | cron_tz=cron_tz, 75 | startup=startup, 76 | shutdown=shutdown, 77 | before_process=before_process, 78 | after_process=after_process, 79 | timers=timers, 80 | dequeue_timeout=dequeue_timeout, 81 | burst=burst, 82 | max_burst_jobs=max_burst_jobs, 83 | metadata=metadata, 84 | ) 85 | 86 | async def on_app_startup(self) -> None: 87 | """Attach the worker to the running event loop.""" 88 | if not self.separate_process: 89 | self.SIGNALS = [] 90 | loop = asyncio.get_running_loop() 91 | self._saq_asyncio_tasks = loop.create_task(self.start()) 92 | 93 | async def on_app_shutdown(self) -> None: 94 | """Attach the worker to the running event loop.""" 95 | if not self.separate_process: 96 | loop = asyncio.get_running_loop() 97 | self._saq_asyncio_tasks = loop.create_task(self.stop()) 98 | -------------------------------------------------------------------------------- /litestar_saq/cli.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional 2 | 3 | if TYPE_CHECKING: 4 | from click import Group 5 | from litestar import Litestar 6 | from litestar.logging.config import BaseLoggingConfig 7 | 8 | from litestar_saq.base import Worker 9 | from litestar_saq.plugin import SAQPlugin 10 | 11 | 12 | def build_cli_app() -> "Group": # noqa: C901 13 | import asyncio 14 | import multiprocessing 15 | import platform 16 | from typing import cast 17 | 18 | from click import IntRange, group, option 19 | from litestar.cli._utils import LitestarGroup, console # pyright: ignore 20 | 21 | @group(cls=LitestarGroup, name="workers", no_args_is_help=True) 22 | def background_worker_group() -> None: 23 | """Manage background task workers.""" 24 | 25 | @background_worker_group.command( 26 | name="run", 27 | help="Run background worker processes.", 28 | ) 29 | @option( 30 | "--workers", 31 | help="The number of worker processes to start.", 32 | type=IntRange(min=1), 33 | default=1, 34 | required=False, 35 | show_default=True, 36 | ) 37 | @option( 38 | "--queues", 39 | help="List of queue names to process.", 40 | type=str, 41 | multiple=True, 42 | required=False, 43 | show_default=False, 44 | ) 45 | @option("-v", "--verbose", help="Enable verbose logging.", is_flag=True, default=None, type=bool, required=False) 46 | @option("-d", "--debug", help="Enable debugging.", is_flag=True, default=None, type=bool, required=False) 47 | def run_worker( # pyright: ignore[reportUnusedFunction] 48 | app: "Litestar", 49 | workers: int, 50 | queues: "Optional[tuple[str, ...]]", 51 | verbose: "Optional[bool]", 52 | debug: "Optional[bool]", 53 | ) -> None: 54 | """Run the API server.""" 55 | console.rule("[yellow]Starting SAQ Workers[/]", align="left") 56 | if platform.system() == "Darwin": 57 | multiprocessing.set_start_method("fork", force=True) 58 | 59 | if app.logging_config is not None: 60 | app.logging_config.configure() 61 | if debug is not None or verbose is not None: 62 | app.debug = True 63 | plugin = get_saq_plugin(app) 64 | if queues: 65 | queue_list = list(queues) 66 | limited_start_up(plugin, queue_list) 67 | show_saq_info(app, workers, plugin) 68 | managed_workers = list(plugin.get_workers().values()) 69 | processes: list[multiprocessing.Process] = [] 70 | if workers > 1: 71 | for _ in range(workers - 1): 72 | for worker in managed_workers: 73 | p = multiprocessing.Process( 74 | target=run_saq_worker, 75 | args=( 76 | worker, 77 | app.logging_config, 78 | ), 79 | ) 80 | p.start() 81 | processes.append(p) 82 | 83 | if len(managed_workers) > 1: 84 | for j in range(len(managed_workers) - 1): 85 | p = multiprocessing.Process(target=run_saq_worker, args=(managed_workers[j + 1], app.logging_config)) 86 | p.start() 87 | processes.append(p) 88 | 89 | try: 90 | run_saq_worker( 91 | worker=managed_workers[0], 92 | logging_config=cast("BaseLoggingConfig", app.logging_config), 93 | ) 94 | except KeyboardInterrupt: 95 | loop = asyncio.get_event_loop() 96 | for w in managed_workers: 97 | loop.run_until_complete(w.stop()) 98 | console.print("[yellow]SAQ workers stopped.[/]") 99 | 100 | @background_worker_group.command( 101 | name="status", 102 | help="Check the status of currently configured workers and queues.", 103 | ) 104 | @option("-v", "--verbose", help="Enable verbose logging.", is_flag=True, default=None, type=bool, required=False) 105 | @option("-d", "--debug", help="Enable debugging.", is_flag=True, default=None, type=bool, required=False) 106 | def worker_status( # pyright: ignore[reportUnusedFunction] 107 | app: "Litestar", 108 | verbose: "Optional[bool]", 109 | debug: "Optional[bool]", 110 | ) -> None: 111 | """Check the status of currently configured workers and queues.""" 112 | console.rule("[yellow]Checking SAQ worker status[/]", align="left") 113 | if app.logging_config is not None: 114 | app.logging_config.configure() 115 | if debug is not None or verbose is not None: 116 | app.debug = True 117 | plugin = get_saq_plugin(app) 118 | show_saq_info(app, plugin.config.worker_processes, plugin) 119 | 120 | return background_worker_group 121 | 122 | 123 | def limited_start_up(plugin: "SAQPlugin", queues: "list[str]") -> None: 124 | """Reset the workers and include only the specified queues.""" 125 | plugin.remove_workers() 126 | plugin.config.filter_delete_queues(queues) 127 | 128 | 129 | def get_saq_plugin(app: "Litestar") -> "SAQPlugin": 130 | """Retrieve a SAQ plugin from the Litestar application's plugins. 131 | 132 | This function attempts to find a SAQ plugin instance. 133 | If plugin is not found, it raises an ImproperlyConfiguredException. 134 | 135 | Args: 136 | app: The Litestar application instance. 137 | 138 | Returns: 139 | The SAQ plugin instance. 140 | 141 | Raises: 142 | ImproperConfigurationError: If the SAQ plugin is not found. 143 | """ 144 | from contextlib import suppress 145 | 146 | from litestar_saq.exceptions import ImproperConfigurationError 147 | from litestar_saq.plugin import SAQPlugin 148 | 149 | with suppress(KeyError): 150 | return app.plugins.get(SAQPlugin) 151 | msg = "Failed to initialize SAQ. The required plugin (SAQPlugin) is missing." 152 | raise ImproperConfigurationError( 153 | msg, 154 | ) 155 | 156 | 157 | def show_saq_info(app: "Litestar", workers: int, plugin: "SAQPlugin") -> None: # pragma: no cover 158 | """Display basic information about the application and its configuration.""" 159 | 160 | from litestar.cli._utils import _format_is_enabled, console # pyright: ignore 161 | from rich.table import Table 162 | from saq import __version__ as saq_version 163 | 164 | table = Table(show_header=False) 165 | table.add_column("title", style="cyan") 166 | table.add_column("value", style="bright_blue") 167 | 168 | table.add_row("SAQ version", saq_version) 169 | table.add_row("Debug mode", _format_is_enabled(app.debug)) 170 | table.add_row("Number of Processes", str(workers)) 171 | table.add_row("Queues", str(len(plugin.config.queue_configs))) 172 | 173 | console.print(table) 174 | 175 | 176 | def run_saq_worker(worker: "Worker", logging_config: "Optional[BaseLoggingConfig]") -> None: 177 | """Run a worker.""" 178 | import asyncio 179 | 180 | loop = asyncio.get_event_loop() 181 | if logging_config is not None: 182 | logging_config.configure() 183 | 184 | async def worker_start(w: "Worker") -> None: 185 | try: 186 | await w.queue.connect() 187 | await w.start() 188 | finally: 189 | await w.queue.disconnect() 190 | 191 | try: 192 | if worker.separate_process: 193 | loop.run_until_complete(loop.create_task(worker_start(worker))) 194 | except KeyboardInterrupt: 195 | loop.run_until_complete(loop.create_task(worker.stop())) 196 | -------------------------------------------------------------------------------- /litestar_saq/config.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator, Collection, Mapping 2 | from dataclasses import dataclass, field 3 | from datetime import timezone, tzinfo 4 | from pathlib import Path 5 | from typing import TYPE_CHECKING, Literal, Optional, TypedDict, TypeVar, Union, cast 6 | 7 | from litestar.exceptions import ImproperlyConfiguredException 8 | from litestar.serialization import decode_json, encode_json 9 | from litestar.utils.module_loader import import_string, module_to_os_path 10 | from saq.queue.base import Queue 11 | from saq.types import DumpType, LoadType, PartialTimersDict, QueueInfo, ReceivesContext, WorkerInfo 12 | from typing_extensions import NotRequired 13 | 14 | from litestar_saq.base import CronJob, Job, JsonDict, Worker 15 | 16 | if TYPE_CHECKING: 17 | from typing import Any 18 | 19 | from litestar.types.callable_types import Guard # pyright: ignore[reportUnknownVariableType] 20 | from saq.types import Function 21 | 22 | T = TypeVar("T") 23 | 24 | 25 | def serializer(value: "Any") -> str: 26 | """Serialize JSON field values. 27 | 28 | Args: 29 | value: Any json serializable value. 30 | 31 | Returns: 32 | JSON string. 33 | 34 | """ 35 | return encode_json(value).decode("utf-8") 36 | 37 | 38 | def _get_static_files() -> Path: 39 | return Path(module_to_os_path("saq") / "web" / "static") 40 | 41 | 42 | @dataclass 43 | class TaskQueues: 44 | """Task queues.""" 45 | 46 | queues: "Mapping[str, Queue]" = field(default_factory=dict) # pyright: ignore 47 | 48 | def get(self, name: str) -> "Queue": 49 | """Get a queue by name. 50 | 51 | Args: 52 | name: The name of the queue. 53 | 54 | Returns: 55 | The queue. 56 | 57 | Raises: 58 | ImproperlyConfiguredException: If the queue does not exist. 59 | 60 | """ 61 | queue = self.queues.get(name) 62 | if queue is not None: 63 | return queue 64 | msg = "Could not find the specified queue. Please check your configuration." 65 | raise ImproperlyConfiguredException(msg) 66 | 67 | 68 | @dataclass 69 | class SAQConfig: 70 | """SAQ Configuration.""" 71 | 72 | queue_configs: "Collection[QueueConfig]" = field(default_factory=list) # pyright: ignore 73 | """Configuration for Queues""" 74 | 75 | queue_instances: "Optional[Mapping[str, Queue]]" = None 76 | """Current configured queue instances. When None, queues will be auto-created on startup""" 77 | queues_dependency_key: str = field(default="task_queues") 78 | """Key to use for storing dependency information in litestar.""" 79 | worker_processes: int = 1 80 | """The number of worker processes to spawn. 81 | 82 | Default is set to 1. 83 | """ 84 | 85 | json_deserializer: "LoadType" = decode_json 86 | """This is a Python callable that will 87 | convert a JSON string to a Python object. By default, this is set to Litestar's 88 | :attr:`decode_json() <.serialization.decode_json>` function.""" 89 | json_serializer: "DumpType" = serializer 90 | """This is a Python callable that will render a given object as JSON. 91 | By default, Litestar's :attr:`encode_json() <.serialization.encode_json>` is used.""" 92 | static_files: Path = field(default_factory=_get_static_files) 93 | """Location of the static files to serve for the SAQ UI""" 94 | web_enabled: bool = False 95 | """If true, the worker admin UI is launched on worker startup..""" 96 | web_path: str = "/saq" 97 | """Base path to serve the SAQ web UI""" 98 | web_guards: "Optional[list[Guard]]" = field(default=None) 99 | """Guards to apply to web endpoints.""" 100 | web_include_in_schema: bool = False 101 | """Include Queue API endpoints in generated OpenAPI schema""" 102 | use_server_lifespan: bool = False 103 | """Utilize the server lifespan hook to run SAQ.""" 104 | 105 | @property 106 | def signature_namespace(self) -> "dict[str, Any]": 107 | """Return the plugin's signature namespace. 108 | 109 | Returns: 110 | A string keyed dict of names to be added to the namespace for signature forward reference resolution. 111 | """ 112 | return { 113 | "Queue": Queue, 114 | "Worker": Worker, 115 | "QueueInfo": QueueInfo, 116 | "WorkerInfo": WorkerInfo, 117 | "Job": Job, 118 | "TaskQueues": TaskQueues, 119 | } 120 | 121 | async def provide_queues(self) -> "AsyncGenerator[TaskQueues, None]": 122 | """Provide the configured job queues. 123 | 124 | Yields: 125 | The configured job queues. 126 | """ 127 | queues = self.get_queues() 128 | for queue in queues.queues.values(): 129 | await queue.connect() 130 | yield queues 131 | 132 | def filter_delete_queues(self, queues: "list[str]") -> None: 133 | """Remove all queues except the ones in the given list.""" 134 | new_config = [queue_config for queue_config in self.queue_configs if queue_config.name in queues] 135 | self.queue_configs = new_config 136 | if self.queue_instances is not None: 137 | for queue_name in dict(self.queue_instances): 138 | if queue_name not in queues: 139 | del self.queue_instances[queue_name] # type: ignore 140 | 141 | def get_queues(self) -> "TaskQueues": 142 | """Get the configured SAQ queues. 143 | 144 | Returns: 145 | The configured job queues. 146 | """ 147 | if self.queue_instances is not None: 148 | return TaskQueues(queues=self.queue_instances) 149 | 150 | self.queue_instances = {} 151 | for c in self.queue_configs: 152 | self.queue_instances[c.name] = c.queue_class( # type: ignore 153 | c.get_broker(), 154 | name=c.name, # pyright: ignore[reportCallIssue] 155 | dump=self.json_serializer, 156 | load=self.json_deserializer, 157 | **c._broker_options, # pyright: ignore[reportArgumentType,reportPrivateUsage] # noqa: SLF001 158 | ) 159 | self.queue_instances[c.name]._is_pool_provided = False # type: ignore # noqa: SLF001 160 | return TaskQueues(queues=self.queue_instances) 161 | 162 | 163 | class RedisQueueOptions(TypedDict, total=False): 164 | """Options for the Redis backend.""" 165 | 166 | max_concurrent_ops: NotRequired[int] 167 | """Maximum concurrent operations. (default 15) 168 | This throttles calls to `enqueue`, `job`, and `abort` to prevent the Queue 169 | from consuming too many Redis connections.""" 170 | swept_error_message: NotRequired[str] 171 | 172 | 173 | class PostgresQueueOptions(TypedDict, total=False): 174 | """Options for the Postgres backend.""" 175 | 176 | versions_table: NotRequired[str] 177 | jobs_table: NotRequired[str] 178 | stats_table: NotRequired[str] 179 | min_size: NotRequired[int] 180 | max_size: NotRequired[int] 181 | saq_lock_keyspace: NotRequired[int] 182 | job_lock_keyspace: NotRequired[int] 183 | job_lock_sweep: NotRequired[bool] 184 | priorities: NotRequired[tuple[int, int]] 185 | swept_error_message: NotRequired[str] 186 | manage_pool_lifecycle: NotRequired[bool] 187 | 188 | 189 | @dataclass 190 | class QueueConfig: 191 | """SAQ Queue Configuration""" 192 | 193 | dsn: "Optional[str]" = None 194 | """DSN for connecting to backend. e.g. 'redis://...' or 'postgres://...'. 195 | """ 196 | 197 | broker_instance: "Optional[Any]" = None 198 | """An instance of a supported saq backend connection.. 199 | """ 200 | id: "Optional[str]" = None 201 | """An optional ID to supply for the worker.""" 202 | name: str = "default" 203 | """The name of the queue to create.""" 204 | concurrency: int = 10 205 | """Number of jobs to process concurrently.""" 206 | broker_options: "Union[RedisQueueOptions, PostgresQueueOptions, dict[str, Any]]" = field(default_factory=dict) # pyright: ignore 207 | """Broker-specific options. For Redis or Postgres backends.""" 208 | tasks: "Collection[Union[ReceivesContext, tuple[str, Function], str]]" = field(default_factory=list) # pyright: ignore 209 | """Allowed list of functions to execute in this queue.""" 210 | scheduled_tasks: "Collection[CronJob]" = field(default_factory=list) # pyright: ignore 211 | """Scheduled cron jobs to execute in this queue.""" 212 | cron_tz: "tzinfo" = timezone.utc 213 | """Timezone for cron jobs.""" 214 | startup: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None 215 | """Async callable to call on startup.""" 216 | shutdown: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None 217 | """Async callable to call on shutdown.""" 218 | before_process: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None 219 | """Async callable to call before a job processes.""" 220 | after_process: "Optional[Union[ReceivesContext, str, Collection[Union[ReceivesContext, str]]]]" = None 221 | """Async callable to call after a job processes.""" 222 | timers: "Optional[PartialTimersDict]" = None 223 | """Dict with various timer overrides in seconds 224 | schedule: how often we poll to schedule jobs 225 | stats: how often to update stats 226 | sweep: how often to clean up stuck jobs 227 | abort: how often to check if a job is aborted""" 228 | dequeue_timeout: float = 0 229 | """How long to wait to dequeue.""" 230 | burst: bool = False 231 | """If True, the worker will process jobs in burst mode.""" 232 | max_burst_jobs: "Optional[int]" = None 233 | """The maximum number of jobs to process in burst mode.""" 234 | metadata: "Optional[JsonDict]" = None 235 | """Arbitrary data to pass to the worker which it will register with saq.""" 236 | multiprocessing_mode: 'Literal["multiprocessing", "threading"]' = "multiprocessing" 237 | """Executes with the multiprocessing or threading backend. Multi-processing is recommended and how SAQ is designed to work.""" 238 | separate_process: bool = True 239 | """Executes as a separate event loop when True. 240 | Set it False to execute within the Litestar application.""" 241 | 242 | def __post_init__(self) -> None: 243 | if self.dsn and self.broker_instance: 244 | msg = "Cannot specify both `dsn` and `broker_instance`" 245 | raise ImproperlyConfiguredException(msg) 246 | if not self.dsn and not self.broker_instance: 247 | msg = "Must specify either `dsn` or `broker_instance`" 248 | raise ImproperlyConfiguredException(msg) 249 | self.tasks = [self._get_or_import_task(task) for task in self.tasks] 250 | if self.startup is not None and not isinstance(self.startup, Collection): 251 | self.startup = [self.startup] 252 | if self.shutdown is not None and not isinstance(self.shutdown, Collection): 253 | self.shutdown = [self.shutdown] 254 | if self.before_process is not None and not isinstance(self.before_process, Collection): 255 | self.before_process = [self.before_process] 256 | if self.after_process is not None and not isinstance(self.after_process, Collection): 257 | self.after_process = [self.after_process] 258 | self.startup = [self._get_or_import_task(task) for task in self.startup or []] # pyright: ignore 259 | self.shutdown = [self._get_or_import_task(task) for task in self.shutdown or []] # pyright: ignore 260 | self.before_process = [self._get_or_import_task(task) for task in self.before_process or []] # pyright: ignore 261 | self.after_process = [self._get_or_import_task(task) for task in self.after_process or []] # pyright: ignore 262 | self._broker_type: Optional[Literal["redis", "postgres", "http"]] = None 263 | self._queue_class: Optional[type[Queue]] = None 264 | 265 | def get_broker(self) -> "Any": 266 | """Get the configured Broker connection. 267 | 268 | Raises: 269 | ImproperlyConfiguredException: If the broker type is invalid. 270 | 271 | Returns: 272 | Dictionary of queues. 273 | """ 274 | 275 | if self.broker_instance is not None: 276 | return self.broker_instance 277 | 278 | if self.dsn and self.dsn.startswith("redis"): 279 | from redis.asyncio import from_url as redis_from_url # pyright: ignore[reportUnknownVariableType] 280 | from saq.queue.redis import RedisQueue 281 | 282 | self.broker_instance = redis_from_url(self.dsn) 283 | self._broker_type = "redis" 284 | self._queue_class = RedisQueue 285 | elif self.dsn and self.dsn.startswith("postgresql"): 286 | from psycopg_pool import AsyncConnectionPool 287 | from saq.queue.postgres import PostgresQueue 288 | 289 | self.broker_instance = AsyncConnectionPool(self.dsn, check=AsyncConnectionPool.check_connection, open=False) 290 | self._broker_type = "postgres" 291 | self._queue_class = PostgresQueue 292 | elif self.dsn and self.dsn.startswith("http"): 293 | from saq.queue.http import HttpQueue 294 | 295 | self.broker_instance = HttpQueue(self.dsn) 296 | self._broker_type = "http" 297 | self._queue_class = HttpQueue 298 | else: 299 | msg = "Invalid broker type" 300 | raise ImproperlyConfiguredException(msg) 301 | return self.broker_instance 302 | 303 | @property 304 | def broker_type(self) -> 'Literal["redis", "postgres", "http"]': 305 | """Type of broker to use. 306 | 307 | Raises: 308 | ImproperlyConfiguredException: If the broker type is invalid. 309 | 310 | Returns: 311 | The broker type. 312 | """ 313 | if self._broker_type is None and self.broker_instance is not None: 314 | if self.broker_instance.__class__.__name__ == "AsyncConnectionPool": 315 | self._broker_type = "postgres" 316 | elif self.broker_instance.__class__.__name__ == "Redis": 317 | self._broker_type = "redis" 318 | elif self.broker_instance.__class__.__name__ == "HttpQueue": 319 | self._broker_type = "http" 320 | if self._broker_type is None: 321 | self.get_broker() 322 | if self._broker_type is None: 323 | msg = "Invalid broker type" 324 | raise ImproperlyConfiguredException(msg) 325 | return self._broker_type 326 | 327 | @property 328 | def _broker_options(self) -> "Union[RedisQueueOptions, PostgresQueueOptions, dict[str, Any]]": 329 | """Broker-specific options. 330 | 331 | Returns: 332 | The broker options. 333 | """ 334 | if self._broker_type == "postgres" and "manage_pool_lifecycle" not in self.broker_options: 335 | self.broker_options["manage_pool_lifecycle"] = True # type: ignore[typeddict-unknown-key] 336 | return self.broker_options 337 | 338 | @property 339 | def queue_class(self) -> "type[Queue]": 340 | """Type of queue to use. 341 | 342 | Raises: 343 | ImproperlyConfiguredException: If the queue class is invalid. 344 | 345 | Returns: 346 | The queue class. 347 | """ 348 | if self._queue_class is None and self.broker_instance is not None: 349 | if self.broker_instance.__class__.__name__ == "AsyncConnectionPool": 350 | from saq.queue.postgres import PostgresQueue 351 | 352 | self._queue_class = PostgresQueue 353 | elif self.broker_instance.__class__.__name__ == "Redis": 354 | from saq.queue.redis import RedisQueue 355 | 356 | self._queue_class = RedisQueue 357 | elif self.broker_instance.__class__.__name__ == "HttpQueue": 358 | from saq.queue.http import HttpQueue 359 | 360 | self._queue_class = HttpQueue 361 | if self._queue_class is None: 362 | self.get_broker() 363 | if self._queue_class is None: 364 | msg = "Invalid queue class" 365 | raise ImproperlyConfiguredException(msg) 366 | return self._queue_class 367 | 368 | @staticmethod 369 | def _get_or_import_task( 370 | task_or_import_string: "Union[str, tuple[str, Function], ReceivesContext]", 371 | ) -> "ReceivesContext": 372 | """Get or import a task. 373 | 374 | Args: 375 | task_or_import_string: The task or import string. 376 | 377 | Returns: 378 | The task. 379 | """ 380 | if isinstance(task_or_import_string, str): 381 | return cast("ReceivesContext", import_string(task_or_import_string)) 382 | if isinstance(task_or_import_string, tuple): 383 | return task_or_import_string[1] # pyright: ignore 384 | return task_or_import_string 385 | -------------------------------------------------------------------------------- /litestar_saq/controllers.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: PLR6301 2 | from functools import lru_cache 3 | from typing import TYPE_CHECKING, Any, Optional, cast 4 | 5 | from litestar.exceptions import NotFoundException 6 | 7 | if TYPE_CHECKING: 8 | from litestar import Controller 9 | from litestar.types.callable_types import Guard # pyright: ignore[reportUnknownVariableType] 10 | from saq.queue.base import Queue 11 | from saq.types import QueueInfo 12 | 13 | from litestar_saq.base import Job 14 | from litestar_saq.config import TaskQueues 15 | 16 | 17 | async def job_info(queue: "Queue", job_id: str) -> "Job": 18 | job = await queue.job(job_id) 19 | if not job: 20 | msg = f"Could not find job ID {job_id}" 21 | raise NotFoundException(msg) 22 | return cast("Job", job) 23 | 24 | 25 | @lru_cache(typed=True) 26 | def build_controller( # noqa: C901 27 | url_base: str = "/saq", 28 | controller_guards: "Optional[list[Guard]]" = None, # pyright: ignore[reportUnknownParameterType] 29 | include_in_schema_: bool = False, 30 | ) -> "type[Controller]": 31 | from litestar import Controller, MediaType, get, post 32 | from litestar.exceptions import NotFoundException 33 | from litestar.status_codes import HTTP_202_ACCEPTED 34 | 35 | class SAQController(Controller): 36 | tags = ["SAQ"] 37 | guards = controller_guards # pyright: ignore[reportUnknownVariableType] 38 | include_in_schema = include_in_schema_ 39 | 40 | @get( 41 | operation_id="WorkerQueueList", 42 | name="worker:queue-list", 43 | path=[f"{url_base}/api/queues"], 44 | media_type=MediaType.JSON, 45 | cache=False, 46 | summary="Queue List", 47 | description="List configured worker queues.", 48 | ) 49 | async def queue_list(self, task_queues: "TaskQueues") -> "dict[str, list[QueueInfo]]": 50 | """Get Worker queues. 51 | 52 | Args: 53 | task_queues: The task queues. 54 | 55 | Returns: 56 | The worker queues. 57 | """ 58 | return {"queues": [await queue.info() for queue in task_queues.queues.values()]} 59 | 60 | @get( 61 | operation_id="WorkerQueueDetail", 62 | name="worker:queue-detail", 63 | path=f"{url_base}/api/queues/{{queue_id:str}}", 64 | media_type=MediaType.JSON, 65 | cache=False, 66 | summary="Queue Detail", 67 | description="List queue details.", 68 | ) 69 | async def queue_detail(self, task_queues: "TaskQueues", queue_id: str) -> "dict[str, QueueInfo]": 70 | """Get queue information. 71 | 72 | Args: 73 | task_queues: The task queues. 74 | queue_id: The queue ID. 75 | 76 | Raises: 77 | NotFoundException: If the queue is not found. 78 | 79 | Returns: 80 | The queue information. 81 | """ 82 | queue = task_queues.get(queue_id) 83 | if not queue: 84 | msg = f"Could not find the {queue_id} queue" 85 | raise NotFoundException(msg) 86 | return {"queue": await queue.info(jobs=True)} 87 | 88 | @get( 89 | operation_id="WorkerJobDetail", 90 | name="worker:job-detail", 91 | path=f"{url_base}/api/queues/{{queue_id:str}}/jobs/{{job_id:str}}", 92 | media_type=MediaType.JSON, 93 | cache=False, 94 | summary="Job Details", 95 | description="List job details.", 96 | ) 97 | async def job_detail( 98 | self, task_queues: "TaskQueues", queue_id: str, job_id: str 99 | ) -> "dict[str, dict[str, Any]]": 100 | """Get job information. 101 | 102 | Args: 103 | task_queues: The task queues. 104 | queue_id: The queue ID. 105 | job_id: The job ID. 106 | 107 | Raises: 108 | NotFoundException: If the queue or job is not found. 109 | 110 | Returns: 111 | The job information. 112 | """ 113 | queue = task_queues.get(queue_id) 114 | if not queue: 115 | msg = f"Could not find the {queue_id} queue" 116 | raise NotFoundException(msg) 117 | job = await job_info(queue, job_id) 118 | job_dict = job.to_dict() 119 | if "kwargs" in job_dict: 120 | job_dict["kwargs"] = repr(job_dict["kwargs"]) 121 | if "result" in job_dict: 122 | job_dict["result"] = repr(job_dict["result"]) 123 | return {"job": job_dict} 124 | 125 | @post( 126 | operation_id="WorkerJobRetry", 127 | name="worker:job-retry", 128 | path=f"{url_base}/api/queues/{{queue_id:str}}/jobs/{{job_id:str}}/retry", 129 | media_type=MediaType.JSON, 130 | cache=False, 131 | summary="Job Retry", 132 | description="Retry a failed job..", 133 | status_code=HTTP_202_ACCEPTED, 134 | ) 135 | async def job_retry(self, task_queues: "TaskQueues", queue_id: str, job_id: str) -> "dict[str, str]": 136 | """Retry job. 137 | 138 | Args: 139 | task_queues: The task queues. 140 | queue_id: The queue ID. 141 | job_id: The job ID. 142 | 143 | Raises: 144 | NotFoundException: If the queue or job is not found. 145 | 146 | Returns: 147 | The job information. 148 | """ 149 | queue = task_queues.get(queue_id) 150 | if not queue: 151 | msg = f"Could not find the {queue_id} queue" 152 | raise NotFoundException(msg) 153 | job = await job_info(queue, job_id) 154 | await job.retry("retried from ui") 155 | return {} 156 | 157 | @post( 158 | operation_id="WorkerJobAbort", 159 | name="worker:job-abort", 160 | path=f"{url_base}/api/queues/{{queue_id:str}}/jobs/{{job_id:str}}/abort", 161 | media_type=MediaType.JSON, 162 | cache=False, 163 | summary="Job Abort", 164 | description="Abort active job.", 165 | status_code=HTTP_202_ACCEPTED, 166 | ) 167 | async def job_abort(self, task_queues: "TaskQueues", queue_id: str, job_id: str) -> "dict[str, str]": 168 | """Abort job. 169 | 170 | Args: 171 | task_queues: The task queues. 172 | queue_id: The queue ID. 173 | job_id: The job ID. 174 | 175 | Raises: 176 | NotFoundException: If the queue or job is not found. 177 | 178 | Returns: 179 | The job information. 180 | """ 181 | queue = task_queues.get(queue_id) 182 | if not queue: 183 | msg = f"Could not find the {queue_id} queue" 184 | raise NotFoundException(msg) 185 | job = await job_info(queue, job_id) 186 | await job.abort("aborted from ui") 187 | return {} 188 | 189 | # static site 190 | @get( 191 | [ 192 | f"{url_base}/", 193 | f"{url_base}/queues/{{queue_id:str}}", 194 | f"{url_base}/queues/{{queue_id:str}}/jobs/{{job_id:str}}", 195 | ], 196 | operation_id="WorkerIndex", 197 | name="worker:index", 198 | media_type=MediaType.HTML, 199 | include_in_schema=False, 200 | ) 201 | async def index(self) -> str: 202 | """Serve site root. 203 | 204 | Returns: 205 | The site root. 206 | """ 207 | return f""" 208 | 209 | 210 | 211 | 212 | 213 | 214 | SAQ 215 | 216 | 217 |
218 | 219 | 220 | 221 | 222 | """.strip() 223 | 224 | return SAQController 225 | -------------------------------------------------------------------------------- /litestar_saq/exceptions.py: -------------------------------------------------------------------------------- 1 | class LitestarSaqError(Exception): 2 | """Base exception type for the Litestar Saq.""" 3 | 4 | 5 | class ImproperConfigurationError(LitestarSaqError): 6 | """Improper Configuration error. 7 | 8 | This exception is raised only when a module depends on a dependency that has not been installed. 9 | """ 10 | 11 | 12 | class BackgroundTaskError(Exception): 13 | """Base class for `Task` related exceptions.""" 14 | -------------------------------------------------------------------------------- /litestar_saq/plugin.py: -------------------------------------------------------------------------------- 1 | import signal 2 | import sys 3 | import time 4 | from contextlib import contextmanager 5 | from importlib.util import find_spec 6 | from multiprocessing import Process 7 | from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast 8 | 9 | from litestar.plugins import CLIPlugin, InitPluginProtocol 10 | 11 | from litestar_saq.base import Worker 12 | 13 | if TYPE_CHECKING: 14 | from collections.abc import Collection, Iterator 15 | 16 | from click import Group 17 | from litestar import Litestar 18 | from litestar.config.app import AppConfig 19 | from saq.queue.base import Queue 20 | from saq.types import Function, ReceivesContext 21 | 22 | from litestar_saq.config import SAQConfig, TaskQueues 23 | 24 | T = TypeVar("T") 25 | 26 | STRUCTLOG_INSTALLED = find_spec("structlog") is not None 27 | 28 | 29 | class SAQPlugin(InitPluginProtocol, CLIPlugin): 30 | """SAQ plugin.""" 31 | 32 | __slots__ = ("_config", "_processes", "_worker_instances") 33 | 34 | WORKER_SHUTDOWN_TIMEOUT = 5.0 # seconds 35 | WORKER_JOIN_TIMEOUT = 1.0 # seconds 36 | 37 | def __init__(self, config: "SAQConfig") -> None: 38 | """Initialize ``SAQPlugin``. 39 | 40 | Args: 41 | config: configure and start SAQ. 42 | """ 43 | self._config = config 44 | self._worker_instances: Optional[dict[str, Worker]] = None 45 | 46 | @property 47 | def config(self) -> "SAQConfig": 48 | return self._config 49 | 50 | def on_cli_init(self, cli: "Group") -> None: 51 | from litestar_saq.cli import build_cli_app 52 | 53 | cli.add_command(build_cli_app()) 54 | return super().on_cli_init(cli) # type: ignore[safe-super] 55 | 56 | def on_app_init(self, app_config: "AppConfig") -> "AppConfig": 57 | """Configure application for use with SQLAlchemy. 58 | 59 | Args: 60 | app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. 61 | 62 | Returns: 63 | The :class:`AppConfig <.config.app.AppConfig>` instance. 64 | """ 65 | 66 | from litestar.di import Provide 67 | from litestar.static_files import create_static_files_router # pyright: ignore[reportUnknownVariableType] 68 | 69 | from litestar_saq.controllers import build_controller 70 | 71 | app_config.dependencies.update( 72 | {self._config.queues_dependency_key: Provide(dependency=self._config.provide_queues)} 73 | ) 74 | if self._config.web_enabled: 75 | app_config.route_handlers.append( 76 | create_static_files_router( 77 | directories=[self._config.static_files], 78 | path=f"{self._config.web_path}/static", 79 | name="saq", 80 | html_mode=False, 81 | opt={"exclude_from_auth": True}, 82 | include_in_schema=False, 83 | ), 84 | ) 85 | app_config.route_handlers.append( 86 | build_controller(self._config.web_path, self._config.web_guards, self._config.web_include_in_schema), # type: ignore[arg-type] 87 | ) 88 | app_config.signature_namespace.update(self._config.signature_namespace) 89 | 90 | workers = self.get_workers() 91 | for worker in workers.values(): 92 | app_config.on_startup.append(worker.on_app_startup) 93 | app_config.on_shutdown.append(worker.on_app_shutdown) 94 | app_config.on_shutdown.extend([self.remove_workers]) 95 | return app_config 96 | 97 | def get_workers(self) -> "dict[str, Worker]": 98 | """Return workers""" 99 | if self._worker_instances is not None: 100 | return self._worker_instances 101 | self._worker_instances = { 102 | queue_config.name: Worker( 103 | queue=self.get_queue(queue_config.name), 104 | id=queue_config.id, 105 | functions=cast("Collection[Function]", queue_config.tasks), 106 | cron_jobs=queue_config.scheduled_tasks, 107 | cron_tz=queue_config.cron_tz, 108 | concurrency=queue_config.concurrency, 109 | startup=cast("Collection[ReceivesContext]", queue_config.startup), 110 | shutdown=cast("Collection[ReceivesContext]", queue_config.shutdown), 111 | before_process=cast("Collection[ReceivesContext]", queue_config.before_process), 112 | after_process=cast("Collection[ReceivesContext]", queue_config.after_process), 113 | timers=queue_config.timers, 114 | dequeue_timeout=queue_config.dequeue_timeout, 115 | separate_process=queue_config.separate_process, 116 | burst=queue_config.burst, 117 | max_burst_jobs=queue_config.max_burst_jobs, 118 | metadata=queue_config.metadata, 119 | ) 120 | for queue_config in self._config.queue_configs 121 | } 122 | 123 | return self._worker_instances 124 | 125 | def remove_workers(self) -> None: 126 | self._worker_instances = None 127 | 128 | def get_queues(self) -> "TaskQueues": 129 | return self._config.get_queues() 130 | 131 | def get_queue(self, name: str) -> "Queue": 132 | return self.get_queues().get(name) 133 | 134 | @contextmanager 135 | def server_lifespan(self, app: "Litestar") -> "Iterator[None]": 136 | import multiprocessing 137 | import platform 138 | 139 | from litestar.cli._utils import console # pyright: ignore 140 | 141 | from litestar_saq.cli import run_saq_worker 142 | 143 | if platform.system() == "Darwin": 144 | multiprocessing.set_start_method("fork", force=True) 145 | 146 | if not self._config.use_server_lifespan: 147 | yield 148 | return 149 | 150 | console.rule("[yellow]Starting SAQ Workers[/]", align="left") 151 | self._processes: list[Process] = [] 152 | 153 | def handle_shutdown(_signum: Any, _frame: Any) -> None: 154 | """Handle shutdown signals gracefully.""" 155 | console.print("[yellow]Received shutdown signal, stopping workers...[/]") 156 | self._terminate_workers(self._processes) 157 | sys.exit(0) 158 | 159 | # Register signal handlers 160 | signal.signal(signal.SIGTERM, handle_shutdown) 161 | signal.signal(signal.SIGINT, handle_shutdown) 162 | 163 | try: 164 | for worker_name, worker in self.get_workers().items(): 165 | for i in range(self.config.worker_processes): 166 | console.print(f"[yellow]Starting worker process {i + 1} for {worker_name}[/]") 167 | process = Process( 168 | target=run_saq_worker, 169 | args=( 170 | worker, 171 | app.logging_config, 172 | ), 173 | name=f"worker-{worker_name}-{i + 1}", 174 | ) 175 | process.start() 176 | self._processes.append(process) 177 | 178 | yield 179 | 180 | except Exception as e: 181 | console.print(f"[red]Error in worker processes: {e}[/]") 182 | raise 183 | finally: 184 | console.print("[yellow]Shutting down SAQ workers...[/]") 185 | self._terminate_workers(self._processes) 186 | console.print("[yellow]SAQ workers stopped.[/]") 187 | 188 | @staticmethod 189 | def _terminate_workers(processes: "list[Process]", timeout: float = 5.0) -> None: 190 | """Gracefully terminate worker processes with timeout. 191 | 192 | Args: 193 | processes: List of worker processes to terminate 194 | timeout: Maximum time to wait for graceful shutdown in seconds 195 | """ 196 | # Send SIGTERM to all processes 197 | from litestar.cli._utils import console # pyright: ignore 198 | 199 | for p in processes: 200 | if p.is_alive(): 201 | p.terminate() 202 | 203 | # Wait for processes to terminate gracefully 204 | termination_start = time.time() 205 | while time.time() - termination_start < timeout: 206 | if not any(p.is_alive() for p in processes): 207 | break 208 | time.sleep(0.1) 209 | 210 | # Force kill any remaining processes 211 | for p in processes: 212 | if p.is_alive(): 213 | try: 214 | p.kill() # Send SIGKILL 215 | p.join(timeout=1.0) 216 | except Exception as e: # noqa: BLE001 217 | console.print(f"[red]Error killing worker process: {e}[/]") 218 | -------------------------------------------------------------------------------- /litestar_saq/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cofin/litestar-saq/329934de25c5a3f7f3459fd8e8d2ca7013d6fe50/litestar_saq/py.typed -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = [{ name = "Cody Fincher", email = "cody.fincher@gmail.com" }] 3 | classifiers = [ 4 | "Development Status :: 3 - Alpha", 5 | "Environment :: Web Environment", 6 | "License :: OSI Approved :: MIT License", 7 | "Natural Language :: English", 8 | "Operating System :: OS Independent", 9 | "Programming Language :: Python :: 3.9", 10 | "Programming Language :: Python :: 3.10", 11 | "Programming Language :: Python :: 3.11", 12 | "Programming Language :: Python :: 3.12", 13 | "Programming Language :: Python :: 3.13", 14 | "Programming Language :: Python", 15 | "Topic :: Software Development", 16 | "Typing :: Typed", 17 | ] 18 | dependencies = [ 19 | "litestar>=2.0.1", 20 | "saq>=0.24.4", 21 | ] 22 | description = "Litestar integration for SAQ" 23 | keywords = ["litestar", "saq"] 24 | license = { text = "MIT" } 25 | name = "litestar-saq" 26 | readme = "README.md" 27 | requires-python = ">=3.9" 28 | version = "0.5.3" 29 | 30 | [project.optional-dependencies] 31 | hiredis = ["hiredis"] 32 | psycopg = ["psycopg[pool,binary]"] 33 | 34 | [project.urls] 35 | Changelog = "https://cofin.github.io/litesatr-saq/latest/changelog" 36 | Discord = "https://discord.gg/X3FJqy8d2j" 37 | Documentation = "https://cofin.github.io/litesatr-saq/latest/" 38 | Homepage = "https://cofin.github.io/litesatr-saq/latest/" 39 | Issue = "https://github.com/cofin/litestar-saq/issues/" 40 | Source = "https://github.com/cofin/litestar-saq" 41 | 42 | [build-system] 43 | build-backend = "hatchling.build" 44 | requires = ["hatchling"] 45 | 46 | [dependency-groups] 47 | dev = [{include-group = "build"}, {include-group = "linting"}, {include-group = "test"}] 48 | build = ["bump-my-version"] 49 | linting = [ 50 | "pre-commit", 51 | "mypy", 52 | "ruff", 53 | "types-click", 54 | "types-redis", 55 | "types-croniter", 56 | "pyright", 57 | "slotscheck", 58 | ] 59 | test = [ 60 | "pytest", 61 | "pytest-mock", 62 | "httpx", 63 | "pytest-cov", 64 | "coverage", 65 | "pytest-databases", 66 | "pytest-sugar", 67 | "pytest-asyncio", 68 | "pytest-xdist", 69 | "anyio", 70 | "litestar[jinja,redis,standard]", 71 | "psycopg[pool,binary]", 72 | ] 73 | 74 | 75 | [tool.bumpversion] 76 | allow_dirty = true 77 | commit = false 78 | commit_args = "--no-verify" 79 | current_version = "0.5.3" 80 | ignore_missing_files = false 81 | ignore_missing_version = false 82 | message = "chore(release): bump to `v{new_version}`" 83 | parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" 84 | regex = false 85 | replace = "{new_version}" 86 | search = "{current_version}" 87 | serialize = ["{major}.{minor}.{patch}"] 88 | sign_tags = false 89 | tag = false 90 | tag_message = "chore(release): `v{new_version}`" 91 | tag_name = "v{new_version}" 92 | 93 | [[tool.bumpversion.files]] 94 | filename = "pyproject.toml" 95 | replace = 'version = "{new_version}"' 96 | search = 'version = "{current_version}"' 97 | 98 | 99 | [[tool.bumpversion.files]] 100 | filename = "uv.lock" 101 | replace = """ 102 | name = "litestar-saq" 103 | version = "{new_version}" 104 | """ 105 | search = """ 106 | name = "litestar-saq" 107 | version = "{current_version}" 108 | """ 109 | 110 | [tool.pytest.ini_options] 111 | addopts = ["-q", "-ra"] 112 | filterwarnings = [ 113 | "ignore::DeprecationWarning:pkg_resources", 114 | "ignore::DeprecationWarning:xdist.*", 115 | ] 116 | minversion = "6.0" 117 | testpaths = ["tests"] 118 | tmp_path_retention_policy = "failed" 119 | tmp_path_retention_count = 3 120 | asyncio_default_fixture_loop_scope = "function" 121 | asyncio_mode = "auto" 122 | 123 | [tool.coverage.report] 124 | exclude_lines = [ 125 | 'if TYPE_CHECKING:', 126 | 'pragma: no cover', 127 | "if __name__ == .__main__.:", 128 | 'def __repr__', 129 | 'if self\.debug:', 130 | 'if settings\.DEBUG', 131 | 'raise AssertionError', 132 | 'raise NotImplementedError', 133 | 'if 0:', 134 | 'class .*\bProtocol\):', 135 | '@(abc\.)?abstractmethod', 136 | ] 137 | omit = ["*/tests/*"] 138 | show_missing = true 139 | 140 | 141 | [tool.coverage.run] 142 | branch = true 143 | concurrency = ["multiprocessing" ] 144 | omit = ["tests/*"] 145 | parallel = true 146 | 147 | [tool.slotscheck] 148 | strict-imports = false 149 | 150 | [tool.ruff] 151 | exclude = [ 152 | ".bzr", 153 | ".direnv", 154 | ".eggs", 155 | ".git", 156 | ".hg", 157 | ".mypy_cache", 158 | ".nox", 159 | ".pants.d", 160 | ".ruff_cache", 161 | ".svn", 162 | ".tox", 163 | ".venv", 164 | "__pypackages__", 165 | "_build", 166 | "buck-out", 167 | "build", 168 | "dist", 169 | "node_modules", 170 | "venv", 171 | '__pycache__', 172 | ] 173 | fix = true 174 | line-length = 120 175 | lint.fixable = ["ALL"] 176 | lint.ignore = [ 177 | "A003", # flake8-builtins - class attribute {name} is shadowing a python builtin 178 | "B010", # flake8-bugbear - do not call setattr with a constant attribute value 179 | "D100", # pydocstyle - missing docstring in public module 180 | "D101", # pydocstyle - missing docstring in public class 181 | "D102", # pydocstyle - missing docstring in public method 182 | "D103", # pydocstyle - missing docstring in public function 183 | "D104", # pydocstyle - missing docstring in public package 184 | "D105", # pydocstyle - missing docstring in magic method 185 | "D106", # pydocstyle - missing docstring in public nested class 186 | "D107", # pydocstyle - missing docstring in __init__ 187 | "D202", # pydocstyle - no blank lines allowed after function docstring 188 | "D205", # pydocstyle - 1 blank line required between summary line and description 189 | "D415", # pydocstyle - first line should end with a period, question mark, or exclamation point 190 | "E501", # pycodestyle line too long, handled by black 191 | "PLW2901", # pylint - for loop variable overwritten by assignment target 192 | "RUF012", # Ruff-specific rule - annotated with classvar 193 | "ANN401", 194 | "FBT", 195 | "PLR0913", # too many arguments 196 | "PT", 197 | "TD", 198 | "ARG002", # ignore for now; investigate 199 | "PERF203", # ignore for now; investigate 200 | "ISC001", 201 | "COM812", 202 | "FA100", # use __future__ import annotations 203 | "CPY001", # copyright at the top of the file 204 | "PLC0415", # import at the top of the file 205 | "PGH003", # Use specific rule codes when ignoring type issues 206 | "PLC2701", # ignore for now; investigate 207 | ] 208 | lint.select = ["ALL"] 209 | # Allow unused variables when underscore-prefixed. 210 | lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 211 | src = ["litestar_saq", "tests"] 212 | target-version = "py39" 213 | unsafe-fixes = true 214 | 215 | [tool.ruff.lint.pydocstyle] 216 | convention = "google" 217 | 218 | [tool.ruff.lint.mccabe] 219 | max-complexity = 12 220 | 221 | [tool.ruff.lint.pep8-naming] 222 | classmethod-decorators = [ 223 | "sqlalchemy.ext.declarative.declared_attr", 224 | "sqlalchemy.orm.declared_attr.directive", 225 | "sqlalchemy.orm.declared_attr", 226 | ] 227 | 228 | [tool.ruff.lint.per-file-ignores] 229 | "tests/**/*.*" = [ 230 | "A", 231 | "ARG", 232 | "B", 233 | "BLE", 234 | "C901", 235 | "D", 236 | "DTZ", 237 | "EM", 238 | "FBT", 239 | "G", 240 | "N", 241 | "PGH", 242 | "PIE", 243 | "PLR", 244 | "PLW", 245 | "PTH", 246 | "RSE", 247 | "S", 248 | "S101", 249 | "SIM", 250 | "TC", 251 | "TRY", 252 | "UP006", 253 | "SLF001", 254 | "ERA001", 255 | 256 | ] 257 | "tools/*.py" = [ "PLR0911"] 258 | 259 | [tool.ruff.lint.isort] 260 | known-first-party = ["litestar_saq", "tests"] 261 | 262 | [tool.mypy] 263 | disallow_any_generics = false 264 | disallow_untyped_decorators = true 265 | implicit_reexport = false 266 | show_error_codes = true 267 | strict = true 268 | warn_redundant_casts = true 269 | warn_return_any = true 270 | warn_unreachable = true 271 | warn_unused_configs = true 272 | warn_unused_ignores = true 273 | 274 | [[tool.mypy.overrides]] 275 | disable_error_code = "attr-defined" 276 | disallow_untyped_decorators = false 277 | module = "tests.*" 278 | 279 | [tool.pyright] 280 | venvPath = "." 281 | disableBytesTypePromotions = true 282 | exclude = [ 283 | "docs", 284 | "tests/helpers.py", 285 | ] 286 | include = ["litestar_saq"] 287 | pythonVersion = "3.9" 288 | strict = ["litestar_saq/**/*"] 289 | venv = ".venv" 290 | 291 | [tool.codespell] 292 | ignore-words-list = "selectin" 293 | skip = 'uv.lock,pyproject.toml' 294 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cofin/litestar-saq/329934de25c5a3f7f3459fd8e8d2ca7013d6fe50/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import pytest 6 | from redis.asyncio import Redis 7 | 8 | if TYPE_CHECKING: 9 | from collections.abc import AsyncGenerator 10 | 11 | from pytest_databases.docker.redis import RedisService 12 | 13 | pytestmark = pytest.mark.anyio 14 | pytest_plugins = [ 15 | "pytest_databases.docker", 16 | "pytest_databases.docker.redis", 17 | ] 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def anyio_backend() -> str: 22 | return "asyncio" 23 | 24 | 25 | @pytest.fixture(name="redis", autouse=True) 26 | async def fx_redis(redis_service: RedisService) -> AsyncGenerator[Redis, None]: 27 | """Redis instance for testing. 28 | 29 | Returns: 30 | Redis client instance, function scoped. 31 | """ 32 | yield Redis(host=redis_service.host, port=redis_service.port) 33 | -------------------------------------------------------------------------------- /tests/test_cli/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | APP_DEFAULT_CONFIG_FILE_CONTENT = """ 4 | from __future__ import annotations 5 | 6 | import asyncio 7 | from logging import getLogger 8 | from typing import TYPE_CHECKING 9 | 10 | from examples import tasks 11 | from litestar import Controller, Litestar, get 12 | 13 | from litestar_saq import CronJob, QueueConfig, SAQConfig, SAQPlugin 14 | 15 | if TYPE_CHECKING: 16 | from saq.types import Context, QueueInfo 17 | 18 | from litestar_saq.config import TaskQueues 19 | 20 | logger = getLogger(__name__) 21 | 22 | 23 | async def system_upkeep(_: Context) -> None: 24 | logger.info("Performing system upkeep operations.") 25 | logger.info("Simulating a long running operation. Sleeping for 60 seconds.") 26 | await asyncio.sleep(3) 27 | logger.info("Simulating an even longer running operation. Sleeping for 120 seconds.") 28 | await asyncio.sleep(3) 29 | logger.info("Long running process complete.") 30 | logger.info("Performing system upkeep operations.") 31 | 32 | 33 | async def background_worker_task(_: Context) -> None: 34 | logger.info("Performing background worker task.") 35 | await asyncio.sleep(1) 36 | logger.info("Performing system upkeep operations.") 37 | 38 | 39 | async def system_task(_: Context) -> None: 40 | logger.info("Performing simple system task") 41 | await asyncio.sleep(2) 42 | logger.info("System task complete.") 43 | 44 | 45 | class SampleController(Controller): 46 | @get(path="/samples") 47 | async def samples_queue_info(self, task_queues: TaskQueues) -> QueueInfo: 48 | queue = task_queues.get("samples") 49 | return await queue.info() 50 | 51 | 52 | saq = SAQPlugin( 53 | config=SAQConfig( 54 | web_enabled=True, 55 | use_server_lifespan=True, 56 | queue_configs=[ 57 | QueueConfig( 58 | dsn="redis://localhost:6397/0", 59 | name="samples", 60 | tasks=[tasks.background_worker_task, tasks.system_task, tasks.system_upkeep], 61 | scheduled_tasks=[CronJob(function=tasks.system_upkeep, cron="* * * * *", timeout=600, ttl=2000)], 62 | ), 63 | ], 64 | ), 65 | ) 66 | app = Litestar(plugins=[saq], route_handlers=[SampleController]) 67 | 68 | """ 69 | -------------------------------------------------------------------------------- /tests/test_cli/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import importlib.util 4 | import sys 5 | from collections.abc import Generator 6 | from pathlib import Path 7 | from shutil import rmtree 8 | from typing import TYPE_CHECKING, Callable, Protocol, cast 9 | 10 | import pytest 11 | from _pytest.fixtures import FixtureRequest 12 | from _pytest.monkeypatch import MonkeyPatch 13 | from click.testing import CliRunner 14 | from litestar.cli._utils import _path_to_dotted_path 15 | from pytest_mock import MockerFixture 16 | 17 | from . import APP_DEFAULT_CONFIG_FILE_CONTENT 18 | 19 | if TYPE_CHECKING: 20 | from unittest.mock import MagicMock 21 | 22 | from litestar.cli._utils import LitestarGroup 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def reset_litestar_app_env(monkeypatch: MonkeyPatch) -> None: 27 | monkeypatch.delenv("LITESTAR_APP", raising=False) 28 | 29 | 30 | @pytest.fixture() 31 | def root_command() -> LitestarGroup: 32 | import litestar.cli.main 33 | 34 | return cast("LitestarGroup", importlib.reload(litestar.cli.main).litestar_group) 35 | 36 | 37 | @pytest.fixture 38 | def patch_autodiscovery_paths(request: FixtureRequest) -> Callable[[list[str]], None]: 39 | def patcher(paths: list[str]) -> None: 40 | from litestar.cli._utils import AUTODISCOVERY_FILE_NAMES 41 | 42 | old_paths = AUTODISCOVERY_FILE_NAMES[::] 43 | AUTODISCOVERY_FILE_NAMES[:] = paths 44 | 45 | def finalizer() -> None: 46 | AUTODISCOVERY_FILE_NAMES[:] = old_paths 47 | 48 | request.addfinalizer(finalizer) 49 | 50 | return patcher 51 | 52 | 53 | @pytest.fixture(autouse=True) 54 | def tmp_project_dir(monkeypatch: MonkeyPatch, tmp_path: Path) -> Path: 55 | path = tmp_path / "project_dir" 56 | path.mkdir(exist_ok=True) 57 | monkeypatch.chdir(path) 58 | return path 59 | 60 | 61 | class CreateAppFileFixture(Protocol): 62 | def __call__( 63 | self, 64 | file: str | Path, 65 | directory: str | Path | None = None, 66 | content: str | None = None, 67 | init_content: str = "", 68 | subdir: str | None = None, 69 | ) -> Path: ... 70 | 71 | 72 | def _purge_module(module_names: list[str], path: str | Path) -> None: 73 | for name in module_names: 74 | if name in sys.modules: 75 | del sys.modules[name] 76 | Path(importlib.util.cache_from_source(path)).unlink(missing_ok=True) # type: ignore[arg-type] 77 | 78 | 79 | @pytest.fixture 80 | def create_app_file(tmp_project_dir: Path, request: FixtureRequest) -> CreateAppFileFixture: 81 | def _create_app_file( 82 | file: str | Path, 83 | directory: str | Path | None = None, 84 | content: str | None = None, 85 | init_content: str = "", 86 | subdir: str | None = None, 87 | ) -> Path: 88 | base = tmp_project_dir 89 | if directory: 90 | base /= Path(Path(directory) / subdir) if subdir else Path(directory) 91 | base.mkdir(parents=True) 92 | base.joinpath("__init__.py").write_text(init_content) 93 | 94 | tmp_app_file = base / file 95 | tmp_app_file.write_text(content or APP_DEFAULT_CONFIG_FILE_CONTENT) 96 | 97 | if directory: 98 | request.addfinalizer(lambda: rmtree(directory)) 99 | request.addfinalizer( 100 | lambda: _purge_module( 101 | [directory, _path_to_dotted_path(tmp_app_file.relative_to(Path.cwd()))], # type: ignore[list-item] 102 | tmp_app_file, 103 | ), 104 | ) 105 | else: 106 | request.addfinalizer(tmp_app_file.unlink) 107 | request.addfinalizer(lambda: _purge_module([str(file).replace(".py", "")], tmp_app_file)) 108 | return tmp_app_file 109 | 110 | return _create_app_file 111 | 112 | 113 | @pytest.fixture 114 | def app_file(create_app_file: CreateAppFileFixture) -> Path: 115 | return create_app_file("app.py") 116 | 117 | 118 | @pytest.fixture 119 | def runner() -> CliRunner: 120 | return CliRunner() 121 | 122 | 123 | @pytest.fixture 124 | def mock_uvicorn_run(mocker: MockerFixture) -> MagicMock: 125 | return mocker.patch("uvicorn.run") 126 | 127 | 128 | @pytest.fixture() 129 | def mock_subprocess_run(mocker: MockerFixture) -> MagicMock: 130 | return mocker.patch("subprocess.run") 131 | 132 | 133 | @pytest.fixture 134 | def mock_confirm_ask(mocker: MockerFixture) -> Generator[MagicMock, None, None]: 135 | yield mocker.patch("rich.prompt.Confirm.ask", return_value=True) 136 | 137 | 138 | @pytest.fixture( 139 | params=[ 140 | pytest.param((APP_DEFAULT_CONFIG_FILE_CONTENT, "app"), id="app_obj"), 141 | ], 142 | ) 143 | def _app_file_content(request: FixtureRequest) -> tuple[str, str]: 144 | return cast("tuple[str, str]", request.param) 145 | 146 | 147 | @pytest.fixture 148 | def app_file_content(_app_file_content: tuple[str, str]) -> str: 149 | return _app_file_content[0] 150 | 151 | 152 | @pytest.fixture 153 | def app_file_app_name(_app_file_content: tuple[str, str]) -> str: 154 | return _app_file_content[1] 155 | -------------------------------------------------------------------------------- /tests/test_cli/test_cli.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from click.testing import CliRunner 3 | from litestar.cli._utils import LitestarGroup 4 | from redis.asyncio import Redis 5 | 6 | from tests.test_cli import APP_DEFAULT_CONFIG_FILE_CONTENT 7 | from tests.test_cli.conftest import CreateAppFileFixture 8 | 9 | pytestmark = pytest.mark.anyio 10 | 11 | 12 | async def test_basic_command( 13 | runner: CliRunner, 14 | create_app_file: CreateAppFileFixture, 15 | root_command: LitestarGroup, 16 | redis_service: None, 17 | redis: Redis, 18 | ) -> None: 19 | app_file = create_app_file("command_test_app.py", content=APP_DEFAULT_CONFIG_FILE_CONTENT) 20 | result = runner.invoke(root_command, ["--app", f"{app_file.stem}:app", "workers"]) 21 | 22 | assert not result.exception 23 | assert "Manage background task workers." in result.output 24 | -------------------------------------------------------------------------------- /tests/test_plugin.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | from litestar.cli._utils import LitestarGroup as Group 5 | 6 | from litestar_saq.config import SAQConfig, TaskQueues 7 | from litestar_saq.plugin import SAQPlugin 8 | 9 | # Assuming SAQConfig, Worker, Group, AppConfig, TaskQueues, Queue are available and can be imported 10 | # Assuming there are meaningful __eq__ methods for comparisons where needed 11 | 12 | 13 | # Test on_cli_init method 14 | @pytest.mark.parametrize( 15 | "cli_group", 16 | [ 17 | Mock(Group), 18 | ], 19 | ) 20 | def test_on_cli_init(cli_group: Group) -> None: 21 | # Arrange 22 | config = Mock(SAQConfig) 23 | plugin = SAQPlugin(config) 24 | 25 | # Act 26 | plugin.on_cli_init(cli_group) 27 | 28 | # Assert 29 | cli_group.add_command.assert_called_once() # pyright: ignore[reportFunctionMemberAccess] 30 | 31 | 32 | # Test get_queues method 33 | @pytest.mark.parametrize( 34 | "queues", 35 | [ 36 | Mock(), 37 | ], 38 | ) 39 | def test_get_queues(queues: TaskQueues) -> None: 40 | # Arrange 41 | config = Mock(SAQConfig, get_queues=Mock(return_value=queues)) 42 | plugin = SAQPlugin(config) 43 | 44 | # Act 45 | result = plugin.get_queues() 46 | 47 | # Assert 48 | assert result == queues 49 | --------------------------------------------------------------------------------