├── .github └── workflows │ ├── main.yml │ └── release.yml ├── .gitignore ├── .pyup.yml ├── .travis.yml ├── LICENSE ├── README.md ├── bin └── release.sh ├── requirements ├── develop.txt └── production.txt ├── setup.cfg ├── setup.py ├── sqlalchemy-fsm.sublime-project ├── src └── sqlalchemy_fsm │ ├── __init__.py │ ├── bound.py │ ├── cache.py │ ├── events.py │ ├── exc.py │ ├── meta.py │ ├── sqltypes.py │ ├── transition.py │ └── util.py ├── test ├── __init__.py ├── conftest.py ├── test_basic.py ├── test_conditionals.py ├── test_events.py ├── test_invalids.py ├── test_multi_source.py ├── test_performance.py └── test_transition_classes.py └── tox.ini /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: CI 4 | 5 | on: 6 | push: 7 | branches: 8 | - '**' 9 | pull_request: 10 | branches: 11 | - '**' 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | jobs: 17 | debug: 18 | # This is just a debug job for assorted github vars I am using 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: GitHub context 22 | env: 23 | GITHUB_CONTEXT: ${{ toJSON(github) }} 24 | run: echo "$GITHUB_CONTEXT" 25 | test: 26 | runs-on: ubuntu-latest 27 | strategy: 28 | matrix: 29 | python-version: [3.6, 3.9, '3.10'] 30 | sqlalchemy-version: ['>=1.2,<1.3', '>=1.3,<1.4', '>=1.4'] 31 | steps: 32 | - uses: actions/checkout@v2 33 | - name: Set up Python ${{ matrix.python-version }} 34 | uses: actions/setup-python@v2 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install tox tox-gh-actions 41 | - name: Test with tox 42 | env: 43 | SQLALCHEMY_VERSION_SPEC: ${{ matrix.sqlalchemy-version }} 44 | run: tox -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | workflow_run: 4 | workflows: ["CI"] 5 | branches: [main] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | publish: 11 | if: ${{ github.event.workflow_run.conclusion == 'success' }} 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | - name: Install dependencies 18 | run: pip install build 19 | - name: Build packages 20 | run: python -m build --outdir dist/ 21 | - name: Publish distribution to Test PyPI 22 | uses: pypa/gh-action-pypi-publish@master 23 | with: 24 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 25 | repository_url: https://test.pypi.org/legacy/ 26 | - name: Publish distribution to PRODUCTION PyPI 27 | # if: false 28 | uses: pypa/gh-action-pypi-publish@master 29 | with: 30 | password: ${{ secrets.PYPI_API_TOKEN }} 31 | bump-version: # Bump version for the next release 32 | needs: publish 33 | runs-on: ubuntu-latest 34 | # Bump version in the PR that targets the "main" branch 35 | steps: 36 | - uses: actions/checkout@v2 37 | - name: Set up Python 38 | uses: actions/setup-python@v2 39 | - name: Install dependencies 40 | run: | 41 | pip install bump2version 42 | - name: Bump version 43 | run: | 44 | git config --global user.email "github+actions@gmail.com" 45 | git config --global user.name "Actions" 46 | bump2version patch 47 | git push --follow-tags -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/vim,python,windows,linux,osx,sublimetext,git 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=vim,python,windows,linux,osx,sublimetext,git 4 | 5 | ### Git ### 6 | # Created by git for backups. To disable backups in Git: 7 | # $ git config --global mergetool.keepBackup false 8 | *.orig 9 | 10 | # Created by git when using merge tools for conflicts 11 | *.BACKUP.* 12 | *.BASE.* 13 | *.LOCAL.* 14 | *.REMOTE.* 15 | *_BACKUP_*.txt 16 | *_BASE_*.txt 17 | *_LOCAL_*.txt 18 | *_REMOTE_*.txt 19 | 20 | ### Linux ### 21 | *~ 22 | 23 | # temporary files which can be created if a process still has a handle open of a deleted file 24 | .fuse_hidden* 25 | 26 | # KDE directory preferences 27 | .directory 28 | 29 | # Linux trash folder which might appear on any partition or disk 30 | .Trash-* 31 | 32 | # .nfs files are created when an open file is removed but is still being accessed 33 | .nfs* 34 | 35 | ### OSX ### 36 | # General 37 | .DS_Store 38 | .AppleDouble 39 | .LSOverride 40 | 41 | # Icon must end with two \r 42 | Icon 43 | 44 | 45 | # Thumbnails 46 | ._* 47 | 48 | # Files that might appear in the root of a volume 49 | .DocumentRevisions-V100 50 | .fseventsd 51 | .Spotlight-V100 52 | .TemporaryItems 53 | .Trashes 54 | .VolumeIcon.icns 55 | .com.apple.timemachine.donotpresent 56 | 57 | # Directories potentially created on remote AFP share 58 | .AppleDB 59 | .AppleDesktop 60 | Network Trash Folder 61 | Temporary Items 62 | .apdisk 63 | 64 | ### Python ### 65 | # Byte-compiled / optimized / DLL files 66 | __pycache__/ 67 | *.py[cod] 68 | *$py.class 69 | 70 | # C extensions 71 | *.so 72 | 73 | # Distribution / packaging 74 | .Python 75 | build/ 76 | develop-eggs/ 77 | dist/ 78 | downloads/ 79 | eggs/ 80 | .eggs/ 81 | lib/ 82 | lib64/ 83 | parts/ 84 | sdist/ 85 | var/ 86 | wheels/ 87 | share/python-wheels/ 88 | *.egg-info/ 89 | .installed.cfg 90 | *.egg 91 | MANIFEST 92 | 93 | # PyInstaller 94 | # Usually these files are written by a python script from a template 95 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 96 | *.manifest 97 | *.spec 98 | 99 | # Installer logs 100 | pip-log.txt 101 | pip-delete-this-directory.txt 102 | 103 | # Unit test / coverage reports 104 | htmlcov/ 105 | .tox/ 106 | .nox/ 107 | .coverage 108 | .coverage.* 109 | .cache 110 | nosetests.xml 111 | coverage.xml 112 | *.cover 113 | *.py,cover 114 | .hypothesis/ 115 | .pytest_cache/ 116 | cover/ 117 | 118 | # Translations 119 | *.mo 120 | *.pot 121 | 122 | # Django stuff: 123 | *.log 124 | local_settings.py 125 | db.sqlite3 126 | db.sqlite3-journal 127 | 128 | # Flask stuff: 129 | instance/ 130 | .webassets-cache 131 | 132 | # Scrapy stuff: 133 | .scrapy 134 | 135 | # Sphinx documentation 136 | docs/_build/ 137 | 138 | # PyBuilder 139 | .pybuilder/ 140 | target/ 141 | 142 | # Jupyter Notebook 143 | .ipynb_checkpoints 144 | 145 | # IPython 146 | profile_default/ 147 | ipython_config.py 148 | 149 | # pyenv 150 | # For a library or package, you might want to ignore these files since the code is 151 | # intended to run in multiple environments; otherwise, check them in: 152 | # .python-version 153 | 154 | # pipenv 155 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 156 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 157 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 158 | # install all needed dependencies. 159 | #Pipfile.lock 160 | 161 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 162 | __pypackages__/ 163 | 164 | # Celery stuff 165 | celerybeat-schedule 166 | celerybeat.pid 167 | 168 | # SageMath parsed files 169 | *.sage.py 170 | 171 | # Environments 172 | .env 173 | .venv 174 | env/ 175 | venv/ 176 | ENV/ 177 | env.bak/ 178 | venv.bak/ 179 | 180 | # Spyder project settings 181 | .spyderproject 182 | .spyproject 183 | 184 | # Rope project settings 185 | .ropeproject 186 | 187 | # mkdocs documentation 188 | /site 189 | 190 | # mypy 191 | .mypy_cache/ 192 | .dmypy.json 193 | dmypy.json 194 | 195 | # Pyre type checker 196 | .pyre/ 197 | 198 | # pytype static type analyzer 199 | .pytype/ 200 | 201 | # Cython debug symbols 202 | cython_debug/ 203 | 204 | ### SublimeText ### 205 | # Cache files for Sublime Text 206 | *.tmlanguage.cache 207 | *.tmPreferences.cache 208 | *.stTheme.cache 209 | 210 | # Workspace files are user-specific 211 | *.sublime-workspace 212 | 213 | # Project files should be checked into the repository, unless a significant 214 | # proportion of contributors will probably not be using Sublime Text 215 | # *.sublime-project 216 | 217 | # SFTP configuration file 218 | sftp-config.json 219 | sftp-config-alt*.json 220 | 221 | # Package control specific files 222 | Package Control.last-run 223 | Package Control.ca-list 224 | Package Control.ca-bundle 225 | Package Control.system-ca-bundle 226 | Package Control.cache/ 227 | Package Control.ca-certs/ 228 | Package Control.merged-ca-bundle 229 | Package Control.user-ca-bundle 230 | oscrypto-ca-bundle.crt 231 | bh_unicode_properties.cache 232 | 233 | # Sublime-github package stores a github token in this file 234 | # https://packagecontrol.io/packages/sublime-github 235 | GitHub.sublime-settings 236 | 237 | ### Vim ### 238 | # Swap 239 | [._]*.s[a-v][a-z] 240 | !*.svg # comment out if you don't need vector files 241 | [._]*.sw[a-p] 242 | [._]s[a-rt-v][a-z] 243 | [._]ss[a-gi-z] 244 | [._]sw[a-p] 245 | 246 | # Session 247 | Session.vim 248 | Sessionx.vim 249 | 250 | # Temporary 251 | .netrwhist 252 | # Auto-generated tag files 253 | tags 254 | # Persistent undo 255 | [._]*.un~ 256 | 257 | ### Windows ### 258 | # Windows thumbnail cache files 259 | Thumbs.db 260 | Thumbs.db:encryptable 261 | ehthumbs.db 262 | ehthumbs_vista.db 263 | 264 | # Dump file 265 | *.stackdump 266 | 267 | # Folder config file 268 | [Dd]esktop.ini 269 | 270 | # Recycle Bin used on file shares 271 | $RECYCLE.BIN/ 272 | 273 | # Windows Installer files 274 | *.cab 275 | *.msi 276 | *.msix 277 | *.msm 278 | *.msp 279 | 280 | # Windows shortcuts 281 | *.lnk 282 | 283 | # End of https://www.toptal.com/developers/gitignore/api/vim,python,windows,linux,osx,sublimetext,git 284 | -------------------------------------------------------------------------------- /.pyup.yml: -------------------------------------------------------------------------------- 1 | # configure updates globally 2 | # default: all 3 | # allowed: all, insecure, False 4 | update: insecure 5 | 6 | # configure dependency pinning globally 7 | # default: True 8 | # allowed: True, False 9 | pin: True 10 | 11 | # set the default branch 12 | # default: empty, the default branch on GitHub 13 | branch: develop 14 | 15 | # update schedule 16 | # default: empty 17 | # allowed: "every day", "every week", .. 18 | schedule: "every week" 19 | 20 | # search for requirement files 21 | # default: True 22 | # allowed: True, False 23 | search: True 24 | 25 | # add a label to pull requests, default is not set 26 | # requires private repo permissions, even on public repos 27 | # default: empty 28 | label_prs: update 29 | 30 | # configure the branch prefix the bot is using 31 | # default: pyup- 32 | branch_prefix: pyup/ 33 | 34 | # allow to close stale PRs 35 | # default: True 36 | close_prs: True -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "2.7" 5 | - "3.6" 6 | env: 7 | - USE_MIN_PACKAGE_VERSIONS=no 8 | - USE_MIN_PACKAGE_VERSIONS=yes 9 | 10 | install: 11 | - pip install -r requirements/develop.txt 12 | - pip install codeclimate-test-reporter 13 | - | 14 | if [ "${USE_MIN_PACKAGE_VERSIONS}" == "yes" ]; 15 | then 16 | pip install "SQLAlchemy==1.0.0" "six==1.10.0"; 17 | fi 18 | - pip freeze # Print versions of all installed packages for logging purposes 19 | script: 20 | - py.test -m pep8 21 | - py.test 22 | after_success: 23 | - codeclimate-test-reporter --token 98492b2b5bbaf1cb15691d14f937120784e2bb8ba613005c4cf92251de7d1807 24 | 25 | deploy: 26 | provider: pypi 27 | user: ilja_o 28 | password: 29 | secure: "GHQ6y8EilsMhLt41/lJUL2BYT7uKQFEAlKUFtyRPJasVKKYEJ3DLw34GPg0lyArBCcsuO0wXk1swulXbYn5lcD6dVhBveod01NIhvIVahkIcjvI9A26r4vIh7zJsKhG4+aJvMGSW8kAmLkIofhaFXgaKJD8WtqKD9NPApKteNFbctyWCxklPpyZ4XhQ42q0UyfyP54aSDM1tQtJMWVIzwvHCOiyFbPBRBkCoZsE6c4hND0pFllmGodvt9TFeBg/1R2du3L/wfaLvRc62SE5FfDPyOkvhZtCL32hDWkmGDeFG/JzvbGnJ8RKSijy+6kG88653gVJ85XiU3OFyYrto1XMQTNIhiOZeIZvHVYRpNDVW9erpWrcXG5KuSkQlR467w6J+OAMCiKG+LyT7rTsg7Y1Xz5lqxURd14ZTes8xDcH1sv8UMncblrpmEurkkVKvyb2Cqn3iH2XzXeLiKjJw1vnxPVuOd+5Ib37/X3ezyFN3YkuwD9NgKYHMrv4TnSuzMl+4L0RSDzdC03NFwTdtocvO2nfbkFTbz4EP5lmcJv6/2MhoX3QMgBBnUIGBa9EKD8oPmkwgJbzingZgfegtuQUk8PL8JjGGSyfAsMA9OZPffkUe7nahgphCwoV2q4ZYqkFoQ9KuhULVUlKxNDSCdU/hvFugJevaMp3Brlxm2uY=" 30 | on: 31 | branch: production 32 | python: 3.6 33 | condition: "${USE_MIN_PACKAGE_VERSIONS} == no" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | copyright (c) 2010 Mikhail Podgurskiy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://badge.fury.io/py/sqlalchemy-fsm.svg)](https://badge.fury.io/py/sqlalchemy-fsm) 2 | [![Build Status](https://travis-ci.org/VRGhost/sqlalchemy-fsm.svg?branch=master)](https://travis-ci.org/VRGhost/sqlalchemy-fsm) 3 | 4 | Finite state machine field for sqlalchemy 5 | ============================================================== 6 | 7 | sqlalchemy-fsm adds declarative states management for sqlalchemy models. 8 | Instead of adding some state field to a model, and manage its 9 | values by hand, you could use FSMState field and mark model methods 10 | with the `transition` decorator. Your method will contain the side-effects 11 | of the state change. 12 | 13 | The decorator also takes a list of conditions, all of which must be met 14 | before a transition is allowed. 15 | 16 | Usage 17 | ----- 18 | 19 | Add FSMState field to you model 20 | 21 | ```python 22 | from sqlalchemy_fsm import FSMField, transition 23 | 24 | class BlogPost(db.Model): 25 | state = db.Column(FSMField, nullable = False) 26 | ``` 27 | 28 | Use the `transition` decorator to annotate model methods 29 | 30 | ```python 31 | @transition(source='new', target='published') 32 | def published(self): 33 | """ 34 | This function may contain side-effects, 35 | like updating caches, notifying users, etc. 36 | The return value will be discarded. 37 | """ 38 | ``` 39 | 40 | `source` parameter accepts a list of states, or an individual state. 41 | You can use `*` for source, to allow switching to `target` from any state. 42 | 43 | `@transition`- annotated methods have the following API: 44 | 1. `.method()` - returns an SqlAlchemy filter condition that can be used for querying the database (e.g. `session.query(BlogPost).filter(BlogPost.published())`) 45 | 1. `.method.is_()` - same as `.method() == ` 46 | 1. `.method()` - returns boolean value that tells if this particular record is in the target state for that method() (e.g. `if not blog.published():`) 47 | 1. `.method.set(*args, **kwargs)` - changes the state of the record object to the transitions' target state (or raises an exception if it is not able to do so) 48 | 1. `.method.can_proceed(*args, **kwargs)` - returns `True` if calling `.method.set(*args, **kwargs)` (with same `*args, **kwargs`) should succeed. 49 | 50 | You can also use `None` as source state for (e.g. in case when the state column in nullable). 51 | However, it is _not possible_ to create transition with `None` as target state due to religious reasons. 52 | 53 | Transition can be also used on a class object to create a group of handlers 54 | for same target state. 55 | 56 | ```python 57 | @transition(target='published') 58 | class PublishHandler(object): 59 | 60 | @transition(source='new') 61 | def do_one(self, instance, value): 62 | instance.side_effect = "published from new" 63 | 64 | @transition(source='draft') 65 | def do_two(self, instance, value): 66 | instance.side_effect = "published from draft" 67 | 68 | 69 | class BlogPost(db.Model): 70 | ... 71 | published = PublishHandler 72 | ``` 73 | 74 | The transition is still to be invoked by calling the model's `published.set()` method. 75 | 76 | An alternative inline class syntax is supported too: 77 | 78 | ```python 79 | @transition(target='published') 80 | class published(object): 81 | 82 | @transition(source='new') 83 | def do_one(self, instance, value): 84 | instance.side_effect = "published from new" 85 | 86 | @transition(source='draft') 87 | def do_two(self, instance, value): 88 | instance.side_effect = "published from draft" 89 | ``` 90 | 91 | If calling `published.set()` succeeds without raising an exception, the state field 92 | will be changed, but not written to the database. 93 | 94 | ```python 95 | def publish_view(request, post_id): 96 | post = get_object__or_404(BlogPost, pk=post_id) 97 | if not post.published.can_proceed(): 98 | raise Http404; 99 | 100 | post.published.set() 101 | post.save() 102 | return redirect('/') 103 | ``` 104 | 105 | If your given function requires arguments to validate, you need to include them 106 | when calling `can_proceed` as well as including them when you call the function 107 | normally. Say `publish.set()` required a date for some reason: 108 | 109 | ```python 110 | if not post.published.can_proceed(the_date): 111 | raise Http404 112 | else: 113 | post.publish(the_date) 114 | ``` 115 | 116 | If your code needs to know the state model is currently in, you can just call 117 | the main function function. 118 | 119 | ```python 120 | if post.deleted(): 121 | raise Http404 122 | ``` 123 | 124 | If you require some conditions to be met before changing state, use the 125 | `conditions` argument to `transition`. `conditions` must be a list of functions 126 | that take one argument, the model instance. The function must return either 127 | `True` or `False` or a value that evaluates to `True` or `False`. If all 128 | functions return `True`, all conditions are considered to be met and transition 129 | is allowed to happen. If one of the functions return `False`, the transition 130 | will not happen. These functions should not have any side effects. 131 | 132 | You can use ordinary functions 133 | 134 | ```python 135 | def can_publish(instance): 136 | # No publishing after 17 hours 137 | if datetime.datetime.now().hour > 17: 138 | return False 139 | return True 140 | ``` 141 | 142 | Or model methods 143 | 144 | ```python 145 | def can_destroy(self): 146 | return not self.is_under_investigation() 147 | ``` 148 | 149 | Use the conditions like this: 150 | 151 | ```python 152 | @transition(source='new', target='published', conditions=[can_publish]) 153 | def publish(self): 154 | """ 155 | Side effects galore 156 | """ 157 | 158 | @transition(source='*', target='destroyed', conditions=[can_destroy]) 159 | def destroy(self): 160 | """ 161 | Side effects galore 162 | """ 163 | ``` 164 | 165 | You can also use FSM handlers to query the database. E.g. 166 | 167 | ```python 168 | session.query(BlogCls).filter(BlogCls.publish()) 169 | ``` 170 | 171 | will return all "Blog" objects whose current state matches "publish"'es target state. 172 | 173 | Events 174 | ------ 175 | 176 | Sqlalchemy-fsm integrates with sqlalchemy's event system. 177 | The library exposes two events `before_state_change` and `after_state_change` that are fired up 178 | at the expected points of state's lifecycle. 179 | 180 | You can subscribe event listeners via standard SQLAlchemy interface of 181 | `listens_for` or `listen`. 182 | 183 | ```python 184 | from sqlalchemy.event import listens_for 185 | 186 | @listens_for(Blog, 'before_state_change') 187 | def on_state_change(instance, source, target): 188 | ... 189 | ``` 190 | 191 | Or 192 | 193 | ```python 194 | from sqlalchemy import event 195 | 196 | def on_state_change(instance, source, target): 197 | ... 198 | 199 | event.listen(Blog, 'after_state_change', on_state_change) 200 | ``` 201 | 202 | It is possible to de-register an event listener call with `sqlalchemy.event.remove()` method. 203 | 204 | How does sqlalchemy-fsm diverge from django-fsm? 205 | ------------------------------------------------ 206 | 207 | * Can't commit data from within transition-decorated functions 208 | 209 | * Does support arguments to conditions functions 210 | -------------------------------------------------------------------------------- /bin/release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | function require_clean_work_tree () { 4 | # Update the index 5 | git update-index -q --ignore-submodules --refresh 6 | err=0 7 | 8 | # Disallow unstaged changes in the working tree 9 | if ! git diff-files --quiet --ignore-submodules -- 10 | then 11 | echo >&2 "cannot $1: you have unstaged changes." 12 | git diff-files --name-status -r --ignore-submodules -- >&2 13 | err=1 14 | fi 15 | 16 | # Disallow uncommitted changes in the index 17 | if ! git diff-index --cached --quiet HEAD --ignore-submodules -- 18 | then 19 | echo >&2 "cannot $1: your index contains uncommitted changes." 20 | git diff-index --cached --name-status -r --ignore-submodules HEAD -- >&2 21 | err=1 22 | fi 23 | 24 | if [ $err = 1 ] 25 | then 26 | echo >&2 "Please commit or stash them." 27 | exit 1 28 | fi 29 | } 30 | 31 | echo "This script increments build version and merges current master into production" 32 | echo " with an appropriate tag." 33 | echo "" 34 | echo "Pass a name of version number to be incremented ('major', 'minor' or 'patch')" 35 | 36 | PROJECT_DIR="$(dirname "${BASH_SOURCE[0]}")/.." 37 | BUMPED_VERSION="$1" 38 | 39 | if [ -z "${BUMPED_VERSION}" ]; then 40 | echo >&2 "You must specify what version number to increment" 41 | exit 1 42 | fi 43 | 44 | CURRENT_BRANCH=$(git symbolic-ref -q HEAD) 45 | CURRENT_BRANCH=${CURRENT_BRANCH##refs/heads/} 46 | 47 | if [ "${CURRENT_BRANCH}" != "master" ]; then 48 | echo >&2 "You must be on 'master' branch." 49 | exit 1 50 | fi 51 | 52 | require_clean_work_tree 53 | 54 | git checkout production 55 | git merge -X theirs --squash master 56 | 57 | # Execute tests (just in case) 58 | python "${PROJECT_DIR}/setup.py" test 59 | bumpversion --allow-dirty --message 'New release on {utcnow}: {new_version}' "${BUMPED_VERSION}" 60 | 61 | git push origin production --tags 62 | 63 | git checkout master 64 | git merge -X theirs production -m "Updating version number(s)" # Update version string(s) 65 | git push -------------------------------------------------------------------------------- /requirements/develop.txt: -------------------------------------------------------------------------------- 1 | black==22.6.0 2 | bump2version 3 | coverage 4 | flake8 5 | flake8-bugbear 6 | flake8-import-order 7 | pep8-naming 8 | pytest-benchmark 9 | pytest-cov 10 | pytest-runner 11 | pytest 12 | 13 | -r production.txt -------------------------------------------------------------------------------- /requirements/production.txt: -------------------------------------------------------------------------------- 1 | six>=1.10.0 2 | SQLAlchemy>=1.0.0 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 2.0.13 3 | commit = True 4 | tag = True 5 | 6 | [aliases] 7 | test = pytest 8 | 9 | [metadata] 10 | description-file = README.md 11 | 12 | [bumpversion:file:src/sqlalchemy_fsm/__init__.py] 13 | 14 | [bumpversion:file:setup.py] 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import setuptools 4 | 5 | 6 | setuptools.setup( 7 | name="sqlalchemy_fsm", 8 | packages=setuptools.find_packages(where="src"), 9 | package_dir={"": "src"}, 10 | py_modules=["sqlalchemy_fsm"], 11 | description="Finite state machine field for sqlalchemy", 12 | long_description=open("README.md").read(), 13 | long_description_content_type="text/markdown", 14 | author="Ilja & Peter", 15 | author_email="ilja@wise.fish", 16 | license="MIT", 17 | classifiers=[ 18 | "Development Status :: 5 - Production/Stable", 19 | "Intended Audience :: Developers", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python :: 3", 23 | "Topic :: Database", 24 | ], 25 | keywords="sqlalchemy finite state machine fsm", 26 | version="2.0.13", 27 | url="https://github.com/VRGhost/sqlalchemy-fsm", 28 | install_requires=[ 29 | "SQLAlchemy>=1.0.0", 30 | "six>=1.10.0", 31 | ], 32 | python_requires=">=3.6", 33 | setup_requires=["pytest-runner"], 34 | tests_require=["pytest"], 35 | ) 36 | -------------------------------------------------------------------------------- /sqlalchemy-fsm.sublime-project: -------------------------------------------------------------------------------- 1 | { 2 | "folders": 3 | [ 4 | { 5 | "path": "." 6 | } 7 | ], 8 | "settings": { 9 | "tab_size": 4, 10 | "translate_tabs_to_spaces": true, 11 | "detect_indentation": false, 12 | "remove_trailing_whitespace_on_save": true, 13 | "ensure_single_trailing_newline": false, 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | events, 3 | exc, 4 | ) 5 | from .sqltypes import FSMField 6 | from .transition import transition 7 | 8 | __version__ = "2.0.13" 9 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/bound.py: -------------------------------------------------------------------------------- 1 | """ 2 | Non-meta objects that are bound to a particular table & sqlalchemy instance. 3 | """ 4 | 5 | import inspect as py_inspect 6 | import warnings 7 | 8 | from sqlalchemy import inspect as sqla_inspect 9 | 10 | 11 | from . import cache, events, exc, meta 12 | from .sqltypes import FSMField 13 | 14 | 15 | @cache.weak_value_cache 16 | def column_cache(table_class): 17 | fsm_fields = [ 18 | col 19 | for col in sqla_inspect(table_class).columns 20 | if isinstance(col.type, FSMField) 21 | ] 22 | 23 | if len(fsm_fields) == 0: 24 | raise exc.SetupError("No FSMField found in model") 25 | elif len(fsm_fields) > 1: 26 | raise exc.SetupError( 27 | "More than one FSMField found in model ({})".format(fsm_fields) 28 | ) 29 | return fsm_fields[0] 30 | 31 | 32 | class SqlAlchemyHandle(object): 33 | 34 | __slots__ = ( 35 | "table_class", 36 | "record", 37 | "fsm_column", 38 | "dispatch", 39 | "column_name", 40 | ) 41 | 42 | def __init__(self, table_class, table_record_instance=None): 43 | self.table_class = table_class 44 | self.record = table_record_instance 45 | self.fsm_column = column_cache.get_value(table_class) 46 | self.column_name = self.fsm_column.name 47 | 48 | if table_record_instance: 49 | self.dispatch = events.BoundFSMDispatcher(table_record_instance) 50 | 51 | 52 | class BoundFSMBase(object): 53 | 54 | __slots__ = ("meta", "sqla_handle", "extra_call_args") 55 | 56 | def __init__(self, meta, sqla_handle, extra_call_args): 57 | self.meta = meta 58 | self.sqla_handle = sqla_handle 59 | self.extra_call_args = extra_call_args 60 | 61 | @property 62 | def target_state(self): 63 | return self.meta.target 64 | 65 | @property 66 | def current_state(self): 67 | return getattr(self.sqla_handle.record, self.sqla_handle.column_name) 68 | 69 | def transition_possible(self): 70 | return ("*" in self.meta.sources) or (self.current_state in self.meta.sources) 71 | 72 | 73 | class BoundFSMFunction(BoundFSMBase): 74 | 75 | __slots__ = BoundFSMBase.__slots__ + ("set_func", "my_args") 76 | 77 | def __init__(self, meta, sqla_handle, set_func, extra_call_args): 78 | super().__init__(meta, sqla_handle, extra_call_args) 79 | self.set_func = set_func 80 | self.my_args = ( 81 | self.meta.extra_call_args 82 | + self.extra_call_args 83 | + (self.sqla_handle.record,) 84 | ) 85 | 86 | def get_call_iface_error(self, fn, args, kwargs): 87 | """Returhs 'Type' error describing function's api mismatch (if one exists) 88 | 89 | or None 90 | """ 91 | try: 92 | py_inspect.getcallargs(fn, *args, **kwargs) 93 | except TypeError as err: 94 | return err 95 | return None 96 | 97 | def conditions_met(self, args, kwargs): 98 | conditions = self.meta.conditions 99 | if not conditions: 100 | # Performance - skip the check 101 | return True 102 | 103 | args = self.my_args + tuple(args) 104 | 105 | kwargs = dict(kwargs) 106 | 107 | out = True 108 | for condition in conditions: 109 | # Check that condition is call-able with args provided 110 | if self.get_call_iface_error(condition, args, kwargs): 111 | out = False 112 | else: 113 | out = condition(*args, **kwargs) 114 | 115 | if not out: 116 | # Preconditions failed 117 | break 118 | 119 | if out: 120 | # Check that the function itself can be called with these args 121 | err = self.get_call_iface_error(self.set_func, args, kwargs) 122 | if err: 123 | warnings.warn("Failure to validate handler call args: {}".format(err)) 124 | # Can not map call args to handler's 125 | out = False 126 | if conditions: 127 | raise exc.SetupError( 128 | "Mismatch beteen args accepted by preconditons " 129 | "({!r}) & handler ({!r})".format( 130 | self.meta.conditions, self.set_func 131 | ) 132 | ) 133 | return out 134 | 135 | def to_next_state(self, args, kwargs): 136 | old_state = self.current_state 137 | new_state = self.target_state 138 | 139 | sqla_target = self.sqla_handle.record 140 | 141 | args = self.my_args + tuple(args) 142 | 143 | self.sqla_handle.dispatch.before_state_change( 144 | source=old_state, target=new_state 145 | ) 146 | 147 | self.set_func(*args, **kwargs) 148 | setattr(sqla_target, self.sqla_handle.column_name, new_state) 149 | self.sqla_handle.dispatch.after_state_change(source=old_state, target=new_state) 150 | 151 | def __repr__(self): 152 | return "<{} meta={!r} instance={!r} function={!r}>".format( 153 | self.__class__.__name__, 154 | self.meta, 155 | self.sqla_handle, 156 | self.set_func, 157 | ) 158 | 159 | 160 | class TansitionStateArtithmetics(object): 161 | """Helper class aiding in merging transition state params.""" 162 | 163 | def __init__(self, meta_a, meta_b): 164 | self.meta_a = meta_a 165 | self.meta_b = meta_b 166 | 167 | def source_intersection(self): 168 | """Returns intersected sources meta sources.""" 169 | sources_a = self.meta_a.sources 170 | sources_b = self.meta_b.sources 171 | 172 | if "*" in sources_a: 173 | return sources_b 174 | elif "*" in sources_b: 175 | return sources_a 176 | elif sources_a.issuperset(sources_b): 177 | return sources_a.intersection(sources_b) 178 | else: 179 | return False 180 | 181 | def target_intersection(self): 182 | target_a = self.meta_a.target 183 | target_b = self.meta_b.target 184 | if target_a == target_b: 185 | # Also covers the case when both are None 186 | out = target_a 187 | elif None in (target_a, target_b): 188 | # Return value that is not None 189 | out = target_a or target_b 190 | else: 191 | # Both are non-equal strings 192 | out = None 193 | return out 194 | 195 | def joint_conditions(self): 196 | """Returns union of both conditions.""" 197 | return self.meta_a.conditions + self.meta_b.conditions 198 | 199 | def joint_args(self): 200 | return self.meta_a.extra_call_args + self.meta_b.extra_call_args 201 | 202 | 203 | @cache.dict_cache 204 | def inherited_bound_classes(key): 205 | 206 | (child_cls, parent_meta) = key 207 | 208 | def _get_sub_transitions(child_cls): 209 | sub_handlers = [] 210 | for name in dir(child_cls): 211 | try: 212 | attr = getattr(child_cls, name) 213 | if attr._sa_fsm_meta: 214 | sub_handlers.append((name, attr)) 215 | except AttributeError: 216 | # Skip non-fsm methods 217 | continue 218 | return sub_handlers 219 | 220 | def _get_bound_sub_metas(child_cls, sub_transitions, parent_meta): 221 | out = [] 222 | 223 | for (_name, transition) in sub_transitions: 224 | sub_meta = transition._sa_fsm_meta 225 | arithmetics = TansitionStateArtithmetics(parent_meta, sub_meta) 226 | 227 | sub_sources = arithmetics.source_intersection() 228 | if not sub_sources: 229 | raise exc.SetupError( 230 | "Source state superset {super} " 231 | "and subset {sub} are not compatable".format( 232 | super=parent_meta.sources, sub=sub_meta.sources 233 | ) 234 | ) 235 | 236 | sub_target = arithmetics.target_intersection() 237 | if not sub_target: 238 | raise exc.SetupError( 239 | "Targets {super} and {sub} are not compatable".format( 240 | super=parent_meta.target, sub=sub_meta.target 241 | ) 242 | ) 243 | 244 | merged_sub_meta = meta.FSMMeta( 245 | sub_sources, 246 | sub_target, 247 | arithmetics.joint_conditions(), 248 | arithmetics.joint_args(), 249 | sub_meta.bound_cls, 250 | ) 251 | out.append((merged_sub_meta, transition._sa_fsm_transition_fn)) 252 | 253 | return out 254 | 255 | out_cls = type( 256 | "{}::sqlalchemy_handle".format( 257 | child_cls.__name__, 258 | ), 259 | (child_cls,), 260 | { 261 | "_sa_fsm_sqlalchemy_handle": None, 262 | "_sa_fsm_sqlalchemy_metas": (), 263 | }, 264 | ) 265 | sub_transitions = _get_sub_transitions(out_cls) 266 | out_cls._sa_fsm_sqlalchemy_metas = tuple( 267 | _get_bound_sub_metas(out_cls, sub_transitions, parent_meta) 268 | ) 269 | 270 | return out_cls 271 | 272 | 273 | class BoundFSMClass(BoundFSMBase): 274 | 275 | __slots__ = BoundFSMBase.__slots__ + ("bound_sub_metas", "_target_cached") 276 | 277 | def __init__(self, meta, sqlalchemy_handle, child_cls, extra_call_args): 278 | super().__init__(meta, sqlalchemy_handle, extra_call_args) 279 | child_cls = inherited_bound_classes.get_value((child_cls, meta)) 280 | child_object = child_cls() 281 | child_object._sa_fsm_sqlalchemy_handle = sqlalchemy_handle 282 | self.bound_sub_metas = [ 283 | meta.get_bound(sqlalchemy_handle, set_fn, (child_object,)) 284 | for (meta, set_fn) in child_object._sa_fsm_sqlalchemy_metas 285 | ] 286 | self._target_cached = None 287 | 288 | @property 289 | def target_state(self): 290 | if self._target_cached is None: 291 | targets = tuple(set(meta.meta.target for meta in self.bound_sub_metas)) 292 | assert len(targets) == 1, "One and just one target expected" 293 | self._target_cached = targets[0] 294 | return self._target_cached 295 | 296 | def transition_possible(self): 297 | return any(sub.transition_possible() for sub in self.bound_sub_metas) 298 | 299 | def conditions_met(self, args, kwargs): 300 | return any( 301 | sub.transition_possible() and sub.conditions_met(args, kwargs) 302 | for sub in self.bound_sub_metas 303 | ) 304 | 305 | def to_next_state(self, args, kwargs): 306 | can_transition_with = [ 307 | sub 308 | for sub in self.bound_sub_metas 309 | if sub.transition_possible() and sub.conditions_met(args, kwargs) 310 | ] 311 | if len(can_transition_with) > 1: 312 | raise exc.SetupError( 313 | "Can transition with multiple handlers ({})".format(can_transition_with) 314 | ) 315 | else: 316 | assert can_transition_with 317 | return can_transition_with[0].to_next_state(args, kwargs) 318 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/cache.py: -------------------------------------------------------------------------------- 1 | """Caching tools/classes""" 2 | 3 | import weakref 4 | 5 | 6 | class DictCache(object): 7 | """Generic object that uses dict-like object for caching.""" 8 | 9 | __slots__ = ("cache", "get_default") 10 | 11 | def __init__(self, dict_object, get_default): 12 | self.cache = dict_object 13 | self.get_default = get_default 14 | 15 | def get_value(self, key): 16 | """A method is faster than __getitem__""" 17 | try: 18 | return self.cache[key] 19 | except KeyError: 20 | out = self.get_default(key) 21 | self.cache[key] = out 22 | return out 23 | 24 | 25 | def weak_value_cache(get_func): 26 | """A decorator that makes a new dict_cache using function provided as value getter""" 27 | return DictCache(weakref.WeakValueDictionary(), get_func) 28 | 29 | 30 | def dict_cache(get_func): 31 | """Generic dict cache decorator""" 32 | return DictCache({}, get_func) 33 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/events.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import sqlalchemy.orm.events 4 | 5 | from sqlalchemy.orm.instrumentation import register_class 6 | 7 | 8 | @sqlalchemy.event.dispatcher 9 | class FSMSchemaEvents(sqlalchemy.orm.events.InstanceEvents): 10 | """Define event listeners for FSM Schema (table) objects.""" 11 | 12 | def before_state_change(self, source, target): 13 | """Event that is fired before the model changes 14 | form `source` to `target` state.""" 15 | 16 | def after_state_change(self, source, target): 17 | """Event that is fired after the model changes 18 | form `source` to `target` state.""" 19 | 20 | 21 | class InstanceRef(object): 22 | """This class has to be passed to the dispatch call as instance. 23 | 24 | No idea why it is required. 25 | 26 | """ 27 | 28 | __slots__ = ("target",) 29 | 30 | def __init__(self, target): 31 | self.target = target 32 | 33 | def obj(self): 34 | return self.target 35 | 36 | 37 | FSM_EVENT_DISPATCHER_CACHE = {} 38 | 39 | 40 | def get_class_bound_dispatcher(target_cls): 41 | """Python class-bound FSM dispatcher class.""" 42 | try: 43 | out_val = FSM_EVENT_DISPATCHER_CACHE[target_cls] 44 | except KeyError: 45 | out_val = register_class(target_cls).dispatch 46 | FSM_EVENT_DISPATCHER_CACHE[target_cls] = out_val 47 | return out_val 48 | 49 | 50 | class BoundFSMDispatcher(object): 51 | """Utility method that simplifies sqlalchemy event dispatch.""" 52 | 53 | def __init__(self, instance): 54 | self.__ref = InstanceRef(instance) 55 | self.__cls_dispatcher = get_class_bound_dispatcher(type(instance)) 56 | for fsm_handle in ("before_state_change", "after_state_change"): 57 | # Precompute fsm handles 58 | getattr(self, fsm_handle) 59 | 60 | def __getattr__(self, name): 61 | handle = partial(getattr(self.__cls_dispatcher, name), self.__ref) 62 | setattr(self, name, handle) 63 | return handle 64 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/exc.py: -------------------------------------------------------------------------------- 1 | """FSM exceptions.""" 2 | 3 | 4 | class FSMException(Exception): 5 | """Generic Finite State Machine Exception.""" 6 | 7 | 8 | class PreconditionError(FSMException): 9 | """Raised when transition conditions are not satisfied.""" 10 | 11 | 12 | class SetupError(FSMException): 13 | """Raised when FSM is configured incorrectly.""" 14 | 15 | 16 | class InvalidSourceStateError(FSMException, NotImplementedError): 17 | """Can not switch from current state to the requested state.""" 18 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/meta.py: -------------------------------------------------------------------------------- 1 | """FSM meta object.""" 2 | 3 | import collections.abc 4 | 5 | from . import util 6 | 7 | 8 | class FSMMeta(object): 9 | 10 | __slots__ = ( 11 | "target", 12 | "conditions", 13 | "sources", 14 | "bound_cls", 15 | "extra_call_args", 16 | ) 17 | 18 | def __init__(self, source, target, conditions, extra_args, bound_cls): 19 | self.bound_cls = bound_cls 20 | self.conditions = tuple(conditions) 21 | self.extra_call_args = tuple(extra_args) 22 | 23 | if target is not None: 24 | if not util.is_valid_fsm_state(target): 25 | raise NotImplementedError(target) 26 | self.target = target 27 | else: 28 | self.target = None 29 | 30 | if util.is_valid_source_state(source): 31 | all_sources = (source,) 32 | elif isinstance(source, collections.abc.Iterable): 33 | all_sources = tuple(source) 34 | 35 | if not all(util.is_valid_source_state(el) for el in all_sources): 36 | raise NotImplementedError(all_sources) 37 | else: 38 | raise NotImplementedError(source) 39 | 40 | self.sources = frozenset(all_sources) 41 | 42 | def get_bound(self, sqlalchemy_handle, set_func, extra_args): 43 | return self.bound_cls(self, sqlalchemy_handle, set_func, extra_args) 44 | 45 | def __repr__(self): 46 | return ( 47 | "<{} sources={!r} target={!r} conditions={!r} " 48 | "extra call args={!r}>".format( 49 | self.__class__.__name__, 50 | self.sources, 51 | self.target, 52 | self.conditions, 53 | self.extra_call_args, 54 | ) 55 | ) 56 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/sqltypes.py: -------------------------------------------------------------------------------- 1 | """ FSM SQL column type(s) """ 2 | 3 | from sqlalchemy import types 4 | 5 | 6 | class FSMField(types.String): 7 | pass 8 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/transition.py: -------------------------------------------------------------------------------- 1 | """ Transition decorator. """ 2 | import inspect as py_inspect 3 | import warnings 4 | 5 | from sqlalchemy.ext.hybrid import HYBRID_METHOD 6 | from sqlalchemy.orm.interfaces import InspectionAttrInfo 7 | 8 | from . import bound, cache, exc 9 | from .meta import FSMMeta 10 | 11 | 12 | @cache.dict_cache 13 | def sql_equality_cache(key): 14 | """It takes a bit of time for sqlalchemy to generate these. 15 | 16 | So I'm caching them. 17 | """ 18 | (column, target) = key 19 | assert target, "Target must be defined." 20 | return column == target 21 | 22 | 23 | class ClassBoundFsmTransition(object): 24 | 25 | __slots__ = ( 26 | "_sa_fsm_meta", 27 | "_sa_fsm_owner_cls", 28 | "_sa_fsm_sqla_handle", 29 | "_sa_fsm_transition_fn", 30 | ) 31 | 32 | def __init__(self, meta, sqla_handle, paylaod_func, owner_cls): 33 | self._sa_fsm_meta = meta 34 | self._sa_fsm_owner_cls = owner_cls 35 | self._sa_fsm_sqla_handle = sqla_handle 36 | self._sa_fsm_transition_fn = paylaod_func 37 | 38 | def __call__(self): 39 | """Return a SQLAlchemy filter for this particular state.""" 40 | column = self._sa_fsm_sqla_handle.fsm_column 41 | target = self._sa_fsm_meta.target 42 | return sql_equality_cache.get_value((column, target)) 43 | 44 | def is_(self, value): 45 | if isinstance(value, bool): 46 | out = self().is_(value) 47 | else: 48 | warnings.warn("Unexpected is_ argument: {!r}".format(value)) 49 | # Can be used as sqlalchemy filer. Won't match anything 50 | out = False 51 | return out 52 | 53 | 54 | class InstanceBoundFsmTransition(object): 55 | 56 | __slots__ = ClassBoundFsmTransition.__slots__ + ( 57 | "_sa_fsm_self", 58 | "_sa_fsm_bound_meta", 59 | ) 60 | 61 | def __init__(self, meta, sqla_handle, transition_fn, owner_cls, instance): 62 | self._sa_fsm_meta = meta 63 | self._sa_fsm_transition_fn = transition_fn 64 | self._sa_fsm_owner_cls = owner_cls 65 | self._sa_fsm_self = instance 66 | self._sa_fsm_bound_meta = meta.get_bound(sqla_handle, transition_fn, ()) 67 | 68 | def __call__(self): 69 | """Check if this is the current state of the object.""" 70 | bound_meta = self._sa_fsm_bound_meta 71 | return bound_meta.target_state == bound_meta.current_state 72 | 73 | def set(self, *args, **kwargs): 74 | """Transition the FSM to this new state.""" 75 | bound_meta = self._sa_fsm_bound_meta 76 | func = self._sa_fsm_transition_fn 77 | 78 | if not bound_meta.transition_possible(): 79 | raise exc.InvalidSourceStateError( 80 | "Unable to switch from {} using method {}".format( 81 | bound_meta.current_state, func.__name__ 82 | ) 83 | ) 84 | if not bound_meta.conditions_met(args, kwargs): 85 | raise exc.PreconditionError("Preconditions are not satisfied.") 86 | return bound_meta.to_next_state(args, kwargs) 87 | 88 | def can_proceed(self, *args, **kwargs): 89 | bound_meta = self._sa_fsm_bound_meta 90 | return bound_meta.transition_possible() and bound_meta.conditions_met( 91 | args, kwargs 92 | ) 93 | 94 | 95 | class FsmTransition(InspectionAttrInfo): 96 | 97 | is_attribute = True 98 | extension_type = HYBRID_METHOD 99 | _sa_fsm_is_transition = True 100 | 101 | def __init__(self, meta, set_function): 102 | self.meta = meta 103 | self.set_fn = set_function 104 | 105 | def __get__(self, instance, owner): 106 | try: 107 | sql_alchemy_handle = owner._sa_fsm_sqlalchemy_handle 108 | except AttributeError: 109 | # Owner class is not bound to sqlalchemy handle object 110 | sql_alchemy_handle = bound.SqlAlchemyHandle(owner, instance) 111 | 112 | if instance is None: 113 | return ClassBoundFsmTransition( 114 | self.meta, sql_alchemy_handle, self.set_fn, owner 115 | ) 116 | else: 117 | return InstanceBoundFsmTransition( 118 | self.meta, sql_alchemy_handle, self.set_fn, owner, instance 119 | ) 120 | 121 | 122 | def transition(source="*", target=None, conditions=()): 123 | def inner_transition(subject): 124 | 125 | if py_inspect.isfunction(subject): 126 | meta = FSMMeta(source, target, conditions, (), bound.BoundFSMFunction) 127 | elif py_inspect.isclass(subject): 128 | # Assume a class with multiple handles for various source states 129 | meta = FSMMeta(source, target, conditions, (), bound.BoundFSMClass) 130 | else: 131 | raise NotImplementedError("Do not know how to {!r}".format(subject)) 132 | 133 | return FsmTransition(meta, subject) 134 | 135 | return inner_transition 136 | -------------------------------------------------------------------------------- /src/sqlalchemy_fsm/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions and consts.""" 2 | from six import string_types 3 | 4 | 5 | def is_valid_fsm_state(value): 6 | return isinstance(value, string_types) and value 7 | 8 | 9 | def is_valid_source_state(value): 10 | """This function makes exeptions for special source states. 11 | 12 | E.g. It explicitly allows '*' (for any state) 13 | and `None` (as this is default value for sqlalchemy colums) 14 | """ 15 | return (value == "*") or (value is None) or is_valid_fsm_state(value) 16 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VRGhost/sqlalchemy-fsm/3e1a16459155f11d9fc9bb0cea06d7f25f369688/test/__init__.py -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy 3 | from sqlalchemy.ext.declarative import declarative_base 4 | from sqlalchemy.orm import sessionmaker 5 | 6 | 7 | engine = sqlalchemy.create_engine("sqlite:///:memory:", echo=True) 8 | SessionGen = sessionmaker(bind=engine) 9 | Base = declarative_base() 10 | 11 | 12 | def pytest_sessionstart(): 13 | Base.metadata.create_all(engine) 14 | 15 | 16 | @pytest.fixture(scope="function") 17 | def session(): 18 | Base.metadata.create_all(engine) # Creates any dynamically imported tables 19 | return SessionGen() 20 | -------------------------------------------------------------------------------- /test/test_basic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy 3 | 4 | from sqlalchemy_fsm import FSMField, transition 5 | from sqlalchemy_fsm.exc import SetupError 6 | 7 | from .conftest import Base 8 | 9 | 10 | class BlogPost(Base): 11 | __tablename__ = "blogpost" 12 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 13 | state = sqlalchemy.Column(FSMField) 14 | 15 | def __init__(self, *args, **kwargs): 16 | self.state = "new" 17 | super(BlogPost, self).__init__(*args, **kwargs) 18 | 19 | @transition(source="new", target="published") 20 | def published(self): 21 | pass 22 | 23 | @transition(source="published", target="hidden") 24 | def hidden(self): 25 | pass 26 | 27 | @transition(source="new", target="removed") 28 | def removed(self): 29 | raise Exception("No rights to delete %s" % self) 30 | 31 | @transition(source=["published", "hidden"], target="stolen") 32 | def stolen(self): 33 | pass 34 | 35 | @transition(source="*", target="moderated") 36 | def moderated(self): 37 | pass 38 | 39 | 40 | class TestFSMField(object): 41 | @pytest.fixture 42 | def model(self): 43 | return BlogPost() 44 | 45 | def test_initial_state_instatiated(self, model): 46 | assert model.state == "new" 47 | 48 | def test_meta_attached(self, model): 49 | assert model.published._sa_fsm_meta 50 | assert "FSMMeta" in repr(model.published._sa_fsm_meta) 51 | 52 | def test_known_transition_should_succeed(self, model): 53 | assert not model.published() # Model is not publish-ed yet 54 | assert model.published.can_proceed() 55 | model.published.set() 56 | assert model.state == "published" 57 | # model is publish-ed now 58 | assert model.published() 59 | 60 | assert model.hidden.can_proceed() 61 | model.hidden.set() 62 | assert model.state == "hidden" 63 | 64 | def test_unknow_transition_fails(self, model): 65 | assert not model.hidden.can_proceed() 66 | with pytest.raises(NotImplementedError) as err: 67 | model.hidden.set() 68 | assert "Unable to switch from" in str(err) 69 | 70 | def test_state_non_changed_after_fail(self, model): 71 | with pytest.raises(Exception) as err: 72 | model.removed.set() 73 | assert "No rights to delete" in str(err) 74 | assert model.removed.can_proceed() 75 | assert model.state == "new" 76 | 77 | def test_mutiple_source_support_path_1_works(self, model): 78 | model.published.set() 79 | model.stolen.set() 80 | assert model.state == "stolen" 81 | 82 | def test_mutiple_source_support_path_2_works(self, model): 83 | model.published.set() 84 | model.hidden.set() 85 | model.stolen.set() 86 | assert model.state == "stolen" 87 | 88 | def test_star_shortcut_succeed(self, model): 89 | assert model.moderated.can_proceed() 90 | model.moderated.set() 91 | assert model.state == "moderated" 92 | 93 | def test_query_filter(self, session): 94 | model1 = BlogPost() 95 | model2 = BlogPost() 96 | model3 = BlogPost() 97 | model4 = BlogPost() 98 | model3.published.set() 99 | model4.published.set() 100 | 101 | session.add_all([model1, model2, model3, model4]) 102 | session.commit() 103 | 104 | ids = [model1.id, model2.id, model3.id, model4.id] 105 | 106 | # Check that one can query by fsm handler 107 | query_results = ( 108 | session.query(BlogPost) 109 | .filter( 110 | BlogPost.published(), 111 | BlogPost.id.in_(ids), 112 | ) 113 | .all() 114 | ) 115 | assert len(query_results) == 2, query_results 116 | assert model3 in query_results 117 | assert model4 in query_results 118 | 119 | negated_query_results = ( 120 | session.query(BlogPost) 121 | .filter( 122 | ~BlogPost.published(), 123 | BlogPost.id.in_(ids), 124 | ) 125 | .all() 126 | ) 127 | assert len(negated_query_results) == 2, query_results 128 | assert model1 in negated_query_results 129 | assert model2 in negated_query_results 130 | 131 | 132 | class InvalidModel(Base): 133 | __tablename__ = "invalidmodel" 134 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 135 | state = sqlalchemy.Column(FSMField) 136 | action = sqlalchemy.Column(FSMField) 137 | 138 | def __init__(self, *args, **kwargs): 139 | self.state = "new" 140 | self.action = "no" 141 | super(InvalidModel, self).__init__(*args, **kwargs) 142 | 143 | @transition(source="new", target="no") 144 | def validated(self): 145 | pass 146 | 147 | 148 | class TestInvalidModel(object): 149 | def test_two_fsmfields_in_one_model_not_allowed(self): 150 | model = InvalidModel() 151 | with pytest.raises(SetupError) as err: 152 | model.validated() 153 | assert "More than one FSMField found" in str(err) 154 | 155 | 156 | class Document(Base): 157 | __tablename__ = "document" 158 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 159 | status = sqlalchemy.Column(FSMField) 160 | 161 | def __init__(self, *args, **kwargs): 162 | self.status = "new" 163 | super(Document, self).__init__(*args, **kwargs) 164 | 165 | @transition(source="new", target="published") 166 | def published(self): 167 | pass 168 | 169 | 170 | class TestDocument(object): 171 | def test_any_state_field_name_allowed(self): 172 | model = Document() 173 | model.published.set() 174 | assert model.status == "published" 175 | 176 | 177 | class NullSource(Base): 178 | __tablename__ = "null_source" 179 | 180 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 181 | status = sqlalchemy.Column(FSMField, nullable=True) 182 | 183 | @transition(source=None, target="published") 184 | def pub_from_none(self): 185 | pass 186 | 187 | @transition(source=None, target="new") 188 | def new_from_none(self): 189 | pass 190 | 191 | @transition(source=["new", None], target="published") 192 | def pub_from_either(self): 193 | pass 194 | 195 | @transition(source="*", target="end") 196 | def end_from_all(self): 197 | pass 198 | 199 | 200 | class TestNullSource(object): 201 | @pytest.fixture 202 | def model(self): 203 | return NullSource() 204 | 205 | def test_null_to_end(self, model): 206 | assert model.status is None 207 | model.end_from_all.set() 208 | assert model.status == "end" 209 | 210 | def test_null_pub_end(self, model): 211 | assert model.status is None 212 | model.pub_from_none.set() 213 | assert model.status == "published" 214 | model.end_from_all.set() 215 | assert model.status == "end" 216 | 217 | def test_null_new_pub_end(self, model): 218 | assert model.status is None 219 | model.new_from_none.set() 220 | assert model.status == "new" 221 | model.pub_from_either.set() 222 | assert model.status == "published" 223 | model.end_from_all.set() 224 | assert model.status == "end" 225 | -------------------------------------------------------------------------------- /test/test_conditionals.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import sqlalchemy 4 | 5 | 6 | from sqlalchemy_fsm import FSMField, transition 7 | from sqlalchemy_fsm.exc import ( 8 | PreconditionError, 9 | ) 10 | 11 | from .conftest import Base 12 | 13 | 14 | def condition_func(instance): 15 | return True 16 | 17 | 18 | class BlogPostWithConditions(Base): 19 | __tablename__ = "BlogPostWithConditions" 20 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 21 | state = sqlalchemy.Column(FSMField) 22 | 23 | def __init__(self, *args, **kwargs): 24 | self.state = "new" 25 | super(BlogPostWithConditions, self).__init__(*args, **kwargs) 26 | 27 | def model_condition(self): 28 | return True 29 | 30 | def unmet_condition(self): 31 | return False 32 | 33 | @transition( 34 | source="new", target="published", conditions=[condition_func, model_condition] 35 | ) 36 | def published(self): 37 | pass 38 | 39 | @transition( 40 | source="published", 41 | target="destroyed", 42 | conditions=[condition_func, unmet_condition], 43 | ) 44 | def destroyed(self): 45 | pass 46 | 47 | 48 | class TestConditional(unittest.TestCase): 49 | def setUp(self): 50 | self.model = BlogPostWithConditions() 51 | 52 | def test_initial_staet(self): 53 | self.assertEqual(self.model.state, "new") 54 | 55 | def test_known_transition_should_succeed(self): 56 | self.assertTrue(self.model.published.can_proceed()) 57 | self.model.published.set() 58 | self.assertEqual(self.model.state, "published") 59 | 60 | def test_unmet_condition(self): 61 | self.model.published.set() 62 | self.assertEqual(self.model.state, "published") 63 | self.assertFalse(self.model.destroyed.can_proceed()) 64 | self.assertRaises(PreconditionError, self.model.destroyed.set) 65 | self.assertEqual(self.model.state, "published") 66 | -------------------------------------------------------------------------------- /test/test_events.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy 3 | import sqlalchemy.event 4 | 5 | import sqlalchemy_fsm 6 | 7 | from .conftest import Base 8 | 9 | 10 | class EventModel(Base): 11 | __tablename__ = "event_model" 12 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 13 | state = sqlalchemy.Column(sqlalchemy_fsm.FSMField) 14 | 15 | def __init__(self, *args, **kwargs): 16 | self.state = "new" 17 | super(EventModel, self).__init__(*args, **kwargs) 18 | 19 | @sqlalchemy_fsm.transition(source="*", target="state_a") 20 | def state_a(self): 21 | pass 22 | 23 | @sqlalchemy_fsm.transition(source="*", target="state_b") 24 | def state_b(self): 25 | pass 26 | 27 | 28 | class TestEventListener(object): 29 | @pytest.fixture 30 | def model(self): 31 | return EventModel() 32 | 33 | @pytest.mark.parametrize( 34 | "event_name", 35 | [ 36 | "before_state_change", 37 | "after_state_change", 38 | ], 39 | ) 40 | def test_events(self, model, event_name): 41 | 42 | listener_result = [] 43 | 44 | def on_update(instance, source, target): 45 | listener_result.append((source, target)) 46 | 47 | sqlalchemy.event.listen(EventModel, event_name, on_update) 48 | 49 | expected_result = [] 50 | assert listener_result == expected_result 51 | 52 | for handle_name in ("state_a", "state_b", "state_a", "state_a", "state_b"): 53 | expected_result.append((model.state, handle_name)) 54 | if handle_name == "state_a": 55 | handle = model.state_a 56 | else: 57 | handle = model.state_b 58 | handle.set() 59 | assert listener_result == expected_result 60 | 61 | # Remove the listener & check that it had an effect 62 | sqlalchemy.event.remove(EventModel, event_name, on_update) 63 | # Call the state handle & ensure that listener had not been called. 64 | model.state_a.set() 65 | assert listener_result == expected_result 66 | 67 | def test_standard_sqlalchemy_events_still_work(self, model, session): 68 | state_log = [] 69 | insert_log = [] 70 | 71 | @sqlalchemy.event.listens_for(EventModel, "after_state_change") 72 | def after_state_change(instance, source, target): 73 | state_log.append(target) 74 | 75 | @sqlalchemy.event.listens_for(EventModel, "before_insert") 76 | def before_insert(mapper, connection, target): 77 | insert_log.append(42) 78 | 79 | assert not state_log 80 | assert not insert_log 81 | 82 | model.state_a.set() 83 | assert len(state_log) == 1 84 | assert len(insert_log) == 0 85 | 86 | model.state_b.set() 87 | assert len(state_log) == 2 88 | assert len(insert_log) == 0 89 | 90 | session.add(model) 91 | session.flush() 92 | 93 | assert len(state_log) == 2 94 | assert len(insert_log) == 1 95 | 96 | model.state_b.set() 97 | assert len(state_log) == 3 98 | assert len(insert_log) == 1 99 | 100 | 101 | class TransitionClassEventModel(Base): 102 | __tablename__ = "transition_class_event_model" 103 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 104 | state = sqlalchemy.Column(sqlalchemy_fsm.FSMField) 105 | side_effect = sqlalchemy.Column(sqlalchemy.String) 106 | 107 | def __init__(self, *args, **kwargs): 108 | self.state = "new" 109 | super(TransitionClassEventModel, self).__init__(*args, **kwargs) 110 | 111 | @sqlalchemy_fsm.transition(source="*", target="state_a") 112 | def state_a(self): 113 | pass 114 | 115 | @sqlalchemy_fsm.transition(source="*", target="state_b") 116 | def state_b(self): 117 | pass 118 | 119 | @sqlalchemy_fsm.transition(target="state_class") 120 | class StateClass(object): 121 | @sqlalchemy_fsm.transition(source="state_a") 122 | def from_a(self, instance): 123 | instance.side_effect = "from_a" 124 | 125 | @sqlalchemy_fsm.transition(source="state_b") 126 | def from_b(self, instance): 127 | instance.side_effect = "from_b" 128 | 129 | 130 | class TestTransitionClassEvents(object): 131 | @pytest.fixture 132 | def model(self): 133 | return TransitionClassEventModel() 134 | 135 | @pytest.mark.parametrize( 136 | "event_name", 137 | [ 138 | "before_state_change", 139 | "after_state_change", 140 | ], 141 | ) 142 | def test_events(self, model, event_name): 143 | 144 | listener_result = [] 145 | 146 | @sqlalchemy.event.listens_for(TransitionClassEventModel, event_name) 147 | def on_update(instance, source, target): 148 | listener_result.append(target) 149 | 150 | expected_result = [] 151 | assert listener_result == expected_result 152 | 153 | for handle_name in ("state_a", "state_b", "state_a", "state_a", "state_b"): 154 | expected_result.append(handle_name) 155 | if handle_name == "state_a": 156 | handle = model.state_a 157 | else: 158 | handle = model.state_b 159 | handle.set() 160 | assert listener_result == expected_result 161 | model.StateClass.set() 162 | 163 | if handle_name == "state_a": 164 | expected_side = "from_a" 165 | else: 166 | expected_side = "from_b" 167 | 168 | expected_result.append("state_class") 169 | 170 | assert model.side_effect == expected_side 171 | assert listener_result == expected_result 172 | 173 | # Remove the listener & check that it had an effect 174 | sqlalchemy.event.remove(TransitionClassEventModel, event_name, on_update) 175 | # Call the state handle & ensure that listener had not been called. 176 | model.state_a.set() 177 | assert listener_result == expected_result 178 | 179 | 180 | class TestEventsLeakage(object): 181 | """Ensure that multiple FSM models do not mix their events up.""" 182 | 183 | @pytest.mark.parametrize( 184 | "event_name", 185 | [ 186 | "before_state_change", 187 | "after_state_change", 188 | ], 189 | ) 190 | def test_leakage(self, event_name): 191 | event_model = EventModel() 192 | tr_cls_model = TransitionClassEventModel() 193 | 194 | event_result = [] 195 | tr_cls_result = [] 196 | joint_result = [] 197 | 198 | @sqlalchemy.event.listens_for(EventModel, event_name) 199 | def on_evt_update(instance, source, target): 200 | event_result.append(target) 201 | 202 | @sqlalchemy.event.listens_for(TransitionClassEventModel, event_name) 203 | def on_tr_update(instance, source, target): 204 | tr_cls_result.append(target) 205 | 206 | @sqlalchemy.event.listens_for(TransitionClassEventModel, event_name) 207 | @sqlalchemy.event.listens_for(EventModel, event_name) 208 | def on_both_update(instance, source, target): 209 | joint_result.append(target) 210 | 211 | assert len(event_result) == 0 212 | assert len(tr_cls_result) == 0 213 | assert len(joint_result) == 0 214 | 215 | event_model.state_a.set() 216 | assert len(event_result) == 1 217 | assert len(tr_cls_result) == 0 218 | assert len(joint_result) == 1 219 | 220 | event_model.state_b.set() 221 | assert len(event_result) == 2 222 | assert len(tr_cls_result) == 0 223 | assert len(joint_result) == 2 224 | 225 | tr_cls_model.state_a.set() 226 | assert len(event_result) == 2 227 | assert len(tr_cls_result) == 1 228 | assert len(joint_result) == 3 229 | 230 | tr_cls_model.state_a.set() 231 | assert len(event_result) == 2 232 | assert len(tr_cls_result) == 2 233 | assert len(joint_result) == 4 234 | -------------------------------------------------------------------------------- /test/test_invalids.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy 3 | 4 | from sqlalchemy_fsm import exc, FSMField, transition 5 | 6 | 7 | from .conftest import Base 8 | 9 | 10 | class NotFsm(Base): 11 | __tablename__ = "NotFsm" 12 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 13 | 14 | @transition(source="*", target="blah") 15 | def change_state(self): 16 | pass 17 | 18 | def not_transition(self): 19 | pass 20 | 21 | 22 | def test_not_fsm(): 23 | with pytest.raises(exc.SetupError) as err: 24 | NotFsm().change_state() 25 | assert "No FSMField found in model" in str(err) 26 | 27 | 28 | def test_not_transition(): 29 | with pytest.raises(AttributeError): 30 | NotFsm.not_transition.can_proceed() 31 | 32 | 33 | class TooMuchFsm(Base): 34 | __tablename__ = "TooMuchFsm" 35 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 36 | state1 = sqlalchemy.Column(FSMField) 37 | state2 = sqlalchemy.Column(FSMField) 38 | 39 | @transition(source="*", target="blah") 40 | def change_state(self): 41 | pass 42 | 43 | 44 | def test_too_much_fsm(): 45 | with pytest.raises(exc.SetupError) as err: 46 | TooMuchFsm().change_state() 47 | assert "More than one FSMField found in model" in str(err) 48 | 49 | 50 | def test_transition_raises_on_unknown(): 51 | class MyCallable(object): 52 | def __call__(*args): 53 | pass 54 | 55 | with pytest.raises(NotImplementedError) as err: 56 | 57 | wrapper = transition(source="*", target="blah") 58 | wrapper(MyCallable()) 59 | 60 | assert "Do not know how to" in str(err) 61 | 62 | 63 | def test_transition_raises_on_invalid_state(): 64 | with pytest.raises(NotImplementedError) as err: 65 | 66 | @transition(source=42, target="blah") 67 | def func1(): 68 | pass 69 | 70 | assert "42" in str(err) 71 | 72 | with pytest.raises(NotImplementedError) as err: 73 | 74 | @transition(source="*", target=42) 75 | def func2(): 76 | pass 77 | 78 | assert "42" in str(err) 79 | 80 | with pytest.raises(NotImplementedError) as err: 81 | 82 | @transition(source=["str", 42], target="blah") 83 | def func3(): 84 | pass 85 | 86 | assert "42" in str(err) 87 | 88 | 89 | def one_arg_condition(): 90 | def one_arg_condition(instance, arg1): 91 | return True 92 | 93 | return one_arg_condition 94 | 95 | 96 | class MisconfiguredTransitions(Base): 97 | __tablename__ = "MisconfiguredTransitions" 98 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 99 | state = sqlalchemy.Column(FSMField) 100 | 101 | @transition(source="*", target="blah", conditions=[one_arg_condition()]) 102 | def change_state(self): 103 | """Condition accepts one arg, state handler doesn't -> exception.""" 104 | pass 105 | 106 | @transition(source="*", target="blah") 107 | class MultiHandlerTransition(object): 108 | """The system won't know which transition{1,2} handler to chose.""" 109 | 110 | @transition() 111 | def transition1(self, instance): 112 | pass 113 | 114 | @transition() 115 | def transition2(self, instance): 116 | pass 117 | 118 | @transition(source="*", target="blah") 119 | class IncompatibleTargets(object): 120 | """The system won't know which transition{1,2} handler to chose.""" 121 | 122 | @transition(target="not-blah") 123 | def transition1(self, instance): 124 | pass 125 | 126 | @transition(source=["src1", "src2"], target="blah") 127 | class IncompatibleSources(object): 128 | """The system won't know which transition{1,2} handler to chose.""" 129 | 130 | @transition(source=["src3", "src4"]) 131 | def transition1(self, instance): 132 | pass 133 | 134 | @transition(source="*", target="blah") 135 | class NoConflictDueToPreconditionArgCount(object): 136 | @transition(conditions=[lambda self, instance, arg1: True]) 137 | def change_state(self, instance, arg1): 138 | pass 139 | 140 | @transition() 141 | def no_arg_condition(self, instance): 142 | pass 143 | 144 | 145 | class TestMisconfiguredTransitions(object): 146 | @pytest.fixture 147 | def model(self): 148 | return MisconfiguredTransitions() 149 | 150 | def test_misconfigured_transitions(self, model): 151 | with pytest.raises(exc.SetupError) as err: 152 | with pytest.warns(UserWarning): 153 | model.change_state.set(42) 154 | assert "Mismatch beteen args accepted" in str(err) 155 | 156 | def test_multi_transition_handlers(self, model): 157 | with pytest.raises(exc.SetupError) as err: 158 | model.MultiHandlerTransition.set() 159 | assert "Can transition with multiple handlers" in str(err) 160 | 161 | def test_incompatible_targets(self, model): 162 | with pytest.raises(exc.SetupError) as err: 163 | model.IncompatibleTargets.set() 164 | assert "are not compatable" in str(err) 165 | 166 | def test_incompatable_sources(self, model): 167 | with pytest.raises(exc.SetupError) as err: 168 | model.IncompatibleSources.set() 169 | assert "are not compatable" in str(err) 170 | 171 | def test_no_conflict_due_to_precondition_arg_count(self, model): 172 | assert model.NoConflictDueToPreconditionArgCount.can_proceed() 173 | 174 | 175 | def test_unexpected_is__type(session): 176 | model = MisconfiguredTransitions() 177 | session.add(model) 178 | session.commit() 179 | with pytest.warns(UserWarning) as warn: 180 | result = ( 181 | session.query(MisconfiguredTransitions) 182 | .filter(MisconfiguredTransitions.change_state.is_("hello world")) 183 | .all() 184 | ) 185 | assert not result 186 | assert "Unexpected is_ argument: 'hello world'" in str(warn.list[0].message) 187 | -------------------------------------------------------------------------------- /test/test_multi_source.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy 3 | 4 | 5 | from sqlalchemy_fsm import FSMField, transition 6 | from sqlalchemy_fsm.exc import ( 7 | InvalidSourceStateError, 8 | PreconditionError, 9 | ) 10 | 11 | from .conftest import Base 12 | 13 | 14 | def val_eq_condition(expected_value): 15 | def bound_val_eq_condition(self, instance, actual_value): 16 | return expected_value == actual_value 17 | 18 | return bound_val_eq_condition 19 | 20 | 21 | def val_contains_condition(expected_values): 22 | def bound_val_contains_condition(self, instance, actual_value): 23 | return actual_value in expected_values 24 | 25 | return bound_val_contains_condition 26 | 27 | 28 | def three_argument_condition(expected1, expected2, expected3): 29 | def bound_three_argument_condition(self, instance, arg1, arg2, arg3): 30 | return (arg1, arg2, arg3) == (expected1, expected2, expected3) 31 | 32 | return bound_three_argument_condition 33 | 34 | 35 | class MultiSourceBlogPost(Base): 36 | 37 | __tablename__ = "MultiSourceBlogPost" 38 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 39 | state = sqlalchemy.Column(FSMField) 40 | side_effect = sqlalchemy.Column(sqlalchemy.String) 41 | 42 | def __init__(self, *args, **kwargs): 43 | self.state = "new" 44 | self.side_effect = "default" 45 | super(MultiSourceBlogPost, self).__init__(*args, **kwargs) 46 | 47 | @transition(source="new", target="hidden") 48 | def hide(self): 49 | self.side_effect = "did_hide" 50 | 51 | @transition(source="*", target="deleted") 52 | def delete(self): 53 | self.side_effect = "deleted" 54 | 55 | @transition(target="published", conditions=[val_contains_condition([1, 2])]) 56 | class publish(object): # noqa: N801 57 | @transition(source="new", conditions=[val_eq_condition(1)]) 58 | def do_one(self, instance, value): 59 | instance.side_effect = "did_one" 60 | 61 | @transition(source="new", conditions=[val_contains_condition([2, 42])]) 62 | def do_two(self, instance, value): 63 | instance.side_effect = "did_two" 64 | 65 | @transition(source="hidden") 66 | def do_unhide(self, instance, value): 67 | instance.side_effect = "did_unhide: {}".format(value) 68 | 69 | @transition(source="published") 70 | def do_publish_loop(self, instance, value): 71 | instance.side_effect = "do_publish_loop: {}".format(value) 72 | 73 | @transition(target="published", source=["new", "something"]) 74 | class noPreFilterPublish(object): # noqa: N801 75 | @transition(source="*", conditions=[three_argument_condition(1, 2, 3)]) 76 | def do_three_arg_mk1(self, instance, val1, val2, val3): 77 | instance.side_effect = "did_three_arg_mk1::{}".format([val1, val2, val3]) 78 | 79 | @transition(source="new", conditions=[three_argument_condition("str", -1, 42)]) 80 | def do_three_arg_mk2(self, instance, val1, val2, val3): 81 | instance.side_effect = "did_three_arg_mk2::{}".format([val1, val2, val3]) 82 | 83 | 84 | class TestMultiSourceBlogPost(object): 85 | @pytest.fixture 86 | def model(self): 87 | return MultiSourceBlogPost() 88 | 89 | def test_transition_one(self, model): 90 | assert model.publish.can_proceed(1) 91 | 92 | model.publish.set(1) 93 | assert model.state == "published" 94 | assert model.side_effect == "did_one" 95 | 96 | def test_transition_two(self, model): 97 | assert model.publish.can_proceed(2) 98 | 99 | model.publish.set(2) 100 | assert model.state == "published" 101 | assert model.side_effect == "did_two" 102 | 103 | def test_three_arg_transition_mk1(self, model): 104 | assert model.noPreFilterPublish.can_proceed(1, 2, 3) 105 | model.noPreFilterPublish.set(1, 2, 3) 106 | assert model.state == "published" 107 | assert model.side_effect == "did_three_arg_mk1::[1, 2, 3]" 108 | 109 | def test_three_arg_transition_mk2(self, model): 110 | assert model.noPreFilterPublish.can_proceed("str", -1, 42) 111 | model.noPreFilterPublish.set("str", -1, 42) 112 | assert model.state == "published" 113 | assert model.side_effect == "did_three_arg_mk2::['str', -1, 42]" 114 | 115 | def unable_to_proceed_with_invalid_kwargs(self, model): 116 | assert not model.noPreFilterPublish.can_proceed("str", -1, tomato="potato") 117 | 118 | def test_transition_two_incorrect_arg(self, model): 119 | # Transition should be rejected because of 120 | # top-level `val_contains_condition([1,2])` constraint 121 | with pytest.raises(PreconditionError) as err: 122 | model.publish.set(42) 123 | assert "Preconditions are not satisfied" in str(err) 124 | assert model.state == "new" 125 | assert model.side_effect == "default" 126 | 127 | # Verify that the exception can still be avoided 128 | # with can_proceed() call 129 | assert not model.publish.can_proceed(42) 130 | assert not model.publish.can_proceed(4242) 131 | 132 | def test_hide(self, model): 133 | model.hide.set() 134 | assert model.state == "hidden" 135 | assert model.side_effect == "did_hide" 136 | 137 | model.publish.set(2) 138 | assert model.state == "published" 139 | assert model.side_effect == "did_unhide: 2" 140 | 141 | def test_publish_loop(self, model): 142 | model.publish.set(1) 143 | assert model.state == "published" 144 | assert model.side_effect == "did_one" 145 | 146 | for arg in (1, 2, 1, 1, 2): 147 | model.publish.set(arg) 148 | assert model.state == "published" 149 | assert model.side_effect == "do_publish_loop: {}".format(arg) 150 | 151 | def test_delete_new(self, model): 152 | model.delete.set() 153 | assert model.state == "deleted" 154 | 155 | # Can not switch from deleted to published 156 | assert not model.publish.can_proceed(2) 157 | with pytest.raises(InvalidSourceStateError) as err: 158 | model.publish.set(2) 159 | assert "Unable to switch" in str(err) 160 | assert model.state == "deleted" 161 | -------------------------------------------------------------------------------- /test/test_performance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy 3 | 4 | import sqlalchemy_fsm 5 | 6 | from .conftest import Base 7 | 8 | 9 | class Benchmarked(Base): 10 | __tablename__ = "benchmark_test" 11 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 12 | state = sqlalchemy.Column(sqlalchemy_fsm.FSMField) 13 | 14 | def __init__(self, *args, **kwargs): 15 | self.state = "new" 16 | super(Benchmarked, self).__init__(*args, **kwargs) 17 | 18 | @sqlalchemy_fsm.transition(source="*", target="published") 19 | def published(self): 20 | pass 21 | 22 | @sqlalchemy_fsm.transition(source="*", target="hidden") 23 | def hidden(self): 24 | pass 25 | 26 | @sqlalchemy_fsm.transition(target="cls_transition") 27 | class cls_move(object): # noqa: N801 28 | @sqlalchemy_fsm.transition(source="new") 29 | def from_new(self, instance): 30 | pass 31 | 32 | @sqlalchemy_fsm.transition(source="published") 33 | def from_pub(self, instance): 34 | pass 35 | 36 | @sqlalchemy_fsm.transition(source="hidden") 37 | def from_hidden(self, instance): 38 | pass 39 | 40 | @sqlalchemy_fsm.transition(source="cls_transition") 41 | def loop(self, instance): 42 | pass 43 | 44 | 45 | # Only enable this when profiling 46 | 47 | 48 | @pytest.mark.skip 49 | class TestPerformanceSimple(object): 50 | @pytest.fixture 51 | def model(self, session): 52 | out = Benchmarked() 53 | session.add(out) 54 | session.commit() 55 | return out 56 | 57 | @pytest.mark.parametrize("in_expected_state", [True, False]) 58 | def test_state_check(self, in_expected_state, benchmark, model, session): 59 | if in_expected_state: 60 | model.published.set() 61 | else: 62 | model.hidden.set() 63 | session.commit() 64 | # Expected state - published 65 | rv = benchmark.pedantic(lambda: model.published(), rounds=10000) 66 | assert rv == in_expected_state 67 | 68 | def test_cls_selector(self, benchmark): 69 | benchmark.pedantic(lambda: Benchmarked.published(), rounds=10000) 70 | 71 | def test_set_performance(self, benchmark, model): 72 | def set_fn(): 73 | """Cycle through two set() ops.""" 74 | 75 | model.published.set() 76 | model.hidden.set() 77 | 78 | benchmark.pedantic(set_fn, rounds=10000) 79 | 80 | def test_cls_performance(self, benchmark, model): 81 | def set_fn(): 82 | """Cycle through two set() ops.""" 83 | model.cls_move.set() 84 | model.published.set() 85 | # model.cls_move.set() 86 | # model.hidden.set() 87 | # model.cls_move.set() 88 | 89 | benchmark.pedantic(set_fn, rounds=10000) 90 | -------------------------------------------------------------------------------- /test/test_transition_classes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy 3 | 4 | from sqlalchemy_fsm import FSMField, transition 5 | 6 | from .conftest import Base 7 | 8 | 9 | # Alternative syntax - separately defined transaction and sqlalchemy classes 10 | class SeparatePublishHandler(object): 11 | @transition(source="new") 12 | def do_one(self, instance): 13 | instance.side_effect = "SeparatePublishHandler::did_one" 14 | 15 | @transition(source="hidden") 16 | def do_two(self, instance): 17 | instance.side_effect = "SeparatePublishHandler::did_two" 18 | 19 | 20 | @transition(target="pre_decorated_publish") 21 | class SeparateDecoratedPublishHandler(object): 22 | @transition(source="new") 23 | def do_one(self, instance): 24 | instance.side_effect = "SeparatePublishHandler::did_one" 25 | 26 | @transition(target="pre_decorated_publish", source="hidden") 27 | def do_two(self, instance): 28 | instance.side_effect = "SeparatePublishHandler::did_two" 29 | 30 | 31 | class AltSyntaxBlogPost(Base): 32 | 33 | __tablename__ = "AltSyntaxBlogPost" 34 | id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) 35 | state = sqlalchemy.Column(FSMField) 36 | side_effect = sqlalchemy.Column(sqlalchemy.String) 37 | 38 | def __init__(self, *args, **kwargs): 39 | self.state = "new" 40 | self.side_effect = "default" 41 | super(AltSyntaxBlogPost, self).__init__(*args, **kwargs) 42 | 43 | @transition(source="new", target="hidden") 44 | def hide(self): 45 | pass 46 | 47 | pre_decorated_publish = SeparateDecoratedPublishHandler 48 | post_decorated_publish = transition(target="post_decorated_publish")( 49 | SeparatePublishHandler 50 | ) 51 | 52 | 53 | class TestAltSyntaxBlogPost(object): 54 | @pytest.fixture 55 | def model(self): 56 | return AltSyntaxBlogPost() 57 | 58 | def test_pre_decorated_publish(self, model): 59 | model.pre_decorated_publish.set() 60 | assert model.state == "pre_decorated_publish" 61 | assert model.side_effect == "SeparatePublishHandler::did_one" 62 | 63 | def test_pre_decorated_publish_from_hidden(self, model): 64 | model.hide.set() 65 | assert model.state == "hidden" 66 | assert model.hide() 67 | assert not model.pre_decorated_publish() 68 | model.pre_decorated_publish.set() 69 | assert model.state == "pre_decorated_publish" 70 | assert model.pre_decorated_publish() 71 | assert model.side_effect == "SeparatePublishHandler::did_two" 72 | 73 | def test_post_decorated_from_hidden(self, model): 74 | model.post_decorated_publish.set() 75 | assert model.state == "post_decorated_publish" 76 | assert model.side_effect == "SeparatePublishHandler::did_one" 77 | 78 | def test_post_decorated_publish_from_hidden(self, model): 79 | model.hide.set() 80 | assert model.state == "hidden" 81 | model.post_decorated_publish.set() 82 | assert model.state == "post_decorated_publish" 83 | assert model.side_effect == "SeparatePublishHandler::did_two" 84 | 85 | def mk_records(self, session, count): 86 | records = [AltSyntaxBlogPost() for idx in range(10)] 87 | session.add_all(records) 88 | return records 89 | 90 | @pytest.mark.parametrize("query_method", ["call", "is_"]) 91 | def test_class_query(self, session, query_method): 92 | hidden_records = self.mk_records(session, 5) 93 | pre_decorated_published = self.mk_records(session, 5) 94 | post_decorated_published = self.mk_records(session, 5) 95 | 96 | [el.hide.set() for el in hidden_records] 97 | [el.pre_decorated_publish.set() for el in pre_decorated_published] 98 | [el.post_decorated_publish.set() for el in post_decorated_published] 99 | 100 | session.commit() 101 | 102 | all_ids = [ 103 | el.id 104 | for el in ( 105 | hidden_records + pre_decorated_published + post_decorated_published 106 | ) 107 | ] 108 | for (handler, expected_group) in [ 109 | ("hide", hidden_records), 110 | ("pre_decorated_publish", pre_decorated_published), 111 | ("post_decorated_publish", post_decorated_published), 112 | ]: 113 | expected_ids = set(el.id for el in expected_group) 114 | attr = getattr(AltSyntaxBlogPost, handler) 115 | 116 | if query_method == "call": 117 | attr_filter = { 118 | True: attr(), 119 | False: ~attr(), 120 | } 121 | elif query_method == "is_": 122 | attr_filter = { 123 | True: attr.is_(True), 124 | False: attr.is_(False), 125 | } 126 | else: 127 | raise NotImplementedError(query_method) 128 | 129 | matching = ( 130 | session.query(AltSyntaxBlogPost) 131 | .filter( 132 | attr_filter[True], 133 | AltSyntaxBlogPost.id.in_(all_ids), 134 | ) 135 | .all() 136 | ) 137 | assert len(matching) == len(expected_group) 138 | assert set(el.id for el in matching) == expected_ids 139 | 140 | not_matching = ( 141 | session.query(AltSyntaxBlogPost) 142 | .filter( 143 | attr_filter[False], 144 | AltSyntaxBlogPost.id.in_(all_ids), 145 | ) 146 | .all() 147 | ) 148 | assert len(not_matching) == (len(all_ids) - len(expected_group)) 149 | assert not expected_ids.intersection( 150 | el.id for el in not_matching 151 | ), expected_ids.intersection(el.id for el in not_matching) 152 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist=lint,py34,py36,py39,py310 3 | 4 | [testenv] 5 | deps = 6 | -r requirements/develop.txt 7 | SQLAlchemy{env:SQLALCHEMY_VERSION_SPEC:} 8 | commands = 9 | pytest --cov=sqlalchemy_fsm --cov-append --cov-report=term-missing --no-cov-on-fail {posargs} 10 | 11 | 12 | [testenv:lint] 13 | skip_install = true 14 | deps = 15 | -r requirements/develop.txt 16 | commands = 17 | black --check src/ test/ 18 | flake8 {posargs} src/ test/ 19 | 20 | [testenv:black] 21 | skip_install = true 22 | deps = 23 | black 24 | commands = black setup.py src/ test/ {posargs} 25 | 26 | [pytest] 27 | addopts = -v -l --color=yes --cov=sqlalchemy_fsm --no-cov-on-fail 28 | testpaths = test 29 | 30 | [flake8] 31 | max-line-length = 90 32 | import-order-style = edited 33 | application-import-names = sqlalchemy_fsm, tests 34 | per-file-ignores = 35 | # imported but unused 36 | */__init__.py: F401 37 | 38 | ## GitHub CI 39 | [gh-actions] 40 | python = 41 | 3.4: py34 42 | 3.6: py36 43 | 3.9: py39, lint 44 | 3.10: py310 --------------------------------------------------------------------------------