├── .github └── workflows │ ├── chaos-test.yml │ ├── publish.yml │ └── unit-test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pydocstyle ├── CONTRIBUTING.md ├── DEVELOPING.md ├── LICENSE ├── README.md ├── alembic.ini ├── chaos-tests ├── conftest.py └── test_workflows.py ├── dbos ├── __init__.py ├── __main__.py ├── _admin_server.py ├── _app_db.py ├── _classproperty.py ├── _client.py ├── _conductor │ ├── conductor.py │ └── protocol.py ├── _context.py ├── _core.py ├── _croniter.py ├── _dbos.py ├── _dbos_config.py ├── _debug.py ├── _docker_pg_helper.py ├── _error.py ├── _event_loop.py ├── _fastapi.py ├── _flask.py ├── _kafka.py ├── _kafka_message.py ├── _logger.py ├── _migrations │ ├── env.py │ ├── script.py.mako │ └── versions │ │ ├── 04ca4f231047_workflow_queues_executor_id.py │ │ ├── 27ac6900c6ad_add_queue_dedup.py │ │ ├── 50f3227f0b4b_fix_job_queue.py │ │ ├── 5c361fc04708_added_system_tables.py │ │ ├── 66478e1b95e5_consolidate_queues.py │ │ ├── 83f3732ae8e7_workflow_timeout.py │ │ ├── 933e86bdac6a_add_queue_priority.py │ │ ├── a3b18ad34abe_added_triggers.py │ │ ├── d76646551a6b_job_queue_limiter.py │ │ ├── d76646551a6c_workflow_queue.py │ │ ├── d994145b47b6_consolidate_inputs.py │ │ ├── eab0cc1d9a14_job_queue.py │ │ └── f4b9b32ba814_functionname_childid_op_outputs.py ├── _outcome.py ├── _queue.py ├── _recovery.py ├── _registrations.py ├── _roles.py ├── _scheduler.py ├── _schemas │ ├── __init__.py │ ├── application_database.py │ └── system_database.py ├── _serialization.py ├── _sys_db.py ├── _templates │ └── dbos-db-starter │ │ ├── README.md │ │ ├── __package │ │ ├── __init__.py │ │ ├── main.py.dbos │ │ └── schema.py │ │ ├── alembic.ini │ │ ├── dbos-config.yaml.dbos │ │ ├── migrations │ │ ├── env.py.dbos │ │ ├── script.py.mako │ │ └── versions │ │ │ └── 2024_07_31_180642_init.py │ │ └── start_postgres_docker.py ├── _tracer.py ├── _utils.py ├── _workflow_commands.py ├── cli │ ├── _github_init.py │ ├── _template_init.py │ └── cli.py ├── dbos-config.schema.json └── py.typed ├── make_release.py ├── pdm.lock ├── pyproject.toml ├── pyrightconfig.test.json ├── tests ├── __init__.py ├── atexit_no_ctor.py ├── atexit_no_launch.py ├── classdefs.py ├── client_collateral.py ├── client_worker.py ├── conftest.py ├── dupname_classdefs1.py ├── dupname_classdefsa.py ├── more_classdefs.py ├── queuedworkflow.py ├── test_admin_server.py ├── test_async.py ├── test_classdecorators.py ├── test_cli.py ├── test_client.py ├── test_concurrency.py ├── test_config.py ├── test_croniter.py ├── test_dbos.py ├── test_debug.py ├── test_docker_secrets.py ├── test_failures.py ├── test_fastapi.py ├── test_fastapi_roles.py ├── test_flask.py ├── test_kafka.py ├── test_outcome.py ├── test_package.py ├── test_queue.py ├── test_scheduler.py ├── test_schema_migration.py ├── test_singleton.py ├── test_spans.py ├── test_sqlalchemy.py ├── test_workflow_introspection.py └── test_workflow_management.py └── version └── __init__.py /.github/workflows/chaos-test.yml: -------------------------------------------------------------------------------- 1 | name: Run Chaos Tests 2 | 3 | on: 4 | schedule: 5 | # Runs every hour on the hour 6 | - cron: '0 * * * *' 7 | workflow_dispatch: 8 | 9 | jobs: 10 | integration: 11 | runs-on: ubuntu-latest 12 | timeout-minutes: 60 13 | steps: 14 | - uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 17 | fetch-tags: true 18 | 19 | - name: Setup PDM 20 | uses: pdm-project/setup-pdm@v4 21 | with: 22 | python-version: '3.13' 23 | architecture: 'x64' 24 | 25 | - name: Install Dependencies 26 | run: pdm install 27 | working-directory: ./ 28 | 29 | - name: Run Chaos Tests 30 | run: pdm run pytest --timeout 1200 chaos-tests 31 | working-directory: ./ -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | deploy: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | id-token: write 11 | contents: read 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 17 | fetch-tags: true 18 | 19 | - name: Setup PDM 20 | uses: pdm-project/setup-pdm@v4 21 | with: 22 | python-version: '3.9.x' 23 | architecture: 'x64' 24 | 25 | - name: Install Dependencies 26 | run: pdm install 27 | 28 | - name: Build package 29 | run: pdm build 30 | 31 | - name: Publish package to TestPyPI 32 | uses: pypa/gh-action-pypi-publish@release/v1 33 | with: 34 | packages-dir: dist/ 35 | verbose: true -------------------------------------------------------------------------------- /.github/workflows/unit-test.yml: -------------------------------------------------------------------------------- 1 | name: Run Unit Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - release/* 8 | pull_request: 9 | branches: 10 | - main 11 | - release/* 12 | types: 13 | - ready_for_review 14 | - opened 15 | - reopened 16 | - synchronize 17 | workflow_dispatch: 18 | 19 | jobs: 20 | integration: 21 | runs-on: ubuntu-latest 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] 26 | services: 27 | # Postgres service container 28 | postgres: 29 | image: sibedge/postgres-plv8 30 | env: 31 | # Specify the password for Postgres superuser. 32 | POSTGRES_PASSWORD: a!b@c$d()e*_,/:;=?@ff[]22 33 | # Set health checks to wait until postgres has started 34 | options: >- 35 | --health-cmd pg_isready 36 | --health-interval 10s 37 | --health-timeout 5s 38 | --health-retries 5 39 | ports: 40 | # Maps tcp port 5432 on service container to the host 41 | - 5432:5432 42 | # Zookeeper service container (required by Kafka) 43 | zookeeper: 44 | image: bitnami/zookeeper 45 | ports: 46 | - 2181:2181 47 | env: 48 | ALLOW_ANONYMOUS_LOGIN: yes 49 | options: >- 50 | --health-cmd "echo mntr | nc -w 2 -q 2 localhost 2181" 51 | --health-interval 10s 52 | --health-timeout 5s 53 | --health-retries 5 54 | # Kafka service container 55 | kafka: 56 | image: bitnami/kafka:3.9.0 57 | ports: 58 | - 9092:9092 59 | options: >- 60 | --health-cmd "kafka-broker-api-versions.sh --version" 61 | --health-interval 10s 62 | --health-timeout 5s 63 | --health-retries 5 64 | env: 65 | KAFKA_CFG_ZOOKEEPER_CONNECT: zookeeper:2181 66 | ALLOW_PLAINTEXT_LISTENER: yes 67 | KAFKA_CFG_LISTENERS: PLAINTEXT://:9092 68 | KAFKA_CFG_ADVERTISED_LISTENERS: PLAINTEXT://127.0.0.1:9092 69 | 70 | steps: 71 | - uses: actions/checkout@v4 72 | with: 73 | fetch-depth: 0 74 | fetch-tags: true 75 | 76 | - name: Setup PDM 77 | uses: pdm-project/setup-pdm@v4 78 | with: 79 | python-version: ${{ matrix.python-version }} 80 | architecture: 'x64' 81 | 82 | - name: Install Dependencies 83 | run: pdm install 84 | working-directory: ./ 85 | 86 | # Mypy is the main type-checker used to verify the entire code base. 87 | - name: Check Types 88 | run: pdm run mypy . 89 | working-directory: ./ 90 | 91 | # Pyright is used by Pylance, so verify it works with DBOS application code. 92 | - name: Check Types With Pyright 93 | run: pdm run pyright tests -p pyrightconfig.test.json 94 | working-directory: ./ 95 | 96 | - name: Run Unit Tests 97 | run: pdm run pytest tests 98 | working-directory: ./ 99 | env: 100 | PGPASSWORD: a!b@c$d()e*_,/:;=?@ff[]22 101 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .dbos/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # Editor temp 87 | .*.swp 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 114 | .pdm.toml 115 | .pdm-python 116 | .pdm-build/ 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | .python-version 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | 169 | # IDE files 170 | .vscode/ 171 | .idea/ 172 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 24.4.2 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://github.com/pycqa/isort 8 | rev: 5.13.2 9 | hooks: 10 | - id: isort -------------------------------------------------------------------------------- /.pydocstyle: -------------------------------------------------------------------------------- 1 | [pydocstyle] 2 | ignore = D100,D101,D102,D103,D104,D105,D106,D107,D202,D203,D212,D213,D406,D407,D413 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DBOS Transact Python 2 | 3 | Thank you for considering contributing to DBOS Transact. We welcome contributions from everyone, including bug fixes, feature enhancements, documentation improvements, or any other form of contribution. 4 | 5 | ## How to Contribute 6 | 7 | To get started with DBOS Transact, please read the [README](README.md). 8 | 9 | You can contribute in many ways. Some simple ways are: 10 | * Use the SDK and open issues to report any bugs, questions, concern with the SDK, samples or documentation. 11 | * Respond to issues with advice or suggestions. 12 | * Participate in discussions in our [Discord](https://discord.gg/fMwQjeW5zg) channel. 13 | * Contribute fixes and improvement to code, samples or documentation. 14 | 15 | ### To contribute code, please follow these steps: 16 | 17 | 1. Fork this github repository to your own account. 18 | 19 | 2. Clone the forked repository to your local machine. 20 | 21 | 3. Create a branch. 22 | 23 | 4. Make the necessary change to code, samples or documentation. 24 | 25 | 5. Write tests. 26 | 27 | 6. Commit the changes to your forked repository. 28 | 29 | 7. Submit a pull request to this repository. 30 | In the PR description please include: 31 | * Description of the fix/feature. 32 | * Brief description of implementation. 33 | * Description of how you tested the fix. 34 | 35 | ## Requesting features 36 | 37 | If you have a feature request or an idea for an enhancement, feel free to open an issue on GitHub. Describe the feature or enhancement you'd like to see and why it would be valuable. Discuss it with the community on the [Discord](https://discord.gg/fMwQjeW5zg) channel. 38 | 39 | ## Discuss with the community 40 | 41 | If you are stuck, need help, or wondering if a certain contribution will be welcome, please ask! You can reach out to us on [Discord](https://discord.gg/fMwQjeW5zg) or Github discussions. 42 | 43 | ## Code of conduct 44 | 45 | It is important to us that contributing to DBOS will be a pleasant experience, if necessary, please refer to our [code of conduct](CODE_OF_CONDUCT.md) for participation guidelines. -------------------------------------------------------------------------------- /DEVELOPING.md: -------------------------------------------------------------------------------- 1 | ### Setting Up for Development 2 | 3 | #### Installing `pdm` and `venv` 4 | 5 | This package uses [`pdm`](https://pdm-project.org/en/latest/) for package and 6 | virtual environment management. 7 | To install `pdm`, run: 8 | 9 | ``` 10 | curl -sSL https://pdm-project.org/install-pdm.py | python3 - 11 | ``` 12 | 13 | `pdm` is installed in `~/.local/bin`. You might have to add this directory to 14 | the `PATH` variable. 15 | 16 | On Ubuntu, it may be necessary to install the following for Python 3.10 before 17 | installing `pdm`: 18 | 19 | ``` 20 | apt install python3.10-venv 21 | ``` 22 | 23 | For Python 3.12 run instead: 24 | 25 | ``` 26 | apt install python3.12-venv 27 | ``` 28 | 29 | #### Installing Python dependencies with `pdm` 30 | 31 | NOTE: If you already have a virtual environment for this project, activate 32 | it so that the dependencies are installed into your existing virtual 33 | environment. If you do not have a virtual environment, `pdm` creates one 34 | in `.venv`. 35 | 36 | To install dependencies: 37 | 38 | ``` 39 | pdm install 40 | pdm run pre-commit install 41 | ``` 42 | 43 | #### Executing unit tests, checking types, manage system table schema migrations 44 | 45 | To run unit tests: 46 | 47 | ``` 48 | pdm run pytest 49 | ``` 50 | 51 | NOTE: The tests need a Postgres database running on `localhost:5432`. To start 52 | one, run: 53 | 54 | ```bash 55 | export PGPASSWORD=dbos 56 | python3 dbos/_templates/dbos-db-starter/start_postgres_docker.py 57 | ``` 58 | 59 | A successful test run results in the following output: 60 | 61 | ``` 62 | =============================== warnings summary =============================== 63 | :488 64 | :488: DeprecationWarning: Type google._upb._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14. 65 | 66 | :488 67 | :488: DeprecationWarning: Type google._upb._message.ScalarMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14. 68 | 69 | -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html 70 | ============ 182 passed, 2 skipped, 2 warnings in 254.62s (0:04:14) ============ 71 | ``` 72 | 73 | The two skipped test cases verify the interaction with Kafka and if Kafka is not available, 74 | the test cases are skipped. 75 | 76 | To check types: 77 | 78 | ``` 79 | pdm run mypy . 80 | ``` 81 | 82 | A successful check of types results in (the number of files might be different depending 83 | on the changes in the project since this was written): 84 | 85 | ``` 86 | Success: no issues found in 64 source files 87 | ``` 88 | 89 | We use alembic to manage system table schema migrations. 90 | To generate a new migration, run: 91 | 92 | ``` 93 | pdm run alembic revision -m "" 94 | ``` 95 | 96 | This command will add a new file under the `dbos/migrations/versions/` folder. 97 | For more information, 98 | read [alembic tutorial](https://alembic.sqlalchemy.org/en/latest/tutorial.html). 99 | 100 | ### Creating a Release 101 | 102 | To cut a new release, run: 103 | 104 | ```shell 105 | python3 make_release.py [--version_number ] 106 | ``` 107 | 108 | Version numbers follow [semver](https://semver.org/). 109 | This command tags the latest commit with the version number and creates a 110 | release branch for it. 111 | If a version number is not supplied, it automatically generated a version number 112 | by incrementing the last released minor version. 113 | 114 | ### Patching a release 115 | 116 | To patch a release, push the patch as a commit to the appropriate release 117 | branch. 118 | Then, tag it with a version number: 119 | 120 | ```shell 121 | git tag 122 | git push --tags 123 | ``` 124 | 125 | This version must follow semver: It should increment by one the patch number of 126 | the release branch. 127 | 128 | ### Preview Versions 129 | 130 | Preview versions are [PEP440](https://peps.python.org/pep-0440/)-compliant alpha 131 | versions. 132 | They can be published from `main`. 133 | Their version number is 134 | `a`. 135 | You can install the latest preview version with `pip install --pre dbos`. 136 | 137 | ### Test Versions 138 | 139 | Test versions are built from feature branches. 140 | Their version number is 141 | `a+`. 142 | 143 | ### Publishing 144 | 145 | Run the [`Publish to PyPI`](./.github/workflows/publish.yml) GitHub action on 146 | the target branch. 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 DBOS, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # This is only used when generating new migrations. It is not referenced when running migrations programmatically. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | # Use forward slashes (/) also on windows to provide an os agnostic path 6 | script_location = dbos/_migrations 7 | 8 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 9 | -------------------------------------------------------------------------------- /chaos-tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import threading 4 | import time 5 | from typing import Any, Generator, Optional 6 | from urllib.parse import quote 7 | 8 | import pytest 9 | import sqlalchemy as sa 10 | 11 | from dbos import DBOS, DBOSConfig 12 | from dbos._docker_pg_helper import start_docker_pg, stop_docker_pg 13 | 14 | 15 | @pytest.fixture() 16 | def config() -> DBOSConfig: 17 | return { 18 | "name": "test-app", 19 | "database_url": f"postgresql://postgres:{quote(os.environ.get('PGPASSWORD', 'dbos'), safe='')}@localhost:5432/dbostestpy", 20 | } 21 | 22 | 23 | @pytest.fixture() 24 | def postgres(config: DBOSConfig) -> Generator[None, Any, None]: 25 | start_docker_pg() 26 | yield 27 | stop_docker_pg() 28 | 29 | 30 | @pytest.fixture() 31 | def cleanup_test_databases(config: DBOSConfig, postgres: None) -> None: 32 | assert config["database_url"] is not None 33 | engine = sa.create_engine( 34 | sa.make_url(config["database_url"]).set( 35 | drivername="postgresql+psycopg", 36 | database="postgres", 37 | ), 38 | connect_args={ 39 | "connect_timeout": 30, 40 | }, 41 | ) 42 | app_db_name = sa.make_url(config["database_url"]).database 43 | sys_db_name = f"{app_db_name}_dbos_sys" 44 | 45 | with engine.connect() as connection: 46 | connection.execution_options(isolation_level="AUTOCOMMIT") 47 | connection.execute( 48 | sa.text(f"DROP DATABASE IF EXISTS {app_db_name} WITH (FORCE)") 49 | ) 50 | connection.execute( 51 | sa.text(f"DROP DATABASE IF EXISTS {sys_db_name} WITH (FORCE)") 52 | ) 53 | engine.dispose() 54 | 55 | 56 | class PostgresChaosMonkey: 57 | 58 | def __init__(self) -> None: 59 | self.stop_event = threading.Event() 60 | self.chaos_thread: Optional[threading.Thread] = None 61 | 62 | def start(self) -> None: 63 | def _chaos_thread() -> None: 64 | while not self.stop_event.is_set(): 65 | wait_time = random.uniform(5, 40) 66 | if not self.stop_event.wait(wait_time): 67 | print( 68 | f"🐒 ChaosMonkey strikes after {wait_time:.2f} seconds! Restarting Postgres..." 69 | ) 70 | stop_docker_pg() 71 | down_time = random.uniform(0, 2) 72 | time.sleep(down_time) 73 | start_docker_pg() 74 | 75 | self.stop_event.clear() 76 | self.chaos_thread = threading.Thread(target=_chaos_thread) 77 | self.chaos_thread.start() 78 | 79 | def stop(self) -> None: 80 | if self.chaos_thread is None: 81 | return 82 | self.stop_event.set() 83 | self.chaos_thread.join() 84 | 85 | 86 | @pytest.fixture() 87 | def dbos( 88 | config: DBOSConfig, cleanup_test_databases: None 89 | ) -> Generator[DBOS, Any, None]: 90 | DBOS.destroy(destroy_registry=True) 91 | dbos = DBOS(config=config) 92 | DBOS.launch() 93 | monkey = PostgresChaosMonkey() 94 | monkey.start() 95 | yield dbos 96 | monkey.stop() 97 | DBOS.destroy(destroy_registry=True) 98 | -------------------------------------------------------------------------------- /chaos-tests/test_workflows.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import Any 3 | 4 | import sqlalchemy as sa 5 | 6 | from dbos import DBOS, Queue, SetWorkflowID 7 | 8 | 9 | def test_workflow(dbos: DBOS) -> None: 10 | 11 | @DBOS.step() 12 | def step_one(x: int) -> int: 13 | return x + 1 14 | 15 | @DBOS.step() 16 | def step_two(x: int) -> int: 17 | return x + 2 18 | 19 | @DBOS.transaction() 20 | def txn_one(x: int) -> int: 21 | DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall() 22 | return x + 3 23 | 24 | @DBOS.workflow() 25 | def workflow(x: int) -> int: 26 | x = step_one(x) 27 | x = step_two(x) 28 | x = txn_one(x) 29 | return x 30 | 31 | num_workflows = 5000 32 | 33 | for i in range(num_workflows): 34 | assert workflow(i) == i + 6 35 | 36 | 37 | def test_recv(dbos: DBOS) -> None: 38 | 39 | topic = "test_topic" 40 | 41 | @DBOS.workflow() 42 | def recv_workflow() -> Any: 43 | return DBOS.recv(topic, timeout_seconds=10) 44 | 45 | num_workflows = 10000 46 | 47 | for i in range(num_workflows): 48 | handle = DBOS.start_workflow(recv_workflow) 49 | 50 | value = str(uuid.uuid4()) 51 | DBOS.send(handle.workflow_id, value, topic) 52 | 53 | assert handle.get_result() == value 54 | 55 | 56 | def test_events(dbos: DBOS) -> None: 57 | 58 | key = "test_key" 59 | 60 | @DBOS.workflow() 61 | def event_workflow() -> str: 62 | value = str(uuid.uuid4()) 63 | DBOS.set_event(key, value) 64 | return value 65 | 66 | num_workflows = 5000 67 | 68 | for i in range(num_workflows): 69 | id = str(uuid.uuid4()) 70 | with SetWorkflowID(id): 71 | value = event_workflow() 72 | assert DBOS.get_event(id, key, timeout_seconds=0) == value 73 | 74 | 75 | def test_queues(dbos: DBOS) -> None: 76 | 77 | queue = Queue("test_queue") 78 | 79 | @DBOS.step() 80 | def step_one(x: int) -> int: 81 | return x + 1 82 | 83 | @DBOS.step() 84 | def step_two(x: int) -> int: 85 | return x + 2 86 | 87 | @DBOS.transaction() 88 | def txn_one(x: int) -> int: 89 | DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall() 90 | return x + 3 91 | 92 | @DBOS.workflow() 93 | def workflow(x: int) -> int: 94 | x = queue.enqueue(step_one, x).get_result() 95 | x = queue.enqueue(step_two, x).get_result() 96 | x = queue.enqueue(txn_one, x).get_result() 97 | return x 98 | 99 | num_workflows = 30 100 | 101 | for i in range(num_workflows): 102 | assert queue.enqueue(workflow, i).get_result() == i + 6 103 | -------------------------------------------------------------------------------- /dbos/__init__.py: -------------------------------------------------------------------------------- 1 | from . import _error as error 2 | from ._client import DBOSClient, EnqueueOptions 3 | from ._context import ( 4 | DBOSContextEnsure, 5 | DBOSContextSetAuth, 6 | SetEnqueueOptions, 7 | SetWorkflowID, 8 | SetWorkflowTimeout, 9 | ) 10 | from ._dbos import DBOS, DBOSConfiguredInstance, WorkflowHandle, WorkflowHandleAsync 11 | from ._dbos_config import DBOSConfig 12 | from ._kafka_message import KafkaMessage 13 | from ._queue import Queue 14 | from ._sys_db import GetWorkflowsInput, WorkflowStatus, WorkflowStatusString 15 | 16 | __all__ = [ 17 | "DBOSConfig", 18 | "DBOS", 19 | "DBOSClient", 20 | "DBOSConfiguredInstance", 21 | "DBOSContextEnsure", 22 | "DBOSContextSetAuth", 23 | "EnqueueOptions", 24 | "GetWorkflowsInput", 25 | "KafkaMessage", 26 | "SetWorkflowID", 27 | "SetWorkflowTimeout", 28 | "SetEnqueueOptions", 29 | "WorkflowHandle", 30 | "WorkflowHandleAsync", 31 | "WorkflowStatus", 32 | "WorkflowStatusString", 33 | "error", 34 | "Queue", 35 | ] 36 | -------------------------------------------------------------------------------- /dbos/__main__.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | from typing import NoReturn, Optional, Union 4 | 5 | from dbos.cli.cli import app 6 | 7 | # This is used by the debugger to execute DBOS as a module. 8 | # Never used otherwise. 9 | 10 | 11 | def main() -> NoReturn: 12 | # Modify sys.argv[0] to remove script or executable extensions 13 | sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0]) 14 | 15 | retval: Optional[Union[str, int]] = 1 16 | try: 17 | app() 18 | retval = None 19 | except SystemExit as e: 20 | retval = e.code 21 | except Exception as e: 22 | print(f"Error: {e}", file=sys.stderr) 23 | retval = 1 24 | finally: 25 | sys.exit(retval) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /dbos/_classproperty.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Generic, Optional, TypeVar 2 | 3 | G = TypeVar("G") # A generic type for ClassPropertyDescriptor getters 4 | 5 | 6 | class ClassPropertyDescriptor(Generic[G]): 7 | def __init__(self, fget: Callable[..., G]) -> None: 8 | self.fget = fget 9 | 10 | def __get__(self, obj: Any, objtype: Optional[Any] = None) -> G: 11 | if objtype is None: 12 | objtype = type(obj) 13 | if self.fget is None: 14 | raise AttributeError("unreadable attribute") 15 | return self.fget(objtype) 16 | 17 | 18 | def classproperty(func: Callable[..., G]) -> ClassPropertyDescriptor[G]: 19 | return ClassPropertyDescriptor(func) 20 | -------------------------------------------------------------------------------- /dbos/_debug.py: -------------------------------------------------------------------------------- 1 | import re 2 | import runpy 3 | import sys 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | from fastapi_cli.discover import get_module_data_from_path 8 | 9 | from dbos import DBOS 10 | 11 | 12 | class PythonModule: 13 | def __init__(self, module_name: str): 14 | self.module_name = module_name 15 | 16 | 17 | def debug_workflow(workflow_id: str, entrypoint: Union[str, PythonModule]) -> None: 18 | if isinstance(entrypoint, str): 19 | # ensure the entrypoint parent directory is in sys.path 20 | parent = str(Path(entrypoint).parent) 21 | if parent not in sys.path: 22 | sys.path.insert(0, parent) 23 | runpy.run_path(entrypoint) 24 | elif isinstance(entrypoint, PythonModule): 25 | runpy.run_module(entrypoint.module_name) 26 | else: 27 | raise ValueError("Invalid entrypoint type. Must be a string or PythonModule.") 28 | 29 | DBOS.logger.info(f"Debugging workflow {workflow_id}...") 30 | DBOS.launch(debug_mode=True) 31 | handle = DBOS._execute_workflow_id(workflow_id) 32 | handle.get_result() 33 | DBOS.logger.info("Workflow Debugging complete. Exiting process.") 34 | 35 | 36 | def parse_start_command(command: str) -> Union[str, PythonModule]: 37 | match = re.match(r"fastapi\s+run\s+(\.?[\w/]+\.py)", command) 38 | if match: 39 | # Mirror the logic in fastapi's run command by converting the path argument to a module 40 | mod_data = get_module_data_from_path(Path(match.group(1))) 41 | sys.path.insert(0, str(mod_data.extra_sys_path)) 42 | return PythonModule(mod_data.module_import_str) 43 | match = re.match(r"python3?\s+(\.?[\w/]+\.py)", command) 44 | if match: 45 | return match.group(1) 46 | match = re.match(r"python3?\s+-m\s+([\w\.]+)", command) 47 | if match: 48 | return PythonModule(match.group(1)) 49 | raise ValueError( 50 | "Invalid command format. Must be 'fastapi run 58 | 59 | 60 |

Welcome to DBOS!

61 |

62 | Visit the route /greeting/{name} to be greeted!
63 | For example, visit /greeting/dbos
64 | The counter increments with each page visit.
65 |

66 |

67 | To learn more about DBOS, check out the docs. 68 |

69 | 70 | 71 | """ 72 | return HTMLResponse(readme) 73 | 74 | 75 | # Finally, we'll launch DBOS then start the FastAPI server. 76 | 77 | if __name__ == "__main__": 78 | DBOS.launch() 79 | uvicorn.run(app, host="0.0.0.0", port=8000) 80 | -------------------------------------------------------------------------------- /dbos/_templates/dbos-db-starter/__package/schema.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, MetaData, String, Table 2 | 3 | metadata = MetaData() 4 | 5 | dbos_hello = Table( 6 | "dbos_hello", 7 | metadata, 8 | Column("greet_count", Integer, primary_key=True, autoincrement=True), 9 | Column("name", String, nullable=False), 10 | ) 11 | -------------------------------------------------------------------------------- /dbos/_templates/dbos-db-starter/alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | # Use forward slashes (/) also on windows to provide an os agnostic path 6 | script_location = migrations 7 | 8 | # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s 9 | # Uncomment the line below if you want the files to be prepended with date and time 10 | # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file 11 | # for all available tokens 12 | file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d%%(second).2d_%%(slug)s 13 | 14 | # sys.path path, will be prepended to sys.path if present. 15 | # defaults to the current working directory. 16 | prepend_sys_path = . 17 | 18 | # timezone to use when rendering the date within the migration file 19 | # as well as the filename. 20 | # If specified, requires the python>=3.9 or backports.zoneinfo library. 21 | # Any required deps can installed by adding `alembic[tz]` to the pip requirements 22 | # string value is passed to ZoneInfo() 23 | # leave blank for localtime 24 | # timezone = 25 | 26 | # max length of characters to apply to the "slug" field 27 | # truncate_slug_length = 40 28 | 29 | # set to 'true' to run the environment during 30 | # the 'revision' command, regardless of autogenerate 31 | # revision_environment = false 32 | 33 | # set to 'true' to allow .pyc and .pyo files without 34 | # a source .py file to be detected as revisions in the 35 | # versions/ directory 36 | # sourceless = false 37 | 38 | # version location specification; This defaults 39 | # to migrations/versions. When using multiple version 40 | # directories, initial revisions must be specified with --version-path. 41 | # The path separator used here should be the separator specified by "version_path_separator" below. 42 | # version_locations = %(here)s/bar:%(here)s/bat:migrations/versions 43 | 44 | # version path separator; As mentioned above, this is the character used to split 45 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 46 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 47 | # Valid values for version_path_separator are: 48 | # 49 | # version_path_separator = : 50 | # version_path_separator = ; 51 | # version_path_separator = space 52 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 53 | 54 | # set to 'true' to search source files recursively 55 | # in each "version_locations" directory 56 | # new in Alembic version 1.10 57 | # recursive_version_locations = false 58 | 59 | # the output encoding used when revision files 60 | # are written from script.py.mako 61 | # output_encoding = utf-8 62 | 63 | # sqlalchemy.url is set in env.py from the dbos-config.yaml file 64 | 65 | 66 | [post_write_hooks] 67 | # post_write_hooks defines scripts or Python functions that are run 68 | # on newly generated revision scripts. See the documentation for further 69 | # detail and examples 70 | 71 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 72 | # hooks = black 73 | # black.type = console_scripts 74 | # black.entrypoint = black 75 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 76 | 77 | # lint with attempts to fix using "ruff" - use the exec runner, execute a binary 78 | # hooks = ruff 79 | # ruff.type = exec 80 | # ruff.executable = %(here)s/.venv/bin/ruff 81 | # ruff.options = --fix REVISION_SCRIPT_FILENAME 82 | 83 | # Logging configuration 84 | [loggers] 85 | keys = root,sqlalchemy,alembic 86 | 87 | [handlers] 88 | keys = console 89 | 90 | [formatters] 91 | keys = generic 92 | 93 | [logger_root] 94 | level = WARN 95 | handlers = console 96 | qualname = 97 | 98 | [logger_sqlalchemy] 99 | level = WARN 100 | handlers = 101 | qualname = sqlalchemy.engine 102 | 103 | [logger_alembic] 104 | level = INFO 105 | handlers = 106 | qualname = alembic 107 | 108 | [handler_console] 109 | class = StreamHandler 110 | args = (sys.stderr,) 111 | level = NOTSET 112 | formatter = generic 113 | 114 | [formatter_generic] 115 | format = %(levelname)-5.5s [%(name)s] %(message)s 116 | datefmt = %H:%M:%S 117 | -------------------------------------------------------------------------------- /dbos/_templates/dbos-db-starter/dbos-config.yaml.dbos: -------------------------------------------------------------------------------- 1 | # To enable auto-completion and validation for this file in VSCode, install the RedHat YAML extension 2 | # https://marketplace.visualstudio.com/items?itemName=redhat.vscode-yaml 3 | 4 | # yaml-language-server: $schema=https://raw.githubusercontent.com/dbos-inc/dbos-transact-py/main/dbos/dbos-config.schema.json 5 | 6 | name: ${project_name} 7 | language: python 8 | runtimeConfig: 9 | start: 10 | - "${start_command}" 11 | database_url: ${DBOS_DATABASE_URL} 12 | ${migration_section} -------------------------------------------------------------------------------- /dbos/_templates/dbos-db-starter/migrations/env.py.dbos: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from logging.config import fileConfig 4 | 5 | from alembic import context 6 | from sqlalchemy import engine_from_config, pool 7 | 8 | # this is the Alembic Config object, which provides 9 | # access to the values within the .ini file in use. 10 | config = context.config 11 | 12 | # Interpret the config file for Python logging. 13 | # This line sets up loggers basically. 14 | if config.config_file_name is not None: 15 | fileConfig(config.config_file_name) 16 | 17 | # Programmatically set the sqlalchemy.url field to the DBOS application database URL 18 | conn_string = os.environ.get("DBOS_DATABASE_URL", "postgresql+psycopg://postgres:dbos@localhost:5432/${default_db_name}?connect_timeout=5") 19 | # Alembic requires the % in URL-escaped parameters to itself be escaped to %% 20 | escaped_conn_string = re.sub( 21 | r"%(?=[0-9A-Fa-f]{2})", 22 | "%%", 23 | conn_string, 24 | ) 25 | config.set_main_option("sqlalchemy.url", escaped_conn_string) 26 | 27 | # add your model's MetaData object here 28 | # for 'autogenerate' support 29 | from ${package_name}.schema import metadata 30 | target_metadata = metadata 31 | 32 | # other values from the config, defined by the needs of env.py, 33 | # can be acquired: 34 | # my_important_option = config.get_main_option("my_important_option") 35 | # ... etc. 36 | 37 | 38 | def run_migrations_offline() -> None: 39 | """Run migrations in 'offline' mode. 40 | 41 | This configures the context with just a URL 42 | and not an Engine, though an Engine is acceptable 43 | here as well. By skipping the Engine creation 44 | we don't even need a DBAPI to be available. 45 | 46 | Calls to context.execute() here emit the given string to the 47 | script output. 48 | 49 | """ 50 | url = config.get_main_option("sqlalchemy.url") 51 | context.configure( 52 | url=url, 53 | target_metadata=target_metadata, 54 | literal_binds=True, 55 | dialect_opts={"paramstyle": "named"}, 56 | ) 57 | 58 | with context.begin_transaction(): 59 | context.run_migrations() 60 | 61 | 62 | def run_migrations_online() -> None: 63 | """Run migrations in 'online' mode. 64 | 65 | In this scenario we need to create an Engine 66 | and associate a connection with the context. 67 | 68 | """ 69 | connectable = engine_from_config( 70 | config.get_section(config.config_ini_section, {}), 71 | prefix="sqlalchemy.", 72 | poolclass=pool.NullPool, 73 | ) 74 | 75 | with connectable.connect() as connection: 76 | context.configure(connection=connection, target_metadata=target_metadata) 77 | 78 | with context.begin_transaction(): 79 | context.run_migrations() 80 | 81 | 82 | if context.is_offline_mode(): 83 | run_migrations_offline() 84 | else: 85 | run_migrations_online() 86 | -------------------------------------------------------------------------------- /dbos/_templates/dbos-db-starter/migrations/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from typing import Sequence, Union 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | ${imports if imports else ""} 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = ${repr(up_revision)} 16 | down_revision: Union[str, None] = ${repr(down_revision)} 17 | branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} 18 | depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} 19 | 20 | 21 | def upgrade() -> None: 22 | ${upgrades if upgrades else "pass"} 23 | 24 | 25 | def downgrade() -> None: 26 | ${downgrades if downgrades else "pass"} 27 | -------------------------------------------------------------------------------- /dbos/_templates/dbos-db-starter/migrations/versions/2024_07_31_180642_init.py: -------------------------------------------------------------------------------- 1 | """ 2 | Initialize application database. 3 | 4 | Revision ID: c6b516e182b2 5 | Revises: 6 | Create Date: 2024-07-31 18:06:42.500040 7 | """ 8 | 9 | from typing import Sequence, Union 10 | 11 | import sqlalchemy as sa 12 | from alembic import op 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = "c6b516e182b2" 16 | down_revision: Union[str, None] = None 17 | branch_labels: Union[str, Sequence[str], None] = None 18 | depends_on: Union[str, Sequence[str], None] = None 19 | 20 | 21 | def upgrade() -> None: 22 | # ### commands auto generated by Alembic - please adjust! ### 23 | op.create_table( 24 | "dbos_hello", 25 | sa.Column("greet_count", sa.Integer(), autoincrement=True, nullable=False), 26 | sa.Column("name", sa.String(), nullable=False), 27 | sa.PrimaryKeyConstraint("greet_count"), 28 | ) 29 | # ### end Alembic commands ### 30 | 31 | 32 | def downgrade() -> None: 33 | # ### commands auto generated by Alembic - please adjust! ### 34 | op.drop_table("dbos_hello") 35 | # ### end Alembic commands ### 36 | -------------------------------------------------------------------------------- /dbos/_templates/dbos-db-starter/start_postgres_docker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | import time 5 | 6 | # Default PostgreSQL port 7 | port = "5432" 8 | 9 | # Set the host PostgreSQL port with the -p/--port flag 10 | for i, arg in enumerate(sys.argv): 11 | if arg in ["-p", "--port"]: 12 | if i + 1 < len(sys.argv): 13 | port = sys.argv[i + 1] 14 | 15 | if "PGPASSWORD" not in os.environ: 16 | print("Error: PGPASSWORD is not set.") 17 | sys.exit(1) 18 | 19 | try: 20 | subprocess.run( 21 | [ 22 | "docker", 23 | "run", 24 | "--rm", 25 | "--name=dbos-db", 26 | f'--env=POSTGRES_PASSWORD={os.environ["PGPASSWORD"]}', 27 | "--env=PGDATA=/var/lib/postgresql/data", 28 | "--volume=/var/lib/postgresql/data", 29 | "-p", 30 | f"{port}:5432", 31 | "-d", 32 | "pgvector/pgvector:pg16", 33 | ], 34 | check=True, 35 | ) 36 | 37 | print("Waiting for PostgreSQL to start...") 38 | attempts = 30 39 | 40 | while attempts > 0: 41 | try: 42 | subprocess.run( 43 | [ 44 | "docker", 45 | "exec", 46 | "dbos-db", 47 | "psql", 48 | "-U", 49 | "postgres", 50 | "-c", 51 | "SELECT 1;", 52 | ], 53 | check=True, 54 | capture_output=True, 55 | ) 56 | print("PostgreSQL started!") 57 | print("Database started successfully!") 58 | break 59 | except subprocess.CalledProcessError: 60 | attempts -= 1 61 | time.sleep(1) 62 | 63 | if attempts == 0: 64 | print("Failed to start PostgreSQL.") 65 | 66 | except subprocess.CalledProcessError as error: 67 | print(f"Error starting PostgreSQL in Docker: {error}") 68 | -------------------------------------------------------------------------------- /dbos/_tracer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import TYPE_CHECKING, Optional 3 | 4 | from opentelemetry import trace 5 | from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter 6 | from opentelemetry.sdk.resources import Resource 7 | from opentelemetry.sdk.trace import TracerProvider 8 | from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter 9 | from opentelemetry.semconv.resource import ResourceAttributes 10 | from opentelemetry.trace import Span 11 | 12 | from dbos._utils import GlobalParams 13 | 14 | from ._dbos_config import ConfigFile 15 | 16 | if TYPE_CHECKING: 17 | from ._context import TracedAttributes 18 | 19 | 20 | class DBOSTracer: 21 | 22 | def __init__(self) -> None: 23 | self.app_id = os.environ.get("DBOS__APPID", None) 24 | self.provider: Optional[TracerProvider] = None 25 | 26 | def config(self, config: ConfigFile) -> None: 27 | if not isinstance(trace.get_tracer_provider(), TracerProvider): 28 | resource = Resource( 29 | attributes={ 30 | ResourceAttributes.SERVICE_NAME: config["name"], 31 | } 32 | ) 33 | 34 | provider = TracerProvider(resource=resource) 35 | if os.environ.get("DBOS__CONSOLE_TRACES", None) is not None: 36 | processor = BatchSpanProcessor(ConsoleSpanExporter()) 37 | provider.add_span_processor(processor) 38 | otlp_traces_endpoints = ( 39 | config.get("telemetry", {}).get("OTLPExporter", {}).get("tracesEndpoint") # type: ignore 40 | ) 41 | if otlp_traces_endpoints: 42 | for e in otlp_traces_endpoints: 43 | processor = BatchSpanProcessor(OTLPSpanExporter(endpoint=e)) 44 | provider.add_span_processor(processor) 45 | trace.set_tracer_provider(provider) 46 | 47 | def set_provider(self, provider: Optional[TracerProvider]) -> None: 48 | self.provider = provider 49 | 50 | def start_span( 51 | self, attributes: "TracedAttributes", parent: Optional[Span] = None 52 | ) -> Span: 53 | tracer = ( 54 | self.provider.get_tracer("dbos-tracer") 55 | if self.provider is not None 56 | else trace.get_tracer("dbos-tracer") 57 | ) 58 | context = trace.set_span_in_context(parent) if parent else None 59 | span: Span = tracer.start_span(name=attributes["name"], context=context) 60 | attributes["applicationID"] = self.app_id 61 | attributes["applicationVersion"] = GlobalParams.app_version 62 | attributes["executorID"] = GlobalParams.executor_id 63 | for k, v in attributes.items(): 64 | if k != "name" and v is not None and isinstance(v, (str, bool, int, float)): 65 | span.set_attribute(k, v) 66 | return span 67 | 68 | def end_span(self, span: Span) -> None: 69 | span.end() 70 | 71 | 72 | dbos_tracer = DBOSTracer() 73 | -------------------------------------------------------------------------------- /dbos/_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import os 3 | 4 | import psycopg 5 | from sqlalchemy.exc import DBAPIError 6 | 7 | INTERNAL_QUEUE_NAME = "_dbos_internal_queue" 8 | 9 | request_id_header = "x-request-id" 10 | 11 | 12 | class GlobalParams: 13 | app_version: str = os.environ.get("DBOS__APPVERSION", "") 14 | executor_id: str = os.environ.get("DBOS__VMID", "local") 15 | try: 16 | # Only works on Python >= 3.8 17 | dbos_version = importlib.metadata.version("dbos") 18 | except importlib.metadata.PackageNotFoundError: 19 | # If package is not installed or during development 20 | dbos_version = "unknown" 21 | 22 | 23 | def retriable_postgres_exception(e: DBAPIError) -> bool: 24 | if e.connection_invalidated: 25 | return True 26 | if isinstance(e.orig, psycopg.OperationalError): 27 | driver_error: psycopg.OperationalError = e.orig 28 | pgcode = driver_error.sqlstate or "" 29 | # Failure to establish connection 30 | if "connection failed" in str(driver_error): 31 | return True 32 | # Error within database transaction 33 | elif "server closed the connection unexpectedly" in str(driver_error): 34 | return True 35 | # Connection timeout 36 | if isinstance(driver_error, psycopg.errors.ConnectionTimeout): 37 | return True 38 | # Insufficient resources 39 | elif pgcode.startswith("53"): 40 | return True 41 | # Connection exception 42 | elif pgcode.startswith("08"): 43 | return True 44 | # Operator intervention 45 | elif pgcode.startswith("57"): 46 | return True 47 | else: 48 | return False 49 | else: 50 | return False 51 | -------------------------------------------------------------------------------- /dbos/_workflow_commands.py: -------------------------------------------------------------------------------- 1 | import time 2 | import uuid 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING, List, Optional 5 | 6 | from dbos._context import get_local_dbos_context 7 | 8 | from ._app_db import ApplicationDatabase 9 | from ._sys_db import ( 10 | GetQueuedWorkflowsInput, 11 | GetWorkflowsInput, 12 | StepInfo, 13 | SystemDatabase, 14 | WorkflowStatus, 15 | WorkflowStatusString, 16 | ) 17 | 18 | if TYPE_CHECKING: 19 | from ._dbos import DBOS 20 | 21 | 22 | def list_workflows( 23 | sys_db: SystemDatabase, 24 | *, 25 | workflow_ids: Optional[List[str]] = None, 26 | status: Optional[str] = None, 27 | start_time: Optional[str] = None, 28 | end_time: Optional[str] = None, 29 | name: Optional[str] = None, 30 | app_version: Optional[str] = None, 31 | user: Optional[str] = None, 32 | limit: Optional[int] = None, 33 | offset: Optional[int] = None, 34 | sort_desc: bool = False, 35 | workflow_id_prefix: Optional[str] = None, 36 | ) -> List[WorkflowStatus]: 37 | input = GetWorkflowsInput() 38 | input.workflow_ids = workflow_ids 39 | input.authenticated_user = user 40 | input.start_time = start_time 41 | input.end_time = end_time 42 | input.status = status 43 | input.application_version = app_version 44 | input.limit = limit 45 | input.name = name 46 | input.offset = offset 47 | input.sort_desc = sort_desc 48 | input.workflow_id_prefix = workflow_id_prefix 49 | 50 | infos: List[WorkflowStatus] = sys_db.get_workflows(input) 51 | 52 | return infos 53 | 54 | 55 | def list_queued_workflows( 56 | sys_db: SystemDatabase, 57 | *, 58 | queue_name: Optional[str] = None, 59 | status: Optional[str] = None, 60 | start_time: Optional[str] = None, 61 | end_time: Optional[str] = None, 62 | name: Optional[str] = None, 63 | limit: Optional[int] = None, 64 | offset: Optional[int] = None, 65 | sort_desc: bool = False, 66 | ) -> List[WorkflowStatus]: 67 | input: GetQueuedWorkflowsInput = { 68 | "queue_name": queue_name, 69 | "start_time": start_time, 70 | "end_time": end_time, 71 | "status": status, 72 | "limit": limit, 73 | "name": name, 74 | "offset": offset, 75 | "sort_desc": sort_desc, 76 | } 77 | 78 | infos: List[WorkflowStatus] = sys_db.get_queued_workflows(input) 79 | return infos 80 | 81 | 82 | def get_workflow(sys_db: SystemDatabase, workflow_id: str) -> Optional[WorkflowStatus]: 83 | input = GetWorkflowsInput() 84 | input.workflow_ids = [workflow_id] 85 | 86 | infos: List[WorkflowStatus] = sys_db.get_workflows(input) 87 | if not infos: 88 | return None 89 | 90 | return infos[0] 91 | 92 | 93 | def list_workflow_steps( 94 | sys_db: SystemDatabase, app_db: ApplicationDatabase, workflow_id: str 95 | ) -> List[StepInfo]: 96 | steps = sys_db.get_workflow_steps(workflow_id) 97 | transactions = app_db.get_transactions(workflow_id) 98 | merged_steps = steps + transactions 99 | merged_steps.sort(key=lambda step: step["function_id"]) 100 | return merged_steps 101 | 102 | 103 | def fork_workflow( 104 | sys_db: SystemDatabase, 105 | app_db: ApplicationDatabase, 106 | workflow_id: str, 107 | start_step: int, 108 | *, 109 | application_version: Optional[str], 110 | ) -> str: 111 | 112 | ctx = get_local_dbos_context() 113 | if ctx is not None and len(ctx.id_assigned_for_next_workflow) > 0: 114 | forked_workflow_id = ctx.id_assigned_for_next_workflow 115 | ctx.id_assigned_for_next_workflow = "" 116 | else: 117 | forked_workflow_id = str(uuid.uuid4()) 118 | app_db.clone_workflow_transactions(workflow_id, forked_workflow_id, start_step) 119 | sys_db.fork_workflow( 120 | workflow_id, 121 | forked_workflow_id, 122 | start_step, 123 | application_version=application_version, 124 | ) 125 | return forked_workflow_id 126 | 127 | 128 | def garbage_collect( 129 | dbos: "DBOS", 130 | cutoff_epoch_timestamp_ms: Optional[int], 131 | rows_threshold: Optional[int], 132 | ) -> None: 133 | if cutoff_epoch_timestamp_ms is None and rows_threshold is None: 134 | return 135 | result = dbos._sys_db.garbage_collect( 136 | cutoff_epoch_timestamp_ms=cutoff_epoch_timestamp_ms, 137 | rows_threshold=rows_threshold, 138 | ) 139 | if result is not None: 140 | cutoff_epoch_timestamp_ms, pending_workflow_ids = result 141 | dbos._app_db.garbage_collect(cutoff_epoch_timestamp_ms, pending_workflow_ids) 142 | 143 | 144 | def global_timeout(dbos: "DBOS", cutoff_epoch_timestamp_ms: int) -> None: 145 | cutoff_iso = datetime.fromtimestamp(cutoff_epoch_timestamp_ms / 1000).isoformat() 146 | for workflow in dbos.list_workflows( 147 | status=WorkflowStatusString.PENDING.value, end_time=cutoff_iso 148 | ): 149 | dbos.cancel_workflow(workflow.workflow_id) 150 | for workflow in dbos.list_workflows( 151 | status=WorkflowStatusString.ENQUEUED.value, end_time=cutoff_iso 152 | ): 153 | dbos.cancel_workflow(workflow.workflow_id) 154 | -------------------------------------------------------------------------------- /dbos/cli/_github_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | from base64 import b64decode 3 | from typing import List, TypedDict 4 | 5 | import requests 6 | 7 | DEMO_REPO_API = "https://api.github.com/repos/dbos-inc/dbos-demo-apps" 8 | PY_DEMO_PATH = "python/" 9 | BRANCH = "main" 10 | 11 | 12 | class GitHubTreeItem(TypedDict): 13 | path: str 14 | mode: str 15 | type: str 16 | sha: str 17 | url: str 18 | size: int 19 | 20 | 21 | class GitHubTree(TypedDict): 22 | sha: str 23 | url: str 24 | tree: List[GitHubTreeItem] 25 | truncated: bool 26 | 27 | 28 | class GitHubItem(TypedDict): 29 | sha: str 30 | node_id: str 31 | url: str 32 | content: str 33 | encoding: str 34 | size: int 35 | 36 | 37 | def _fetch_github(url: str) -> requests.Response: 38 | headers = {} 39 | github_token = os.getenv("GITHUB_TOKEN") 40 | if github_token: 41 | headers["Authorization"] = f"Bearer {github_token}" 42 | 43 | response = requests.get(url, headers=headers) 44 | 45 | if not response.ok: 46 | if response.headers.get("x-ratelimit-remaining") == "0": 47 | raise Exception( 48 | "Error fetching from GitHub API: rate limit exceeded.\n" 49 | "Please wait a few minutes and try again.\n" 50 | "To increase the limit, you can create a personal access token and set it in the GITHUB_TOKEN environment variable.\n" 51 | "Details: https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api" 52 | ) 53 | elif response.status_code == 401: 54 | raise Exception( 55 | f"Error fetching content from GitHub {url}: {response.status_code} {response.reason}.\n" 56 | "Please ensure your GITHUB_TOKEN environment variable is set to a valid personal access token." 57 | ) 58 | raise Exception( 59 | f"Error fetching content from GitHub {url}: {response.status_code} {response.reason}" 60 | ) 61 | 62 | return response 63 | 64 | 65 | def _fetch_github_tree(tag: str) -> List[GitHubTreeItem]: 66 | response = _fetch_github(f"{DEMO_REPO_API}/git/trees/{tag}?recursive=1") 67 | tree_data: GitHubTree = response.json() 68 | return tree_data["tree"] 69 | 70 | 71 | def _fetch_github_item(url: str) -> str: 72 | response = _fetch_github(url) 73 | item: GitHubItem = response.json() 74 | return b64decode(item["content"]).decode("utf-8") 75 | 76 | 77 | def create_template_from_github(app_name: str, template_name: str) -> None: 78 | print( 79 | f"Creating a new application named {app_name} from the template {template_name}" 80 | ) 81 | 82 | tree = _fetch_github_tree(BRANCH) 83 | template_path = f"{PY_DEMO_PATH}{template_name}/" 84 | 85 | files_to_download = [ 86 | item 87 | for item in tree 88 | if item["path"].startswith(template_path) and item["type"] == "blob" 89 | ] 90 | 91 | # Download every file from the template 92 | for item in files_to_download: 93 | raw_content = _fetch_github_item(item["url"]) 94 | file_path = item["path"].replace(template_path, "") 95 | target_path = os.path.join(".", file_path) 96 | 97 | # Create directory if it doesn't exist 98 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 99 | 100 | # Write file with proper permissions 101 | with open(target_path, "w", encoding="utf-8") as f: 102 | f.write(raw_content) 103 | os.chmod(target_path, int(item["mode"], 8)) 104 | 105 | print( 106 | f"Downloaded {len(files_to_download)} files from the template GitHub repository" 107 | ) 108 | -------------------------------------------------------------------------------- /dbos/cli/_template_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import typing 4 | from os import path 5 | from typing import Any 6 | 7 | import tomlkit 8 | from rich import print 9 | 10 | from dbos._dbos_config import _app_name_to_db_name 11 | 12 | 13 | def get_templates_directory() -> str: 14 | import dbos 15 | 16 | package_dir = path.abspath(path.dirname(dbos.__file__)) 17 | return path.join(package_dir, "_templates") 18 | 19 | 20 | def _copy_dbos_template(src: str, dst: str, ctx: dict[str, str]) -> None: 21 | with open(src, "r") as f: 22 | content = f.read() 23 | 24 | for key, value in ctx.items(): 25 | content = content.replace(f"${{{key}}}", value) 26 | 27 | with open(dst, "w") as f: 28 | f.write(content) 29 | 30 | 31 | def _copy_template_dir(src_dir: str, dst_dir: str, ctx: dict[str, str]) -> None: 32 | 33 | for root, dirs, files in os.walk(src_dir, topdown=True): 34 | dirs[:] = [d for d in dirs if d != "__package"] 35 | 36 | dst_root = path.join(dst_dir, path.relpath(root, src_dir)) 37 | if len(dirs) == 0: 38 | os.makedirs(dst_root, exist_ok=True) 39 | else: 40 | for dir in dirs: 41 | os.makedirs(path.join(dst_root, dir), exist_ok=True) 42 | 43 | for file in files: 44 | src = path.join(root, file) 45 | base, ext = path.splitext(file) 46 | 47 | dst = path.join(dst_root, base if ext == ".dbos" else file) 48 | if path.exists(dst): 49 | print(f"[yellow]File {dst} already exists, skipping[/yellow]") 50 | continue 51 | 52 | if ext == ".dbos": 53 | _copy_dbos_template(src, dst, ctx) 54 | else: 55 | shutil.copy(src, dst) 56 | 57 | 58 | def copy_template(src_dir: str, project_name: str, config_mode: bool) -> None: 59 | 60 | dst_dir = path.abspath(".") 61 | 62 | package_name = project_name.replace("-", "_") 63 | default_migration_section = """database: 64 | migrate: 65 | - alembic upgrade head 66 | """ 67 | ctx = { 68 | "project_name": project_name, 69 | "default_db_name": _app_name_to_db_name(project_name), 70 | "package_name": package_name, 71 | "start_command": f"python3 -m {package_name}.main", 72 | "migration_section": default_migration_section, 73 | } 74 | 75 | if config_mode: 76 | ctx["start_command"] = "python3 main.py" 77 | ctx["migration_section"] = "" 78 | _copy_dbos_template( 79 | os.path.join(src_dir, "dbos-config.yaml.dbos"), 80 | os.path.join(dst_dir, "dbos-config.yaml"), 81 | ctx, 82 | ) 83 | else: 84 | _copy_template_dir(src_dir, dst_dir, ctx) 85 | _copy_template_dir( 86 | path.join(src_dir, "__package"), path.join(dst_dir, package_name), ctx 87 | ) 88 | 89 | 90 | def get_project_name() -> typing.Union[str, None]: 91 | name = None 92 | try: 93 | with open("pyproject.toml", "rb") as file: 94 | pyproj = typing.cast(dict[str, Any], tomlkit.load(file)) 95 | name = typing.cast(str, pyproj["project"]["name"]) 96 | except: 97 | pass 98 | 99 | if name == None: 100 | try: 101 | _, parent = path.split(path.abspath(".")) 102 | name = parent 103 | except: 104 | pass 105 | 106 | return name 107 | -------------------------------------------------------------------------------- /dbos/dbos-config.schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-07/schema#", 3 | "title": "DBOS Config", 4 | "type": "object", 5 | "additionalProperties": false, 6 | "properties": { 7 | "name": { 8 | "type": "string", 9 | "description": "The name of your application" 10 | }, 11 | "language": { 12 | "type": "string", 13 | "description": "The language used in your application", 14 | "enum": [ 15 | "python" 16 | ] 17 | }, 18 | "database_url": { 19 | "type": ["string", "null"], 20 | "description": "The URL of the application database" 21 | }, 22 | "database": { 23 | "type": "object", 24 | "additionalProperties": false, 25 | "properties": { 26 | "hostname": { 27 | "type": ["string", "null"], 28 | "description": "The hostname or IP address of the application database. DEPRECATED: Use database_url instead", 29 | "deprecated": true 30 | }, 31 | "port": { 32 | "type": ["number", "null"], 33 | "description": "The port number of the application database. DEPRECATED: Use database_url instead", 34 | "deprecated": true 35 | }, 36 | "username": { 37 | "type": ["string", "null"], 38 | "description": "The username to use when connecting to the application database. DEPRECATED: Use database_url instead", 39 | "not": { 40 | "enum": ["dbos"] 41 | }, 42 | "deprecated": true 43 | }, 44 | "password": { 45 | "type": ["string", "null"], 46 | "description": "The password to use when connecting to the application database. Developers are strongly encouraged to use environment variable substitution (${VAR_NAME}) or Docker secrets (${DOCKER_SECRET:SECRET_NAME}) to avoid storing secrets in source. DEPRECATED: Use database_url instead", 47 | "deprecated": true 48 | }, 49 | "connectionTimeoutMillis": { 50 | "type": ["number", "null"], 51 | "description": "The number of milliseconds the system waits before timing out when connecting to the application database. DEPRECATED: Use database_url instead", 52 | "deprecated": true 53 | }, 54 | "app_db_name": { 55 | "type": ["string", "null"], 56 | "description": "The name of the application database. DEPRECATED: Use database_url instead", 57 | "deprecated": true 58 | }, 59 | "sys_db_name": { 60 | "type": "string", 61 | "description": "The name of the system database" 62 | }, 63 | "ssl": { 64 | "type": ["boolean", "null"], 65 | "description": "Use SSL/TLS to securely connect to the database (default: true). DEPRECATED: Use database_url instead", 66 | "deprecated": true 67 | }, 68 | "ssl_ca": { 69 | "type": ["string", "null"], 70 | "description": "If using SSL/TLS to securely connect to a database, path to an SSL root certificate file. DEPRECATED: Use database_url instead", 71 | "deprecated": true 72 | }, 73 | "migrate": { 74 | "type": "array", 75 | "description": "Specify a list of user DB migration commands to run" 76 | }, 77 | "rollback": { 78 | "type": "array", 79 | "description": "Specify a list of user DB rollback commands to run. DEPRECATED", 80 | "deprecated": true 81 | } 82 | } 83 | }, 84 | "telemetry": { 85 | "type": "object", 86 | "additionalProperties": false, 87 | "properties": { 88 | "logs": { 89 | "type": "object", 90 | "additionalProperties": false, 91 | "properties": { 92 | "addContextMetadata": { 93 | "type": "boolean", 94 | "description": "Adds contextual information, such as workflow UUID, to each log entry" 95 | }, 96 | "logLevel": { 97 | "type": "string", 98 | "description": "A filter on what logs should be printed to the standard output" 99 | }, 100 | "silent": { 101 | "type": "boolean", 102 | "description": "Silences the logger such that nothing is printed to the standard output" 103 | } 104 | } 105 | }, 106 | "OTLPExporter": { 107 | "type": "object", 108 | "additionalProperties": false, 109 | "properties": { 110 | "logsEndpoint": { 111 | "type": "string", 112 | "description": "The URL of an OTLP collector to which to export logs" 113 | }, 114 | "tracesEndpoint": { 115 | "type": "string", 116 | "description": "The URL of an OTLP collector to which to export traces" 117 | } 118 | } 119 | } 120 | } 121 | }, 122 | "runtimeConfig": { 123 | "type": "object", 124 | "additionalProperties": false, 125 | "properties": { 126 | "entrypoints": { 127 | "type": "array", 128 | "items": { 129 | "type": "string" 130 | } 131 | }, 132 | "port": { 133 | "type": "number" 134 | }, 135 | "start": { 136 | "type": "array", 137 | "description": "Specify commands to run to start your application (Python only)" 138 | }, 139 | "setup": { 140 | "type": "array", 141 | "items": { 142 | "type": "string" 143 | }, 144 | "description": "Commands to setup the application execution environment" 145 | }, 146 | "admin_port": { 147 | "type": "number", 148 | "description": "The port number of the admin server (Default: 3001)" 149 | } 150 | } 151 | }, 152 | "http": { 153 | "type": "object", 154 | "additionalProperties": false, 155 | "properties": { 156 | "cors_middleware": { 157 | "type": "boolean" 158 | }, 159 | "credentials": { 160 | "type": "boolean" 161 | }, 162 | "allowed_origins": { 163 | "type": "array", 164 | "items": { 165 | "type": "string" 166 | } 167 | } 168 | } 169 | }, 170 | "application": {}, 171 | "env": {}, 172 | "version": { 173 | "type": "string", 174 | "deprecated": true 175 | } 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /dbos/py.typed: -------------------------------------------------------------------------------- 1 | # Yes, we typed everything. You're welcome. 2 | -------------------------------------------------------------------------------- /make_release.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import typer 5 | from git import Optional, Repo 6 | 7 | app = typer.Typer() 8 | 9 | 10 | @app.command() 11 | def make_release(version_number: Optional[str] = None) -> None: 12 | repo = Repo(os.getcwd()) 13 | if repo.is_dirty(): 14 | raise Exception("Local git repository is not clean") 15 | if repo.active_branch.name != "main": 16 | raise Exception("Can only make a release from main") 17 | remote_branch = repo.references[f"origin/{repo.active_branch.name}"] 18 | local_commit = repo.active_branch.commit 19 | remote_commit = remote_branch.commit 20 | if local_commit != remote_commit: 21 | raise Exception( 22 | f"Your local branch {repo.active_branch.name} is not up to date with origin." 23 | ) 24 | 25 | if version_number is None: 26 | version_number = guess_next_version(repo) 27 | version_pattern = r"^\d+\.\d+\.\d+$" 28 | if not re.match(version_pattern, version_number): 29 | raise Exception(f"Invalid version number: {version_number}") 30 | 31 | create_and_push_release_tag(repo=repo, version_number=version_number) 32 | create_and_push_release_branch(repo=repo, version_number=version_number) 33 | 34 | 35 | def guess_next_version(repo: Repo) -> str: 36 | tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True) 37 | for tag in tags: 38 | if repo.is_ancestor(tag.commit, repo.heads.main.commit): 39 | last_tag = tag.name 40 | break 41 | if last_tag is None: 42 | raise Exception("No previous tags found") 43 | 44 | major, minor, patch = map(int, last_tag.split(".")) 45 | minor += 1 46 | return f"{major}.{minor}.{patch}" 47 | 48 | 49 | def create_and_push_release_tag(repo: Repo, version_number: str) -> None: 50 | release_tag = repo.create_tag(version_number) 51 | push_info = repo.remote("origin").push(release_tag) 52 | if push_info[0].flags & push_info[0].ERROR: 53 | raise Exception(f"Failed to push tags: {push_info[0].summary}") 54 | print(f"Release tag pushed: {version_number}") 55 | 56 | 57 | def create_and_push_release_branch(repo: Repo, version_number: str) -> None: 58 | branch_name = f"release/v{version_number}" 59 | new_branch = repo.create_head(branch_name, repo.heads["main"]) 60 | new_branch.checkout() 61 | push_info = repo.remote("origin").push(new_branch) 62 | if push_info[0].flags & push_info[0].ERROR: 63 | raise Exception(f"Failed to push branch: {push_info[0].summary}") 64 | print(f"Release branch pushed: {branch_name}") 65 | 66 | 67 | if __name__ == "__main__": 68 | app() 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "dbos" 3 | dynamic = ["version"] 4 | description = "Ultra-lightweight durable execution in Python" 5 | authors = [ 6 | {name = "DBOS, Inc.", email = "contact@dbos.dev"}, 7 | ] 8 | dependencies = [ 9 | "pyyaml>=6.0.2", 10 | "jsonschema>=4.23.0", 11 | "alembic>=1.13.3", 12 | "typing-extensions>=4.12.2; python_version < \"3.10\"", 13 | "typer>=0.12.5", 14 | "jsonpickle>=3.3.0", 15 | "opentelemetry-api>=1.27.0", 16 | "opentelemetry-sdk>=1.27.0", 17 | "opentelemetry-exporter-otlp-proto-http>=1.27.0", 18 | "python-dateutil>=2.9.0.post0", 19 | "fastapi[standard]>=0.115.2", 20 | "tomlkit>=0.13.2", 21 | "psycopg[binary]>=3.1", # Keep compatibility with 3.1--older Python installations/machines can't always install 3.2 22 | "docker>=7.1.0", 23 | "cryptography>=43.0.3", 24 | "rich>=13.9.4", 25 | "pyjwt>=2.10.1", 26 | "websockets>=14.0", 27 | ] 28 | requires-python = ">=3.9" 29 | readme = "README.md" 30 | license = {text = "MIT"} 31 | 32 | [project.scripts] 33 | dbos = "dbos.cli.cli:app" 34 | 35 | [build-system] 36 | requires = ["pdm-backend"] 37 | build-backend = "pdm.backend" 38 | 39 | [tool.pdm] 40 | distribution = true 41 | 42 | [tool.pdm.version] 43 | source = "scm" 44 | version_format = "version:format_version" 45 | 46 | [tool.black] 47 | line-length = 88 48 | 49 | [tool.isort] 50 | profile = "black" 51 | filter_files = true 52 | atomic = true 53 | 54 | [tool.mypy] 55 | strict = true 56 | 57 | [tool.pytest.ini_options] 58 | addopts = "-s" 59 | log_cli_format = "%(asctime)s [%(levelname)8s] (%(name)s:%(filename)s:%(lineno)s) %(message)s" 60 | log_cli_level = "INFO" 61 | log_cli = true 62 | timeout = 120 # Terminate any test that takes longer than 120 seconds 63 | 64 | [dependency-groups] 65 | dev = [ 66 | "pytest>=8.3.3", 67 | "mypy>=1.15.0", 68 | "pytest-mock>=3.14.0", 69 | "types-PyYAML>=6.0.12.20240808", 70 | "types-jsonschema>=4.23.0.20240813", 71 | "black>=24.10.0", 72 | "pre-commit>=4.0.1", 73 | "isort>=5.13.2", 74 | "requests>=2.32.3", 75 | "types-requests>=2.32.0.20240914", 76 | "httpx>=0.27.2", 77 | "pytz>=2024.2", 78 | "GitPython>=3.1.43", 79 | "confluent-kafka>=2.6.0", 80 | "types-confluent-kafka>=1.2.2", 81 | "flask>=3.0.3", 82 | "pytest-order>=1.3.0", 83 | "pdm-backend>=2.4.2", 84 | "pytest-asyncio>=0.25.0", 85 | "pyright>=1.1.398", 86 | "types-docker>=7.1.0.20241229", 87 | "pytest-timeout>=2.3.1", 88 | ] 89 | -------------------------------------------------------------------------------- /pyrightconfig.test.json: -------------------------------------------------------------------------------- 1 | { 2 | "typeCheckingMode": "standard", 3 | "reportTypedDictNotRequiredAccess": false, 4 | "reportRedeclaration": false 5 | } -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbos-inc/dbos-transact-py/89acf789007f46819b4f1eac70eaf592cc9a6719/tests/__init__.py -------------------------------------------------------------------------------- /tests/atexit_no_ctor.py: -------------------------------------------------------------------------------- 1 | from dbos import DBOS 2 | 3 | 4 | @DBOS.workflow() 5 | def my_function(foo: str) -> str: 6 | return foo 7 | -------------------------------------------------------------------------------- /tests/atexit_no_launch.py: -------------------------------------------------------------------------------- 1 | from dbos import DBOS 2 | from dbos._dbos_config import DBOSConfig 3 | 4 | 5 | @DBOS.workflow() 6 | def my_function(foo: str) -> str: 7 | return foo 8 | 9 | 10 | def default_config() -> DBOSConfig: 11 | return { 12 | "name": "forgot-launch", 13 | "database_url": f"postgresql://postgres:doesntmatter@localhost:5432/notneeded", 14 | } 15 | 16 | 17 | DBOS(config=default_config()) 18 | -------------------------------------------------------------------------------- /tests/classdefs.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | 3 | # Public API 4 | from dbos import DBOS, DBOSConfiguredInstance 5 | 6 | 7 | @DBOS.dbos_class() 8 | class DBOSTestClass(DBOSConfiguredInstance): 9 | txn_counter_c = 0 10 | wf_counter_c = 0 11 | step_counter_c = 0 12 | 13 | def __init__(self) -> None: 14 | super().__init__("myconfig") 15 | self.txn_counter: int = 0 16 | self.wf_counter: int = 0 17 | self.step_counter: int = 0 18 | 19 | @classmethod 20 | @DBOS.workflow() 21 | def test_workflow_cls(cls, var: str, var2: str) -> str: 22 | cls.wf_counter_c += 1 23 | res = DBOSTestClass.test_transaction_cls(var2) 24 | res2 = DBOSTestClass.test_step_cls(var) 25 | return res + res2 26 | 27 | @classmethod 28 | @DBOS.transaction() 29 | def test_transaction_cls(cls, var2: str) -> str: 30 | rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall() 31 | cls.txn_counter_c += 1 32 | return var2 + str(rows[0][0]) 33 | 34 | @classmethod 35 | @DBOS.step() 36 | def test_step_cls(cls, var: str) -> str: 37 | cls.step_counter_c += 1 38 | return var 39 | 40 | @DBOS.workflow() 41 | def test_workflow(self, var: str, var2: str) -> str: 42 | self.wf_counter += 1 43 | res = self.test_transaction(var2) 44 | res2 = self.test_step(var) 45 | return res + res2 46 | 47 | @DBOS.transaction() 48 | def test_transaction(self, var2: str) -> str: 49 | rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall() 50 | self.txn_counter += 1 51 | return var2 + str(rows[0][0]) 52 | 53 | @DBOS.step() 54 | def test_step(self, var: str) -> str: 55 | self.step_counter += 1 56 | return var 57 | 58 | @DBOS.workflow() 59 | @DBOS.required_roles(["admin"]) 60 | def test_func_admin(self, var: str) -> str: 61 | assert DBOS.assumed_role == "admin" 62 | return self.config_name + ":" + var 63 | 64 | 65 | @DBOS.default_required_roles(["user"]) 66 | class DBOSTestRoles(DBOSConfiguredInstance): 67 | @staticmethod 68 | @DBOS.workflow() 69 | def greetfunc(name: str) -> str: 70 | return f"Hello {name}" 71 | 72 | 73 | @DBOS.dbos_class() 74 | class DBOSSendRecv: 75 | send_counter: int = 0 76 | recv_counter: int = 0 77 | 78 | @staticmethod 79 | @DBOS.workflow() 80 | def test_send_workflow(dest_uuid: str, topic: str) -> str: 81 | DBOS.send(dest_uuid, "test1") 82 | DBOS.send(dest_uuid, "test2", topic=topic) 83 | DBOS.send(dest_uuid, "test3") 84 | DBOSSendRecv.send_counter += 1 85 | return dest_uuid 86 | 87 | @staticmethod 88 | @DBOS.workflow() 89 | def test_recv_workflow(topic: str) -> str: 90 | msg1 = DBOS.recv(topic, timeout_seconds=10) 91 | msg2 = DBOS.recv(timeout_seconds=10) 92 | msg3 = DBOS.recv(timeout_seconds=10) 93 | DBOSSendRecv.recv_counter += 1 94 | return "-".join([str(msg1), str(msg2), str(msg3)]) 95 | -------------------------------------------------------------------------------- /tests/client_collateral.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Optional, TypedDict, cast 3 | 4 | from dbos import DBOS, Queue 5 | 6 | 7 | class Person(TypedDict): 8 | first: str 9 | last: str 10 | age: int 11 | 12 | 13 | queue = Queue("test_queue") 14 | inorder_queue = Queue("inorder_queue", 1, priority_enabled=True) 15 | inorder_results: List[str] = [] 16 | 17 | 18 | @DBOS.workflow() 19 | def enqueue_test(numVal: int, strVal: str, person: Person) -> str: 20 | return f"{numVal}-{strVal}-{json.dumps(person)}" 21 | 22 | 23 | @DBOS.workflow() 24 | def send_test(topic: Optional[str] = None) -> str: 25 | return cast(str, DBOS.recv(topic, 60)) 26 | 27 | 28 | @DBOS.workflow() 29 | def retrieve_test(value: str) -> str: 30 | DBOS.sleep(5) 31 | inorder_results.append(value) 32 | return value 33 | 34 | 35 | @DBOS.workflow() 36 | def event_test(key: str, value: str, update: Optional[int] = None) -> str: 37 | DBOS.set_event(key, value) 38 | if update is not None: 39 | DBOS.sleep(update) 40 | DBOS.set_event(key, f"updated-{value}") 41 | return f"{key}-{value}" 42 | 43 | 44 | @DBOS.workflow() 45 | def blocked_workflow() -> None: 46 | while True: 47 | DBOS.sleep(0.1) 48 | 49 | 50 | @DBOS.transaction() 51 | def test_txn(x: int) -> int: 52 | return x 53 | 54 | 55 | @DBOS.step() 56 | def test_step(x: int) -> int: 57 | return x 58 | 59 | 60 | @DBOS.workflow() 61 | def fork_test(x: int) -> int: 62 | return test_txn(x) + test_step(x) 63 | -------------------------------------------------------------------------------- /tests/client_worker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from dbos import DBOS, SetWorkflowID 5 | from tests import client_collateral as cc 6 | from tests.conftest import default_config 7 | 8 | app_vers = os.environ.get("DBOS__APPVERSION") 9 | if app_vers is None: 10 | DBOS.logger.error("DBOS__APPVERSION not set") 11 | os._exit(1) 12 | else: 13 | DBOS.logger.info(f"DBOS__APPVERSION: {app_vers}") 14 | 15 | if len(sys.argv) < 2: 16 | DBOS.logger.error("Usage: client_worker wfid ") 17 | os._exit(1) 18 | 19 | wfid = sys.argv[1] 20 | topic = sys.argv[2] if len(sys.argv) > 2 else None 21 | 22 | config = default_config() 23 | DBOS(config=config) 24 | DBOS.launch() 25 | 26 | DBOS.logger.info(f"Starting send_test with WF ID: {wfid}") 27 | with SetWorkflowID(wfid): 28 | DBOS.start_workflow(cc.send_test, topic) 29 | 30 | os._exit(0) 31 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | import time 5 | from typing import Any, Generator, Tuple 6 | from urllib.parse import quote 7 | 8 | import pytest 9 | import sqlalchemy as sa 10 | from fastapi import FastAPI 11 | from flask import Flask 12 | 13 | from dbos import DBOS, DBOSClient, DBOSConfig 14 | from dbos._app_db import ApplicationDatabase 15 | from dbos._schemas.system_database import SystemSchema 16 | from dbos._sys_db import SystemDatabase 17 | 18 | 19 | @pytest.fixture(scope="session") 20 | def build_wheel() -> str: 21 | subprocess.check_call(["pdm", "build"]) 22 | wheel_files = glob.glob(os.path.join("dist", "*.whl")) 23 | assert len(wheel_files) == 1 24 | return wheel_files[0] 25 | 26 | 27 | def default_config() -> DBOSConfig: 28 | return { 29 | "name": "test-app", 30 | "database_url": f"postgresql://postgres:{quote(os.environ.get('PGPASSWORD', 'dbos'), safe='')}@localhost:5432/dbostestpy", 31 | } 32 | 33 | 34 | @pytest.fixture() 35 | def config() -> DBOSConfig: 36 | return default_config() 37 | 38 | 39 | @pytest.fixture() 40 | def sys_db(config: DBOSConfig) -> Generator[SystemDatabase, Any, None]: 41 | assert config["database_url"] is not None 42 | sys_db = SystemDatabase( 43 | database_url=config["database_url"], 44 | engine_kwargs={ 45 | "pool_timeout": 30, 46 | "max_overflow": 0, 47 | "pool_size": 2, 48 | "connect_args": {"connect_timeout": 30}, 49 | }, 50 | ) 51 | sys_db.run_migrations() 52 | yield sys_db 53 | sys_db.destroy() 54 | 55 | 56 | @pytest.fixture() 57 | def app_db(config: DBOSConfig) -> Generator[ApplicationDatabase, Any, None]: 58 | assert config["database_url"] is not None 59 | app_db = ApplicationDatabase( 60 | database_url=config["database_url"], 61 | engine_kwargs={ 62 | "pool_timeout": 30, 63 | "max_overflow": 0, 64 | "pool_size": 2, 65 | "connect_args": {"connect_timeout": 30}, 66 | }, 67 | ) 68 | app_db.run_migrations() 69 | yield app_db 70 | app_db.destroy() 71 | 72 | 73 | @pytest.fixture(scope="session") 74 | def postgres_db_engine() -> sa.Engine: 75 | cfg = default_config() 76 | assert cfg["database_url"] is not None 77 | return sa.create_engine( 78 | sa.make_url(cfg["database_url"]).set( 79 | drivername="postgresql+psycopg", 80 | database="postgres", 81 | ), 82 | connect_args={ 83 | "connect_timeout": 30, 84 | }, 85 | ) 86 | 87 | 88 | @pytest.fixture() 89 | def cleanup_test_databases(config: DBOSConfig, postgres_db_engine: sa.Engine) -> None: 90 | assert config["database_url"] is not None 91 | app_db_name = sa.make_url(config["database_url"]).database 92 | sys_db_name = f"{app_db_name}_dbos_sys" 93 | 94 | with postgres_db_engine.connect() as connection: 95 | connection.execution_options(isolation_level="AUTOCOMMIT") 96 | connection.execute( 97 | sa.text(f"DROP DATABASE IF EXISTS {app_db_name} WITH (FORCE)") 98 | ) 99 | connection.execute( 100 | sa.text(f"DROP DATABASE IF EXISTS {sys_db_name} WITH (FORCE)") 101 | ) 102 | 103 | # Clean up environment variables 104 | os.environ.pop("DBOS__VMID") if "DBOS__VMID" in os.environ else None 105 | os.environ.pop("DBOS__APPVERSION") if "DBOS__APPVERSION" in os.environ else None 106 | os.environ.pop("DBOS__APPID") if "DBOS__APPID" in os.environ else None 107 | 108 | 109 | @pytest.fixture() 110 | def dbos( 111 | config: DBOSConfig, cleanup_test_databases: None 112 | ) -> Generator[DBOS, Any, None]: 113 | DBOS.destroy(destroy_registry=True) 114 | 115 | # This launches for test convenience. 116 | # Tests add to running DBOS and then call stuff without adding 117 | # launch themselves. 118 | # If your test is tricky and has a problem with this, use a different 119 | # fixture that does not launch. 120 | dbos = DBOS(config=config) 121 | DBOS.launch() 122 | 123 | yield dbos 124 | DBOS.destroy(destroy_registry=True) 125 | 126 | 127 | @pytest.fixture() 128 | def client(config: DBOSConfig, dbos: DBOS) -> Generator[DBOSClient, Any, None]: 129 | assert config["database_url"] is not None 130 | client = DBOSClient(config["database_url"]) 131 | yield client 132 | client.destroy() 133 | 134 | 135 | @pytest.fixture() 136 | def dbos_fastapi( 137 | config: DBOSConfig, cleanup_test_databases: None 138 | ) -> Generator[Tuple[DBOS, FastAPI], Any, None]: 139 | DBOS.destroy(destroy_registry=True) 140 | app = FastAPI() 141 | dbos = DBOS(fastapi=app, config=config) 142 | 143 | # This is for test convenience. 144 | # Usually fastapi itself does launch, but we are not completing the fastapi lifecycle 145 | DBOS.launch() 146 | 147 | yield dbos, app 148 | DBOS.destroy(destroy_registry=True) 149 | 150 | 151 | @pytest.fixture() 152 | def dbos_flask( 153 | config: DBOSConfig, cleanup_test_databases: None 154 | ) -> Generator[Tuple[DBOS, Flask], Any, None]: 155 | DBOS.destroy(destroy_registry=True) 156 | app = Flask(__name__) 157 | 158 | dbos = DBOS(flask=app, config=config) 159 | 160 | # This is for test convenience. 161 | # Usually fastapi itself does launch, but we are not completing the fastapi lifecycle 162 | DBOS.launch() 163 | 164 | yield dbos, app 165 | DBOS.destroy(destroy_registry=True) 166 | 167 | 168 | # Pretty-print test names 169 | def pytest_collection_modifyitems(session: Any, config: Any, items: Any) -> None: 170 | for item in items: 171 | item._nodeid = "\n" + item.nodeid + "\n" 172 | 173 | 174 | def queue_entries_are_cleaned_up(dbos: DBOS) -> bool: 175 | max_tries = 10 176 | success = False 177 | for i in range(max_tries): 178 | with dbos._sys_db.engine.begin() as c: 179 | query = ( 180 | sa.select(sa.func.count()) 181 | .select_from(SystemSchema.workflow_status) 182 | .where( 183 | sa.and_( 184 | SystemSchema.workflow_status.c.queue_name.isnot(None), 185 | SystemSchema.workflow_status.c.status.in_( 186 | ["ENQUEUED", "PENDING"] 187 | ), 188 | ) 189 | ) 190 | ) 191 | row = c.execute(query).fetchone() 192 | assert row is not None 193 | count = row[0] 194 | if count == 0: 195 | success = True 196 | break 197 | time.sleep(1) 198 | return success 199 | -------------------------------------------------------------------------------- /tests/dupname_classdefs1.py: -------------------------------------------------------------------------------- 1 | from dbos import DBOS, DBOSConfiguredInstance 2 | 3 | 4 | @DBOS.dbos_class(class_name="AnotherDBOSTestRegDup") 5 | class DBOSTestRegDup(DBOSConfiguredInstance): 6 | """DBOSTestRegDup duplicates the name of a class defined dupname_classdefsa.py""" 7 | 8 | def __init__(self, instance_name: str) -> None: 9 | super().__init__(instance_name) 10 | -------------------------------------------------------------------------------- /tests/dupname_classdefsa.py: -------------------------------------------------------------------------------- 1 | from dbos import DBOS, DBOSConfiguredInstance 2 | 3 | 4 | @DBOS.dbos_class() 5 | class DBOSTestRegDup(DBOSConfiguredInstance): 6 | """DBOSTestRegDup duplicates the name of a class defined dupname_classdefs1.py""" 7 | 8 | def __init__(self, instance_name: str) -> None: 9 | super().__init__(instance_name) 10 | -------------------------------------------------------------------------------- /tests/more_classdefs.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | # Public API 4 | from dbos import DBOS 5 | 6 | 7 | @DBOS.dbos_class() 8 | class DBOSWFEvents: 9 | @staticmethod 10 | @DBOS.workflow() 11 | def test_setevent_workflow() -> None: 12 | DBOS.set_event("key1", "value1") 13 | DBOS.set_event("key2", "value2") 14 | DBOS.set_event("key3", None) 15 | 16 | @staticmethod 17 | @DBOS.workflow() 18 | def test_getevent_workflow( 19 | target_uuid: str, key: str, timeout_seconds: float = 10 20 | ) -> Optional[str]: 21 | msg = DBOS.get_event(target_uuid, key, timeout_seconds) 22 | return str(msg) if msg is not None else None 23 | 24 | 25 | @DBOS.workflow() 26 | def wfFunc(arg: str) -> str: 27 | assert DBOS.workflow_id == "wfid" 28 | return arg + "1" 29 | -------------------------------------------------------------------------------- /tests/queuedworkflow.py: -------------------------------------------------------------------------------- 1 | # Public API 2 | import os 3 | from urllib.parse import quote 4 | 5 | from dbos import DBOS, DBOSConfig, Queue, SetWorkflowID 6 | 7 | 8 | def default_config() -> DBOSConfig: 9 | return { 10 | "name": "test-app", 11 | "database_url": f"postgresql://postgres:{quote(os.environ.get('PGPASSWORD', 'dbos'), safe='')}@localhost:5432/dbostestpy", 12 | } 13 | 14 | 15 | q = Queue("testq", concurrency=1, limiter={"limit": 1, "period": 1}) 16 | 17 | 18 | @DBOS.dbos_class() 19 | class WF: 20 | @staticmethod 21 | @DBOS.workflow() 22 | def queued_task() -> int: 23 | DBOS.sleep(0.1) 24 | return 1 25 | 26 | @staticmethod 27 | @DBOS.workflow() 28 | def enqueue_5_tasks() -> int: 29 | for i in range(5): 30 | print(f"Iteration {i + 1}") 31 | wfh = DBOS.start_workflow(WF.queued_task) 32 | wfh.get_result() 33 | DBOS.sleep(0.9) 34 | 35 | if i == 3 and "DIE_ON_PURPOSE" in os.environ: 36 | print("CRASH") 37 | os._exit(1) 38 | return 5 39 | 40 | x = 5 41 | 42 | 43 | def main() -> None: 44 | DBOS(config=default_config()) 45 | DBOS.launch() 46 | DBOS._recover_pending_workflows() 47 | 48 | with SetWorkflowID("testqueuedwfcrash"): 49 | WF.enqueue_5_tasks() 50 | 51 | DBOS.destroy() 52 | os._exit(0) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from dbos.cli._template_init import get_templates_directory 2 | from dbos.cli.cli import _resolve_project_name_and_template 3 | 4 | 5 | def test_resolve_project_name_and_template() -> None: 6 | git_templates = ["dbos-toolbox", "dbos-app-starter", "dbos-cron-starter"] 7 | templates_dir = get_templates_directory() 8 | 9 | # dbos init my-app -t dbos-toolbox 10 | project_name, template = _resolve_project_name_and_template( 11 | project_name="my-app", 12 | template="dbos-toolbox", 13 | config=False, 14 | git_templates=git_templates, 15 | templates_dir=templates_dir, 16 | ) 17 | assert project_name == "my-app" 18 | assert template == "dbos-toolbox" 19 | 20 | # dbos init -t dbos-toolbox 21 | project_name, template = _resolve_project_name_and_template( 22 | project_name=None, 23 | template="dbos-toolbox", 24 | config=False, 25 | git_templates=git_templates, 26 | templates_dir=templates_dir, 27 | ) 28 | assert project_name == "dbos-toolbox" 29 | assert template == "dbos-toolbox" 30 | -------------------------------------------------------------------------------- /tests/test_concurrency.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | import uuid 4 | from concurrent.futures import Future, ThreadPoolExecutor 5 | from typing import Tuple, cast 6 | 7 | from sqlalchemy import text 8 | 9 | # Public API 10 | from dbos import DBOS, SetWorkflowID 11 | 12 | 13 | def test_concurrent_workflows(dbos: DBOS) -> None: 14 | @DBOS.workflow() 15 | def test_workflow() -> str: 16 | time.sleep(1) 17 | return DBOS.workflow_id 18 | 19 | def test_thread(id: str) -> str: 20 | with SetWorkflowID(id): 21 | return test_workflow() 22 | 23 | num_threads = 10 24 | with ThreadPoolExecutor(max_workers=num_threads) as executor: 25 | futures: list[Tuple[str, Future[str]]] = [] 26 | for _ in range(num_threads): 27 | id = str(uuid.uuid4()) 28 | futures.append((id, executor.submit(test_thread, id))) 29 | for id, future in futures: 30 | assert id == future.result() 31 | 32 | 33 | def test_concurrent_conflict_uuid(dbos: DBOS) -> None: 34 | condition = threading.Condition() 35 | step_count = 0 36 | txn_count = 0 37 | 38 | @DBOS.step() 39 | def test_step() -> str: 40 | nonlocal step_count 41 | condition.acquire() 42 | step_count += 1 43 | if step_count % 2 == 1: 44 | # Wait for the other one to notify 45 | condition.wait() 46 | else: 47 | # Notify the other one 48 | condition.notify() 49 | condition.release() 50 | 51 | return DBOS.workflow_id 52 | 53 | @DBOS.workflow() 54 | def test_workflow() -> str: 55 | res = test_step() 56 | return res 57 | 58 | def test_comm_thread(id: str) -> str: 59 | with SetWorkflowID(id): 60 | return test_step() 61 | 62 | # Need to set isolation level to a lower one, otherwise it gets serialization error instead (we already handle it correctly by automatic retries). 63 | @DBOS.transaction(isolation_level="REPEATABLE READ") 64 | def test_transaction() -> str: 65 | DBOS.sql_session.execute(text("SELECT 1")).fetchall() 66 | nonlocal txn_count 67 | condition.acquire() 68 | txn_count += 1 69 | if txn_count % 2 == 1: 70 | # Wait for the other one to notify 71 | condition.wait() 72 | else: 73 | # Notify the other one 74 | condition.notify() 75 | condition.release() 76 | 77 | return DBOS.workflow_id 78 | 79 | def test_txn_thread(id: str) -> str: 80 | with SetWorkflowID(id): 81 | return test_transaction() 82 | 83 | wfuuid = str(uuid.uuid4()) 84 | with SetWorkflowID(wfuuid): 85 | wf_handle1 = dbos.start_workflow(test_workflow) 86 | 87 | with SetWorkflowID(wfuuid): 88 | wf_handle2 = dbos.start_workflow(test_workflow) 89 | 90 | # These two workflows should run concurrently, but both should succeed. 91 | assert wf_handle1.get_result() == wfuuid 92 | assert wf_handle2.get_result() == wfuuid 93 | 94 | # Make sure temp workflows can handle conflicts as well. 95 | wfuuid = str(uuid.uuid4()) 96 | with ThreadPoolExecutor(max_workers=2) as executor: 97 | future1 = executor.submit(test_comm_thread, wfuuid) 98 | future2 = executor.submit(test_comm_thread, wfuuid) 99 | 100 | assert future1.result() == wfuuid 101 | assert future2.result() == wfuuid 102 | 103 | # Make sure temp transactions can handle conflicts as well. 104 | wfuuid = str(uuid.uuid4()) 105 | with ThreadPoolExecutor(max_workers=2) as executor: 106 | future1 = executor.submit(test_txn_thread, wfuuid) 107 | future2 = executor.submit(test_txn_thread, wfuuid) 108 | 109 | assert future1.result() == wfuuid 110 | assert future2.result() == wfuuid 111 | 112 | 113 | def test_concurrent_recv(dbos: DBOS) -> None: 114 | condition = threading.Condition() 115 | counter = 0 116 | 117 | @DBOS.workflow() 118 | def test_workflow(topic: str) -> str: 119 | nonlocal counter 120 | condition.acquire() 121 | counter += 1 122 | if counter % 2 == 1: 123 | # Wait for the other one to notify 124 | condition.wait() 125 | else: 126 | # Notify the other one 127 | condition.notify() 128 | condition.release() 129 | m = cast(str, DBOS.recv(topic, 5)) 130 | return m 131 | 132 | def test_thread(id: str, topic: str) -> str: 133 | with SetWorkflowID(id): 134 | return test_workflow(topic) 135 | 136 | wfuuid = str(uuid.uuid4()) 137 | topic = "test_topic" 138 | with ThreadPoolExecutor(max_workers=2) as executor: 139 | future1 = executor.submit(test_thread, wfuuid, topic) 140 | future2 = executor.submit(test_thread, wfuuid, topic) 141 | 142 | expected_message = "test message" 143 | DBOS.send(wfuuid, expected_message, topic) 144 | # Both should return the same message 145 | assert future1.result() == future2.result() 146 | assert future1.result() == expected_message 147 | # Make sure the notification map is empty 148 | assert not dbos._sys_db.notifications_map._dict 149 | 150 | 151 | def test_concurrent_getevent(dbos: DBOS) -> None: 152 | @DBOS.workflow() 153 | def test_workflow(event_name: str, value: str) -> str: 154 | DBOS.set_event(event_name, value) 155 | return value 156 | 157 | def test_thread(id: str, event_name: str) -> str: 158 | return cast(str, DBOS.get_event(id, event_name, 5)) 159 | 160 | wfuuid = str(uuid.uuid4()) 161 | event_name = "test_event" 162 | with ThreadPoolExecutor(max_workers=2) as executor: 163 | future1 = executor.submit(test_thread, wfuuid, event_name) 164 | future2 = executor.submit(test_thread, wfuuid, event_name) 165 | 166 | expected_message = "test message" 167 | with SetWorkflowID(wfuuid): 168 | test_workflow(event_name, expected_message) 169 | 170 | # Both should return the same message 171 | assert future1.result() == future2.result() 172 | assert future1.result() == expected_message 173 | # Make sure the event map is empty 174 | assert not dbos._sys_db.workflow_events_map._dict 175 | -------------------------------------------------------------------------------- /tests/test_debug.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from dbos import DBOS, DBOSConfig, SetWorkflowID 7 | from dbos._dbos import _get_dbos_instance 8 | from dbos._debug import PythonModule, parse_start_command 9 | from dbos._schemas.system_database import SystemSchema 10 | 11 | 12 | def test_parse_fast_api_command() -> None: 13 | command = "fastapi run app/main.py" 14 | actual = parse_start_command(command) 15 | assert isinstance(actual, PythonModule) 16 | assert actual.module_name == "main" 17 | 18 | 19 | def test_parse_python_command() -> None: 20 | command = "python app/main.py" 21 | expected = "app/main.py" 22 | actual = parse_start_command(command) 23 | assert actual == expected 24 | 25 | 26 | def test_parse_python3_command() -> None: 27 | command = "python3 app/main.py" 28 | expected = "app/main.py" 29 | actual = parse_start_command(command) 30 | assert actual == expected 31 | 32 | 33 | def test_parse_python_module_command() -> None: 34 | command = "python -m some_module" 35 | actual = parse_start_command(command) 36 | assert isinstance(actual, PythonModule) 37 | assert actual.module_name == "some_module" 38 | 39 | 40 | def test_parse_python3_module_command() -> None: 41 | command = "python3 -m some_module" 42 | actual = parse_start_command(command) 43 | assert isinstance(actual, PythonModule) 44 | assert actual.module_name == "some_module" 45 | 46 | 47 | def get_recovery_attempts(wfuuid: str) -> int: 48 | dbos = _get_dbos_instance() 49 | with dbos._sys_db.engine.connect() as c: 50 | stmt = sa.select( 51 | SystemSchema.workflow_status.c.recovery_attempts, 52 | SystemSchema.workflow_status.c.created_at, 53 | SystemSchema.workflow_status.c.updated_at, 54 | ).where(SystemSchema.workflow_status.c.workflow_uuid == wfuuid) 55 | result = c.execute(stmt).fetchone() 56 | assert result is not None 57 | recovery_attempts, created_at, updated_at = result 58 | return int(recovery_attempts) 59 | 60 | 61 | def test_wf_debug(dbos: DBOS, config: DBOSConfig) -> None: 62 | wf_counter: int = 0 63 | step_counter: int = 0 64 | 65 | @DBOS.workflow() 66 | def test_workflow(var: str) -> str: 67 | nonlocal wf_counter 68 | wf_counter += 1 69 | res = test_step(var) 70 | DBOS.logger.info("I'm test_workflow") 71 | return res 72 | 73 | @DBOS.step() 74 | def test_step(var: str) -> str: 75 | nonlocal step_counter 76 | step_counter += 1 77 | DBOS.logger.info("I'm test_step") 78 | return var 79 | 80 | wfuuid = str(uuid.uuid4()) 81 | with SetWorkflowID(wfuuid): 82 | handle = DBOS.start_workflow(test_workflow, "test") 83 | result = handle.get_result() 84 | assert result == "test" 85 | assert wf_counter == 1 86 | assert step_counter == 1 87 | 88 | expected_retry_attempts = get_recovery_attempts(wfuuid) 89 | 90 | DBOS.destroy() 91 | DBOS(config=config) 92 | DBOS.launch(debug_mode=True) 93 | 94 | handle = DBOS._execute_workflow_id(wfuuid) 95 | result = handle.get_result() 96 | assert result == "test" 97 | assert wf_counter == 2 98 | assert step_counter == 1 99 | 100 | actual_retry_attempts = get_recovery_attempts(wfuuid) 101 | assert actual_retry_attempts == expected_retry_attempts 102 | 103 | 104 | def test_wf_debug_exception(dbos: DBOS, config: DBOSConfig) -> None: 105 | wf_counter: int = 0 106 | step_counter: int = 0 107 | 108 | @DBOS.workflow() 109 | def test_workflow(var: str) -> str: 110 | nonlocal wf_counter 111 | wf_counter += 1 112 | res = test_step(var) 113 | DBOS.logger.info("I'm test_workflow") 114 | raise Exception("test_wf_debug_exception") 115 | 116 | @DBOS.step() 117 | def test_step(var: str) -> str: 118 | nonlocal step_counter 119 | step_counter += 1 120 | DBOS.logger.info("I'm test_step") 121 | return var 122 | 123 | wfuuid = str(uuid.uuid4()) 124 | with SetWorkflowID(wfuuid): 125 | handle = DBOS.start_workflow(test_workflow, "test") 126 | with pytest.raises(Exception) as excinfo: 127 | handle.get_result() 128 | assert str(excinfo.value) == "test_wf_debug_exception" 129 | 130 | assert wf_counter == 1 131 | assert step_counter == 1 132 | 133 | expected_retry_attempts = get_recovery_attempts(wfuuid) 134 | 135 | DBOS.destroy() 136 | DBOS(config=config) 137 | DBOS.launch(debug_mode=True) 138 | 139 | handle = DBOS._execute_workflow_id(wfuuid) 140 | with pytest.raises(Exception) as excinfo: 141 | handle.get_result() 142 | assert str(excinfo.value) == "test_wf_debug_exception" 143 | assert wf_counter == 2 144 | assert step_counter == 1 145 | 146 | actual_retry_attempts = get_recovery_attempts(wfuuid) 147 | assert actual_retry_attempts == expected_retry_attempts 148 | -------------------------------------------------------------------------------- /tests/test_fastapi.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import uuid 4 | from contextlib import asynccontextmanager 5 | from typing import Any, Tuple 6 | 7 | import httpx 8 | import pytest 9 | import sqlalchemy as sa 10 | import uvicorn 11 | from fastapi import FastAPI 12 | from fastapi.testclient import TestClient 13 | 14 | # Public API 15 | from dbos import DBOS, DBOSConfig, Queue 16 | 17 | # Private API because this is a unit test 18 | from dbos._context import assert_current_dbos_context 19 | 20 | 21 | def test_simple_endpoint( 22 | dbos_fastapi: Tuple[DBOS, FastAPI], caplog: pytest.LogCaptureFixture 23 | ) -> None: 24 | dbos, app = dbos_fastapi 25 | client = TestClient(app) 26 | 27 | @app.get("/endpoint/{var1}/{var2}") 28 | def test_endpoint(var1: str, var2: str) -> str: 29 | result = test_workflow(var1, var2) 30 | ctx = assert_current_dbos_context() 31 | assert not ctx.is_within_workflow() 32 | return result 33 | 34 | @app.get("/workflow/{var1}/{var2}") 35 | @DBOS.workflow() 36 | def test_workflow(var1: str, var2: str) -> str: 37 | DBOS.span.set_attribute("test_key", "test_value") 38 | res1 = test_transaction(var1) 39 | res2 = test_step(var2) 40 | return res1 + res2 41 | 42 | @app.get("/transaction/{var}") 43 | @DBOS.transaction() 44 | def test_transaction(var: str) -> str: 45 | rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall() 46 | return var + str(rows[0][0]) 47 | 48 | @DBOS.step() 49 | def test_step(var: str) -> str: 50 | return var 51 | 52 | original_propagate = logging.getLogger("dbos").propagate 53 | caplog.set_level(logging.WARNING, "dbos") 54 | logging.getLogger("dbos").propagate = True 55 | 56 | response = client.get("/workflow/bob/bob") 57 | assert response.status_code == 200 58 | assert response.text == '"bob1bob"' 59 | assert caplog.text == "" 60 | 61 | response = client.get("/endpoint/bob/bob") 62 | assert response.status_code == 200 63 | assert response.text == '"bob1bob"' 64 | assert caplog.text == "" 65 | 66 | response = client.get("/transaction/bob") 67 | assert response.status_code == 200 68 | assert response.text == '"bob1"' 69 | assert caplog.text == "" 70 | 71 | # Reset logging 72 | logging.getLogger("dbos").propagate = original_propagate 73 | 74 | 75 | def test_start_workflow(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None: 76 | dbos, app = dbos_fastapi 77 | client = TestClient(app) 78 | 79 | @app.get("/{var1}/{var2}") 80 | def test_endpoint(var1: str, var2: str) -> str: 81 | handle = dbos.start_workflow(test_workflow, var1, var2) 82 | context = assert_current_dbos_context() 83 | assert not context.is_within_workflow() 84 | return handle.get_result() 85 | 86 | @DBOS.workflow() 87 | def test_workflow(var1: str, var2: str) -> str: 88 | DBOS.span.set_attribute("test_key", "test_value") 89 | res1 = test_transaction(var1) 90 | res2 = test_step(var2) 91 | return res1 + res2 92 | 93 | @DBOS.transaction() 94 | def test_transaction(var: str) -> str: 95 | rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall() 96 | return var + str(rows[0][0]) 97 | 98 | @DBOS.step() 99 | def test_step(var: str) -> str: 100 | return var 101 | 102 | response = client.get("/bob/bob") 103 | assert response.status_code == 200 104 | assert response.text == '"bob1bob"' 105 | 106 | 107 | def test_endpoint_recovery(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None: 108 | dbos, app = dbos_fastapi 109 | client = TestClient(app) 110 | 111 | workflow_id = str(uuid.uuid4()) 112 | 113 | @DBOS.workflow() 114 | def test_workflow(var1: str) -> tuple[str, str]: 115 | return var1, DBOS.workflow_id 116 | 117 | @app.get("/{var1}/{var2}") 118 | def test_endpoint(var1: str, var2: str) -> dict[str, str]: 119 | res1, id1 = test_workflow(var1) 120 | res2, id2 = test_workflow(var2) 121 | return {"res1": res1, "res2": res2, "id1": id1, "id2": id2} 122 | 123 | response = client.get("/a/b", headers={"dbos-idempotency-key": workflow_id}) 124 | assert response.status_code == 200 125 | assert response.json().get("res1") == "a" 126 | assert response.json().get("res2") == "b" 127 | assert response.json().get("id1") == workflow_id 128 | assert response.json().get("id2") != workflow_id 129 | 130 | # Change the workflow status to pending 131 | dbos._sys_db.update_workflow_outcome(workflow_id, "PENDING") 132 | 133 | # Recovery should execute the workflow again but skip the transaction 134 | workflow_handles = DBOS._recover_pending_workflows() 135 | assert len(workflow_handles) == 1 136 | assert workflow_handles[0].get_result() == ("a", workflow_id) 137 | 138 | 139 | @pytest.mark.asyncio 140 | async def test_custom_lifespan( 141 | config: DBOSConfig, cleanup_test_databases: None 142 | ) -> None: 143 | resource = None 144 | port = 8000 145 | 146 | @asynccontextmanager # pyright: ignore 147 | async def lifespan(app: FastAPI) -> Any: 148 | nonlocal resource 149 | resource = 1 150 | yield 151 | resource = None 152 | 153 | app = FastAPI(lifespan=lifespan) 154 | 155 | DBOS.destroy() 156 | DBOS(fastapi=app, config=config) 157 | 158 | queue = Queue("queue") 159 | 160 | @app.get("/") 161 | @DBOS.workflow() 162 | async def resource_workflow() -> Any: 163 | handle = await queue.enqueue_async(queue_workflow) 164 | return { 165 | "resource": resource, 166 | "loop": id(asyncio.get_event_loop()), 167 | "queue_loop": await handle.get_result(), 168 | } 169 | 170 | @DBOS.workflow() 171 | async def queue_workflow() -> int: 172 | return id(asyncio.get_event_loop()) 173 | 174 | uvicorn_config = uvicorn.Config( 175 | app=app, host="127.0.0.1", port=port, log_level="error" 176 | ) 177 | server = uvicorn.Server(config=uvicorn_config) 178 | 179 | # Run server in background task 180 | server_task = asyncio.create_task(server.serve()) 181 | await asyncio.sleep(0.2) # Give server time to start 182 | 183 | async with httpx.AsyncClient() as client: 184 | r = await client.get(f"http://127.0.0.1:{port}") 185 | assert r.json()["resource"] == 1 186 | # Verify that both the FastAPI and enqueued workflows run in the main event loop 187 | assert r.json()["loop"] == id(asyncio.get_event_loop()) 188 | assert r.json()["queue_loop"] == id(asyncio.get_event_loop()) 189 | 190 | server.should_exit = True 191 | await server_task 192 | assert resource is None 193 | 194 | 195 | def test_stacked_decorators_wf(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None: 196 | dbos, app = dbos_fastapi 197 | client = TestClient(app) 198 | 199 | @app.get("/endpoint/{var1}/{var2}") 200 | @DBOS.workflow() 201 | async def test_endpoint(var1: str, var2: str) -> str: 202 | return f"{var1}, {var2}!" 203 | 204 | response = client.get("/endpoint/plums/deify") 205 | assert response.status_code == 200 206 | assert response.text == '"plums, deify!"' 207 | 208 | 209 | def test_stacked_decorators_step(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None: 210 | dbos, app = dbos_fastapi 211 | client = TestClient(app) 212 | 213 | @app.get("/endpoint/{var1}/{var2}") 214 | @DBOS.step() 215 | async def test_endpoint(var1: str, var2: str) -> str: 216 | return f"{var1}, {var2}!" 217 | 218 | response = client.get("/endpoint/plums/deify") 219 | assert response.status_code == 200 220 | assert response.text == '"plums, deify!"' 221 | -------------------------------------------------------------------------------- /tests/test_flask.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import uuid 3 | from typing import Tuple 4 | 5 | import pytest 6 | import sqlalchemy as sa 7 | from flask import Flask, Response, jsonify 8 | 9 | from dbos import DBOS 10 | from dbos._context import assert_current_dbos_context 11 | 12 | 13 | def test_flask_endpoint( 14 | dbos_flask: Tuple[DBOS, Flask], caplog: pytest.LogCaptureFixture 15 | ) -> None: 16 | _, app = dbos_flask 17 | 18 | @app.route("/endpoint//") 19 | def test_endpoint(var1: str, var2: str) -> Response: 20 | ctx = assert_current_dbos_context() 21 | assert not ctx.is_within_workflow() 22 | return test_workflow(var1, var2) 23 | 24 | @app.route("/workflow//") 25 | @DBOS.workflow() 26 | def test_workflow(var1: str, var2: str) -> Response: 27 | res1 = test_transaction(var1) 28 | res2 = test_step(var2) 29 | result = res1 + res2 30 | return jsonify({"result": result}) 31 | 32 | @app.route("/transaction/") 33 | @DBOS.transaction() 34 | def test_transaction(var: str) -> str: 35 | rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall() 36 | return var + str(rows[0][0]) 37 | 38 | @DBOS.step() 39 | def test_step(var: str) -> str: 40 | return var 41 | 42 | app.config["TESTING"] = True 43 | client = app.test_client() 44 | 45 | original_propagate = logging.getLogger("dbos").propagate 46 | caplog.set_level(logging.WARNING, "dbos") 47 | logging.getLogger("dbos").propagate = True 48 | 49 | response = client.get("/endpoint/a/b") 50 | assert response.status_code == 200 51 | assert response.json == {"result": "a1b"} 52 | assert caplog.text == "" 53 | 54 | response = client.get("/workflow/a/b") 55 | assert response.status_code == 200 56 | assert response.json == {"result": "a1b"} 57 | assert caplog.text == "" 58 | 59 | response = client.get("/transaction/bob") 60 | assert response.status_code == 200 61 | assert response.text == "bob1" 62 | assert caplog.text == "" 63 | 64 | # Reset logging 65 | logging.getLogger("dbos").propagate = original_propagate 66 | 67 | 68 | def test_endpoint_recovery(dbos_flask: Tuple[DBOS, Flask]) -> None: 69 | dbos, app = dbos_flask 70 | 71 | wfuuid = str(uuid.uuid4()) 72 | 73 | @DBOS.workflow() 74 | def test_workflow(var1: str) -> tuple[str, str]: 75 | return var1, DBOS.workflow_id 76 | 77 | @app.route("//") 78 | def test_endpoint(var1: str, var2: str) -> dict[str, str]: 79 | res1, id1 = test_workflow(var1) 80 | res2, id2 = test_workflow(var2) 81 | return {"res1": res1, "res2": res2, "id1": id1, "id2": id2} 82 | 83 | app.config["TESTING"] = True 84 | client = app.test_client() 85 | 86 | response = client.get("/a/b", headers={"dbos-idempotency-key": wfuuid}) 87 | assert response.status_code == 200 88 | assert response.json is not None 89 | assert response.json.get("res1") == "a" 90 | assert response.json.get("res2") == "b" 91 | assert response.json.get("id1") == wfuuid 92 | assert response.json.get("id2") != wfuuid 93 | 94 | # Change the workflow status to pending 95 | dbos._sys_db.update_workflow_outcome(wfuuid, "PENDING") 96 | 97 | # Recovery should execute the workflow again but skip the transaction 98 | workflow_handles = DBOS._recover_pending_workflows() 99 | assert len(workflow_handles) == 1 100 | assert workflow_handles[0].get_result() == ("a", wfuuid) 101 | -------------------------------------------------------------------------------- /tests/test_kafka.py: -------------------------------------------------------------------------------- 1 | import random 2 | import threading 3 | import time 4 | from typing import NoReturn 5 | 6 | import pytest 7 | from confluent_kafka import KafkaError, Producer 8 | 9 | from dbos import DBOS, KafkaMessage 10 | 11 | # These tests require local Kafka to run. 12 | # Without it, they're automatically skipped. 13 | # Here's a docker-compose script you can use to set up local Kafka: 14 | 15 | # version: "3.7" 16 | # services: 17 | # broker: 18 | # image: bitnami/kafka:latest 19 | # hostname: broker 20 | # container_name: broker 21 | # ports: 22 | # - '9092:9092' 23 | # environment: 24 | # KAFKA_CFG_NODE_ID: 1 25 | # KAFKA_CFG_LISTENER_SECURITY_PROTOCOL_MAP: 'CONTROLLER:PLAINTEXT,PLAINTEXT:PLAINTEXT,PLAINTEXT_HOST:PLAINTEXT' 26 | # KAFKA_CFG_ADVERTISED_LISTENERS: 'PLAINTEXT_HOST://localhost:9092,PLAINTEXT://broker:19092' 27 | # KAFKA_CFG_PROCESS_ROLES: 'broker,controller' 28 | # KAFKA_CFG_CONTROLLER_QUORUM_VOTERS: '1@broker:29093' 29 | # KAFKA_CFG_LISTENERS: 'CONTROLLER://:29093,PLAINTEXT_HOST://:9092,PLAINTEXT://:19092' 30 | # KAFKA_CFG_INTER_BROKER_LISTENER_NAME: 'PLAINTEXT' 31 | # KAFKA_CFG_CONTROLLER_LISTENER_NAMES: 'CONTROLLER' 32 | 33 | 34 | NUM_EVENTS = 3 35 | 36 | 37 | def send_test_messages(server: str, topic: str) -> bool: 38 | 39 | try: 40 | 41 | def on_error(err: KafkaError) -> NoReturn: 42 | raise Exception(err) 43 | 44 | producer = Producer({"bootstrap.servers": server, "error_cb": on_error}) 45 | 46 | for i in range(NUM_EVENTS): 47 | producer.produce( 48 | topic, key=f"test message key {i}", value=f"test message value {i}" 49 | ) 50 | 51 | producer.poll(10) 52 | producer.flush(10) 53 | return True 54 | except Exception as e: 55 | return False 56 | finally: 57 | pass 58 | 59 | 60 | def test_kafka(dbos: DBOS) -> None: 61 | event = threading.Event() 62 | kafka_count = 0 63 | server = "localhost:9092" 64 | topic = f"dbos-kafka-{random.randrange(1_000_000_000)}" 65 | 66 | if not send_test_messages(server, topic): 67 | pytest.skip("Kafka not available") 68 | 69 | @DBOS.kafka_consumer( 70 | { 71 | "bootstrap.servers": server, 72 | "group.id": "dbos-test", 73 | "auto.offset.reset": "earliest", 74 | }, 75 | [topic], 76 | ) 77 | @DBOS.workflow() 78 | def test_kafka_workflow(msg: KafkaMessage) -> None: 79 | nonlocal kafka_count 80 | kafka_count += 1 81 | assert b"test message key" in msg.key # type: ignore 82 | assert b"test message value" in msg.value # type: ignore 83 | print(msg) 84 | if kafka_count == 3: 85 | event.set() 86 | 87 | wait = event.wait(timeout=10) 88 | assert wait 89 | assert kafka_count == 3 90 | 91 | 92 | def test_kafka_in_order(dbos: DBOS) -> None: 93 | event = threading.Event() 94 | kafka_count = 0 95 | server = "localhost:9092" 96 | topic = f"dbos-kafka-{random.randrange(1_000_000_000)}" 97 | 98 | if not send_test_messages(server, topic): 99 | pytest.skip("Kafka not available") 100 | 101 | @DBOS.kafka_consumer( 102 | { 103 | "bootstrap.servers": server, 104 | "group.id": "dbos-test", 105 | "auto.offset.reset": "earliest", 106 | }, 107 | [topic], 108 | in_order=True, 109 | ) 110 | @DBOS.workflow() 111 | def test_kafka_workflow(msg: KafkaMessage) -> None: 112 | time.sleep(random.uniform(0, 2)) 113 | nonlocal kafka_count 114 | kafka_count += 1 115 | assert f"test message key {kafka_count - 1}".encode() == msg.key 116 | print(msg) 117 | if kafka_count == 3: 118 | event.set() 119 | 120 | wait = event.wait(timeout=15) 121 | assert wait 122 | assert kafka_count == 3 123 | time.sleep(2) # Wait for things to clean up 124 | 125 | 126 | def test_kafka_no_groupid(dbos: DBOS) -> None: 127 | event = threading.Event() 128 | kafka_count = 0 129 | server = "localhost:9092" 130 | topic1 = f"dbos-kafka-{random.randrange(1_000_000_000, 2_000_000_000)}" 131 | topic2 = f"dbos-kafka-{random.randrange(2_000_000_000, 3_000_000_000)}" 132 | 133 | if not send_test_messages(server, topic1): 134 | pytest.skip("Kafka not available") 135 | 136 | if not send_test_messages(server, topic2): 137 | pytest.skip("Kafka not available") 138 | 139 | @DBOS.kafka_consumer( 140 | { 141 | "bootstrap.servers": server, 142 | "auto.offset.reset": "earliest", 143 | }, 144 | [topic1, topic2], 145 | ) 146 | @DBOS.workflow() 147 | def test_kafka_workflow(msg: KafkaMessage) -> None: 148 | nonlocal kafka_count 149 | kafka_count += 1 150 | assert b"test message key" in msg.key # type: ignore 151 | assert b"test message value" in msg.value # type: ignore 152 | print(msg) 153 | if kafka_count == 6: 154 | event.set() 155 | 156 | wait = event.wait(timeout=10) 157 | assert wait 158 | assert kafka_count == 6 159 | -------------------------------------------------------------------------------- /tests/test_outcome.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | from typing import Callable 4 | 5 | import pytest 6 | 7 | from dbos._outcome import Immediate, Outcome, Pending 8 | 9 | 10 | def after(result: Callable[[], int]) -> str: 11 | return f"Result: {result()}" 12 | 13 | 14 | def before(func: Callable[[], None]) -> Callable[[Callable[[], int]], str]: 15 | func() 16 | return after 17 | 18 | 19 | class ExceededRetries(Exception): 20 | pass 21 | 22 | 23 | def test_immediate() -> None: 24 | 25 | func_called = False 26 | 27 | def adder(a: int, b: int) -> int: 28 | assert func_called # ensure func called before adder 29 | return a + b 30 | 31 | def func() -> None: 32 | nonlocal func_called 33 | func_called = True 34 | 35 | o1 = Outcome[int].make(lambda: adder(10, 20)).wrap(lambda: before(func)) 36 | 37 | assert isinstance(o1, Immediate) 38 | 39 | output1 = o1() 40 | assert func_called 41 | assert output1 == "Result: 30" 42 | 43 | o3 = Outcome[int].make(lambda: adder(30, 40)).then(after) 44 | assert isinstance(o3, Immediate) 45 | out2 = o3() 46 | assert out2 == "Result: 70" 47 | 48 | 49 | def test_immediate_retry() -> None: 50 | 51 | count = 0 52 | 53 | def raiser() -> int: 54 | nonlocal count 55 | count += 1 56 | raise Exception("Error") 57 | 58 | o1 = Outcome[int].make(raiser) 59 | o2 = o1.retry(3, lambda i, e: 0.1, lambda i, e: ExceededRetries()) 60 | 61 | assert isinstance(o2, Immediate) 62 | with pytest.raises(ExceededRetries): 63 | o2() 64 | 65 | assert count == 3 66 | 67 | 68 | @pytest.mark.asyncio 69 | async def test_pending() -> None: 70 | func_called = False 71 | 72 | async def adder(a: int, b: int) -> int: 73 | assert func_called # ensure func called before adder 74 | await asyncio.sleep(0.1) # simulate async operation 75 | return a + b 76 | 77 | def func() -> None: 78 | nonlocal func_called 79 | func_called = True 80 | 81 | o1 = Outcome[int].make(functools.partial(adder, 10, 20)) 82 | o2 = o1.wrap(lambda: before(func)) 83 | 84 | assert isinstance(o1, Pending) 85 | assert isinstance(o2, Pending) 86 | 87 | output = await o2() 88 | assert func_called 89 | assert output == "Result: 30" 90 | 91 | o3 = Outcome[int].make(functools.partial(adder, 30, 40)).then(after) 92 | assert isinstance(o3, Pending) 93 | out2 = await o3() 94 | assert out2 == "Result: 70" 95 | 96 | 97 | @pytest.mark.asyncio 98 | async def test_pending_retry() -> None: 99 | 100 | count = 0 101 | 102 | async def raiser() -> int: 103 | nonlocal count 104 | count += 1 105 | raise Exception("Error") 106 | 107 | o1 = Outcome[int].make(raiser) 108 | o2 = o1.retry(3, lambda i, e: 0.1, lambda i, e: ExceededRetries()) 109 | 110 | assert isinstance(o2, Pending) 111 | with pytest.raises(ExceededRetries): 112 | await o2() 113 | 114 | assert count == 3 115 | -------------------------------------------------------------------------------- /tests/test_schema_migration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import pytest 5 | import sqlalchemy as sa 6 | from alembic import command 7 | from alembic.config import Config 8 | 9 | # Public API 10 | from dbos import DBOS, DBOSConfig 11 | 12 | # Private API because this is a unit test 13 | from dbos._schemas.system_database import SystemSchema 14 | 15 | 16 | def test_systemdb_migration(dbos: DBOS) -> None: 17 | # Make sure all tables exist 18 | with dbos._sys_db.engine.connect() as connection: 19 | sql = SystemSchema.workflow_status.select() 20 | result = connection.execute(sql) 21 | assert result.fetchall() == [] 22 | 23 | sql = SystemSchema.operation_outputs.select() 24 | result = connection.execute(sql) 25 | assert result.fetchall() == [] 26 | 27 | sql = SystemSchema.workflow_events.select() 28 | result = connection.execute(sql) 29 | assert result.fetchall() == [] 30 | 31 | sql = SystemSchema.notifications.select() 32 | result = connection.execute(sql) 33 | assert result.fetchall() == [] 34 | 35 | # Test migrating down 36 | rollback_system_db( 37 | sysdb_url=dbos._sys_db.engine.url.render_as_string(hide_password=False) 38 | ) 39 | 40 | with dbos._sys_db.engine.connect() as connection: 41 | with pytest.raises(sa.exc.ProgrammingError) as exc_info: 42 | sql = SystemSchema.workflow_status.select() 43 | result = connection.execute(sql) 44 | assert "does not exist" in str(exc_info.value) 45 | 46 | 47 | def test_custom_sysdb_name_migration( 48 | config: DBOSConfig, postgres_db_engine: sa.Engine 49 | ) -> None: 50 | sysdb_name = "custom_sysdb_name" 51 | config["sys_db_name"] = sysdb_name 52 | 53 | # Clean up from previous runs 54 | with postgres_db_engine.connect() as connection: 55 | connection.execution_options(isolation_level="AUTOCOMMIT") 56 | connection.execute(sa.text(f"DROP DATABASE IF EXISTS {sysdb_name}")) 57 | 58 | # Test migrating up 59 | DBOS.destroy() # In case of other tests leaving it 60 | dbos = DBOS(config=config) 61 | DBOS.launch() 62 | 63 | # Make sure all tables exist 64 | with dbos._sys_db.engine.connect() as connection: 65 | sql = SystemSchema.workflow_status.select() 66 | result = connection.execute(sql) 67 | assert result.fetchall() == [] 68 | 69 | # Test migrating down 70 | rollback_system_db( 71 | sysdb_url=dbos._sys_db.engine.url.render_as_string(hide_password=False) 72 | ) 73 | 74 | with dbos._sys_db.engine.connect() as connection: 75 | with pytest.raises(sa.exc.ProgrammingError) as exc_info: 76 | sql = SystemSchema.workflow_status.select() 77 | result = connection.execute(sql) 78 | assert "does not exist" in str(exc_info.value) 79 | DBOS.destroy() 80 | 81 | 82 | def rollback_system_db(sysdb_url: str) -> None: 83 | migration_dir = os.path.join( 84 | os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 85 | "dbos", 86 | "_migrations", 87 | ) 88 | alembic_cfg = Config() 89 | alembic_cfg.set_main_option("script_location", migration_dir) 90 | escaped_conn_string = re.sub( 91 | r"%(?=[0-9A-Fa-f]{2})", 92 | "%%", 93 | sysdb_url, 94 | ) 95 | alembic_cfg.set_main_option("sqlalchemy.url", escaped_conn_string) 96 | command.downgrade(alembic_cfg, "base") # Rollback all migrations 97 | 98 | 99 | def test_reset(config: DBOSConfig, postgres_db_engine: sa.Engine) -> None: 100 | DBOS.destroy() 101 | dbos = DBOS(config=config) 102 | DBOS.launch() 103 | 104 | # Make sure the system database exists 105 | with dbos._sys_db.engine.connect() as c: 106 | sql = SystemSchema.workflow_status.select() 107 | result = c.execute(sql) 108 | assert result.fetchall() == [] 109 | sysdb_name = dbos._sys_db.engine.url.database 110 | 111 | DBOS.destroy() 112 | dbos = DBOS(config=config) 113 | DBOS.reset_system_database() 114 | 115 | with postgres_db_engine.connect() as c: 116 | c.execution_options(isolation_level="AUTOCOMMIT") 117 | count: int = c.execute( 118 | sa.text(f"SELECT COUNT(*) FROM pg_database WHERE datname = '{sysdb_name}'") 119 | ).scalar_one() 120 | assert count == 0 121 | 122 | DBOS.launch() 123 | 124 | # Make sure the system database is recreated 125 | with dbos._sys_db.engine.connect() as c: 126 | sql = SystemSchema.workflow_status.select() 127 | result = c.execute(sql) 128 | assert result.fetchall() == [] 129 | 130 | # Verify that resetting after launch throws 131 | with pytest.raises(AssertionError): 132 | DBOS.reset_system_database() 133 | -------------------------------------------------------------------------------- /tests/test_singleton.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | import time 5 | from os import path 6 | from urllib.parse import quote 7 | 8 | import pytest 9 | 10 | # Public API 11 | from dbos import DBOS, DBOSConfig, SetWorkflowID, WorkflowHandle 12 | 13 | # Private API used because this is a test 14 | from dbos._context import DBOSContextEnsure, assert_current_dbos_context 15 | from tests.conftest import default_config 16 | 17 | 18 | def test_dbos_singleton(cleanup_test_databases: None) -> None: 19 | # Initialize singleton 20 | DBOS.destroy() # In case of other tests leaving it 21 | 22 | # Simulate an app that does some imports of its own code, then defines DBOS, 23 | # then imports more 24 | from tests.classdefs import DBOSSendRecv, DBOSTestClass, DBOSTestRoles 25 | 26 | dbos: DBOS = DBOS(config=default_config()) 27 | 28 | from tests.more_classdefs import DBOSWFEvents, wfFunc 29 | 30 | DBOS.launch() # Usually framework (fastapi) does this via lifecycle event 31 | 32 | # Basics 33 | with SetWorkflowID("wfid"): 34 | res = wfFunc("f") 35 | assert res == "f1" 36 | 37 | res = DBOSTestClass.test_workflow_cls("a", "b") 38 | assert res == "b1a" 39 | 40 | inst = DBOSTestClass() 41 | res = inst.test_workflow("c", "d") 42 | assert res == "d1c" 43 | 44 | wh = DBOS.start_workflow(DBOSTestClass.test_workflow_cls, "a", "b") 45 | assert wh.get_result() == "b1a" 46 | 47 | wh = DBOS.start_workflow(inst.test_workflow, "c", "d") 48 | assert wh.get_result() == "d1c" 49 | stati = DBOS.get_workflow_status(wh.get_workflow_id()) 50 | assert stati 51 | assert stati.config_name == "myconfig" 52 | assert stati.class_name and "DBOSTestClass" in stati.class_name 53 | wfhr: WorkflowHandle[str] = DBOS.retrieve_workflow(wh.get_workflow_id()) 54 | assert wfhr.workflow_id == wh.get_workflow_id() 55 | 56 | # Roles 57 | 58 | with DBOSContextEnsure(): 59 | with pytest.raises(Exception) as exc_info: 60 | DBOSTestRoles.greetfunc("Nobody") 61 | assert "DBOS Error 8" in str(exc_info.value) 62 | 63 | ctx = assert_current_dbos_context() 64 | ctx.authenticated_roles = ["user", "admin"] 65 | res = inst.test_func_admin("admin") 66 | assert res == "myconfig:admin" 67 | 68 | assert DBOSTestRoles.greetfunc("user") == "Hello user" 69 | 70 | # Send / Recv 71 | dest_uuid = str("sruuid1") 72 | 73 | with SetWorkflowID(dest_uuid): 74 | handle = dbos.start_workflow(DBOSSendRecv.test_recv_workflow, "testtopic") 75 | assert handle.get_workflow_id() == dest_uuid 76 | 77 | send_uuid = str("sruuid2") 78 | with SetWorkflowID(send_uuid): 79 | res = DBOSSendRecv.test_send_workflow(handle.get_workflow_id(), "testtopic") 80 | assert res == dest_uuid 81 | 82 | begin_time = time.time() 83 | assert handle.get_result() == "test2-test1-test3" 84 | duration = time.time() - begin_time 85 | assert duration < 3.0 86 | 87 | # Events 88 | wfuuid = str("sendwf1") 89 | with SetWorkflowID(wfuuid): 90 | DBOSWFEvents.test_setevent_workflow() 91 | with SetWorkflowID(wfuuid): 92 | DBOSWFEvents.test_setevent_workflow() 93 | 94 | value1 = DBOSWFEvents.test_getevent_workflow(wfuuid, "key1") 95 | assert value1 == "value1" 96 | 97 | value2 = DBOSWFEvents.test_getevent_workflow(wfuuid, "key2") 98 | assert value2 == "value2" 99 | 100 | # Run getEvent outside of a workflow 101 | value1 = DBOS.get_event(wfuuid, "key1") 102 | assert value1 == "value1" 103 | 104 | value2 = DBOS.get_event(wfuuid, "key2") 105 | assert value2 == "value2" 106 | 107 | begin_time = time.time() 108 | value3 = DBOSWFEvents.test_getevent_workflow(wfuuid, "key3") 109 | assert value3 is None 110 | duration = time.time() - begin_time 111 | assert duration < 1 # None is from the event not from the timeout 112 | 113 | DBOS.destroy() 114 | 115 | 116 | def test_dbos_singleton_negative(cleanup_test_databases: None) -> None: 117 | # Initialize singleton 118 | DBOS.destroy() # In case of other tests leaving it 119 | 120 | # Simulate an app that does some imports of its own code, then defines DBOS, 121 | # then imports more 122 | from tests.classdefs import DBOSTestClass 123 | 124 | DBOS(config=default_config()) 125 | 126 | # Something should have launched 127 | with pytest.raises(Exception) as exc_info: 128 | DBOSTestClass.test_workflow_cls("a", "b") 129 | assert "launch" in str(exc_info.value) 130 | 131 | DBOS.destroy() 132 | 133 | 134 | def test_dbos_atexit_no_dbos(cleanup_test_databases: None) -> None: 135 | # Run the .py as a separate process 136 | result = subprocess.run( 137 | [sys.executable, path.join("tests", "atexit_no_ctor.py")], 138 | capture_output=True, 139 | text=True, 140 | ) 141 | 142 | # Assert that the output contains the warning message 143 | assert "DBOS exiting; functions were registered" in result.stdout 144 | 145 | 146 | def test_dbos_atexit_no_launch(cleanup_test_databases: None) -> None: 147 | # Run the .py as a separate process 148 | result = subprocess.run( 149 | [sys.executable, path.join("tests", "atexit_no_launch.py")], 150 | capture_output=True, 151 | text=True, 152 | ) 153 | 154 | # Assert that the output contains the warning message 155 | assert "DBOS exists but launch() was not called" in result.stdout 156 | -------------------------------------------------------------------------------- /tests/test_spans.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | from fastapi import FastAPI 5 | from fastapi.testclient import TestClient 6 | from opentelemetry.sdk import trace as tracesdk 7 | from opentelemetry.sdk.trace.export import SimpleSpanProcessor 8 | from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter 9 | 10 | from dbos import DBOS 11 | from dbos._tracer import dbos_tracer 12 | from dbos._utils import GlobalParams 13 | 14 | 15 | def test_spans(dbos: DBOS) -> None: 16 | 17 | @DBOS.workflow() 18 | def test_workflow() -> None: 19 | test_step() 20 | current_span = DBOS.span 21 | subspan = DBOS.tracer.start_span({"name": "a new span"}, parent=current_span) 22 | subspan.add_event("greeting_event", {"name": "a new event"}) 23 | DBOS.tracer.end_span(subspan) 24 | 25 | @DBOS.step() 26 | def test_step() -> None: 27 | return 28 | 29 | exporter = InMemorySpanExporter() 30 | span_processor = SimpleSpanProcessor(exporter) 31 | provider = tracesdk.TracerProvider() 32 | provider.add_span_processor(span_processor) 33 | dbos_tracer.set_provider(provider) 34 | 35 | test_workflow() 36 | test_step() 37 | 38 | spans = exporter.get_finished_spans() 39 | 40 | assert len(spans) == 4 41 | 42 | for span in spans: 43 | assert span.attributes is not None 44 | assert span.attributes["applicationVersion"] == GlobalParams.app_version 45 | assert span.attributes["executorID"] == GlobalParams.executor_id 46 | assert span.context is not None 47 | 48 | assert spans[0].name == test_step.__name__ 49 | assert spans[1].name == "a new span" 50 | assert spans[2].name == test_workflow.__name__ 51 | assert spans[3].name == test_step.__name__ 52 | 53 | assert spans[0].parent.span_id == spans[2].context.span_id # type: ignore 54 | assert spans[1].parent.span_id == spans[2].context.span_id # type: ignore 55 | assert spans[2].parent == None 56 | assert spans[3].parent == None 57 | 58 | 59 | @pytest.mark.asyncio 60 | async def test_spans_async(dbos: DBOS) -> None: 61 | 62 | @DBOS.workflow() 63 | async def test_workflow() -> None: 64 | await test_step() 65 | current_span = DBOS.span 66 | subspan = DBOS.tracer.start_span({"name": "a new span"}, parent=current_span) 67 | subspan.add_event("greeting_event", {"name": "a new event"}) 68 | DBOS.tracer.end_span(subspan) 69 | 70 | @DBOS.step() 71 | async def test_step() -> None: 72 | return 73 | 74 | exporter = InMemorySpanExporter() 75 | span_processor = SimpleSpanProcessor(exporter) 76 | provider = tracesdk.TracerProvider() 77 | provider.add_span_processor(span_processor) 78 | dbos_tracer.set_provider(provider) 79 | 80 | await test_workflow() 81 | await test_step() 82 | 83 | spans = exporter.get_finished_spans() 84 | 85 | assert len(spans) == 4 86 | 87 | for span in spans: 88 | assert span.attributes is not None 89 | assert span.attributes["applicationVersion"] == GlobalParams.app_version 90 | assert span.attributes["executorID"] == GlobalParams.executor_id 91 | assert span.context is not None 92 | 93 | assert spans[0].name == test_step.__name__ 94 | assert spans[1].name == "a new span" 95 | assert spans[2].name == test_workflow.__name__ 96 | assert spans[3].name == test_step.__name__ 97 | 98 | assert spans[0].parent.span_id == spans[2].context.span_id # type: ignore 99 | assert spans[1].parent.span_id == spans[2].context.span_id # type: ignore 100 | assert spans[2].parent == None 101 | assert spans[3].parent == None 102 | 103 | 104 | def test_temp_wf_fastapi(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None: 105 | dbos, app = dbos_fastapi 106 | 107 | @app.get("/step") 108 | @DBOS.step() 109 | def test_step_endpoint() -> str: 110 | return "test" 111 | 112 | exporter = InMemorySpanExporter() 113 | span_processor = SimpleSpanProcessor(exporter) 114 | provider = tracesdk.TracerProvider() 115 | provider.add_span_processor(span_processor) 116 | dbos_tracer.set_provider(provider) 117 | 118 | client = TestClient(app) 119 | response = client.get("/step") 120 | assert response.status_code == 200 121 | assert response.text == '"test"' 122 | 123 | spans = exporter.get_finished_spans() 124 | 125 | assert len(spans) == 2 126 | 127 | for span in spans: 128 | assert span.attributes is not None 129 | assert span.attributes["applicationVersion"] == GlobalParams.app_version 130 | assert span.context is not None 131 | 132 | assert spans[0].name == test_step_endpoint.__name__ 133 | assert spans[1].name == "/step" 134 | 135 | assert spans[0].parent.span_id == spans[1].context.span_id # type:ignore 136 | assert spans[1].parent == None 137 | -------------------------------------------------------------------------------- /tests/test_sqlalchemy.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 6 | 7 | # Public API 8 | from dbos import DBOS, SetWorkflowID 9 | 10 | 11 | # Declare a SQLAlchemy ORM base class 12 | class Base(DeclarativeBase): 13 | pass 14 | 15 | 16 | # Declare a SQLAlchemy ORM class for accessing the database table. 17 | class Hello(Base): 18 | __tablename__ = "dbos_hello" 19 | greet_count: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) 20 | name: Mapped[str] = mapped_column(nullable=False) 21 | 22 | def __repr__(self) -> str: 23 | return f"Hello(greet_count={self.greet_count!r}, name={self.name!r})" 24 | 25 | 26 | def test_simple_transaction(dbos: DBOS, postgres_db_engine: sa.Engine) -> None: 27 | txn_counter: int = 0 28 | assert dbos._app_db_field is not None 29 | Base.metadata.drop_all(dbos._app_db_field.engine) 30 | Base.metadata.create_all(dbos._app_db_field.engine) 31 | 32 | @DBOS.transaction() 33 | def test_transaction(name: str) -> str: 34 | new_greeting = Hello(name=name) 35 | DBOS.sql_session.add(new_greeting) 36 | stmt = ( 37 | sa.select(Hello) 38 | .where(Hello.name == name) 39 | .order_by(Hello.greet_count.desc()) 40 | .limit(1) 41 | ) 42 | row = DBOS.sql_session.scalar(stmt) 43 | assert row is not None 44 | greet_count = row.greet_count 45 | nonlocal txn_counter 46 | txn_counter += 1 47 | return name + str(greet_count) 48 | 49 | assert test_transaction("alice") == "alice1" 50 | assert test_transaction("alice") == "alice2" 51 | assert txn_counter == 2 52 | 53 | # Test OAOO 54 | wfuuid = str(uuid.uuid4()) 55 | with SetWorkflowID(wfuuid): 56 | assert test_transaction("alice") == "alice3" 57 | with SetWorkflowID(wfuuid): 58 | assert test_transaction("alice") == "alice3" 59 | assert txn_counter == 3 # Only increment once 60 | 61 | Base.metadata.drop_all(dbos._app_db_field.engine) 62 | 63 | # Make sure no transactions are left open 64 | with postgres_db_engine.begin() as conn: 65 | result = conn.execute( 66 | sa.text( 67 | "select * from pg_stat_activity where state = 'idle in transaction'" 68 | ) 69 | ).fetchall() 70 | assert len(result) == 0 71 | 72 | 73 | def test_error_transaction(dbos: DBOS, postgres_db_engine: sa.Engine) -> None: 74 | txn_counter: int = 0 75 | assert dbos._app_db_field is not None 76 | # Drop the database but don't re-create. Should fail. 77 | Base.metadata.drop_all(dbos._app_db_field.engine) 78 | 79 | @DBOS.transaction() 80 | def test_transaction(name: str) -> str: 81 | nonlocal txn_counter 82 | txn_counter += 1 83 | new_greeting = Hello(name=name) 84 | DBOS.sql_session.add(new_greeting) 85 | return name 86 | 87 | with pytest.raises(Exception) as exc_info: 88 | test_transaction("alice") 89 | assert 'relation "dbos_hello" does not exist' in str(exc_info.value) 90 | assert txn_counter == 1 91 | 92 | # Test OAOO 93 | wfuuid = str(uuid.uuid4()) 94 | with SetWorkflowID(wfuuid): 95 | with pytest.raises(Exception) as exc_info: 96 | test_transaction("alice") 97 | assert 'relation "dbos_hello" does not exist' in str(exc_info.value) 98 | assert txn_counter == 2 99 | 100 | with SetWorkflowID(wfuuid): 101 | with pytest.raises(Exception) as exc_info: 102 | test_transaction("alice") 103 | assert 'relation "dbos_hello" does not exist' in str(exc_info.value) 104 | assert txn_counter == 2 105 | 106 | # Make sure no transactions are left open 107 | with postgres_db_engine.begin() as conn: 108 | result = conn.execute( 109 | sa.text( 110 | "select * from pg_stat_activity where state = 'idle in transaction'" 111 | ) 112 | ).fetchall() 113 | assert len(result) == 0 114 | -------------------------------------------------------------------------------- /version/__init__.py: -------------------------------------------------------------------------------- 1 | from pdm.backend.hooks.version import SCMVersion 2 | 3 | 4 | def format_version(git_version: SCMVersion) -> str: 5 | """ 6 | Format version into string. 7 | 8 | 1. Release versions may only be published from release branches. Their version is a git tag. 9 | 2. Preview versions are published from main. They are PEP440 alpha releases whose version is the 10 | next release version number followed by "a" followed by the number of commits since the last release. 11 | If the last release was 1.2.3 and there have been ten commits since, the preview version is 1.2.3a10 12 | 3. Test versions are published from feature branches. They are PEP440 local versions tagged with a git hash. 13 | """ 14 | assert git_version.branch is not None 15 | is_release = "release" in git_version.branch 16 | is_preview = git_version.branch == "main" 17 | 18 | next_version = guess_next_version(str(git_version.version)) 19 | 20 | if git_version.distance is None: 21 | if is_release: 22 | version = str(git_version.version) 23 | elif is_preview: 24 | version = f"{next_version}a0" 25 | else: 26 | version = f"{next_version}a0+{git_version.node}" 27 | else: 28 | if is_release: 29 | raise Exception( 30 | f"Release branches may only publish tagged releases. Distance: {git_version.distance}" 31 | ) 32 | elif is_preview: 33 | version = f"{next_version}a{git_version.distance}" 34 | else: 35 | version = f"{next_version}a{git_version.distance}+{git_version.node}" 36 | 37 | return version 38 | 39 | 40 | def guess_next_version(version_number: str) -> str: 41 | major, minor, patch = map(int, version_number.split(".")) 42 | minor += 1 43 | return f"{major}.{minor}.{patch}" 44 | --------------------------------------------------------------------------------