├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── .gitignore ├── LICENSE ├── README.md ├── django_test_apps.txt ├── django_test_suite.sh ├── django_tidb ├── __init__.py ├── base.py ├── features.py ├── fields │ ├── __init__.py │ └── vector.py ├── introspection.py ├── operations.py ├── patch.py ├── schema.py └── version.py ├── pyproject.toml ├── run_testing_worker.py ├── tests ├── tidb │ ├── __init__.py │ ├── models.py │ ├── test_tidb_auto_id_cache.py │ ├── test_tidb_auto_random.py │ ├── test_tidb_ddl.py │ └── test_tidb_explain.py ├── tidb_field_defaults │ ├── README.md │ ├── __init__.py │ ├── models.py │ └── tests.py └── tidb_vector │ ├── __init__.py │ ├── models.py │ └── test_vector.py ├── tidb_settings.py └── tox.ini /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | pull_request: 3 | push: 4 | branches: 5 | - main 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | lint: 13 | name: lint 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v3 18 | 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | python -m pip install tox 23 | 24 | - name: Run lint 25 | run: | 26 | tox -e lint 27 | 28 | tests: 29 | strategy: 30 | fail-fast: false 31 | matrix: 32 | python-version: 33 | - '3.10' 34 | - '3.11' 35 | - '3.12' 36 | - '3.13' 37 | django-version: 38 | - '5.2.1' 39 | tidb-version: 40 | - 'v8.5.1' 41 | - 'v8.1.2' 42 | - 'v7.5.6' 43 | - 'v7.1.6' 44 | - 'v6.5.9' 45 | - 'v5.4.3' 46 | exclude: 47 | # Django introduced the `debug_transaction` feature in version 4.2.x, 48 | # but it does not consider databases that do not support savepoints(TiDB < 6.2.0), 49 | # as a result, all `assertNumQueries` in test cases failed. 50 | # https://github.com/django/django/commit/798e38c2b9c46ab72e2ee8c33dc822f01b194b1e 51 | - django-version: '5.2.1' 52 | tidb-version: 'v5.4.3' 53 | 54 | name: py${{ matrix.python-version }}_tidb${{ matrix.tidb-version }}_django${{ matrix.django-version }} 55 | runs-on: ubuntu-latest 56 | 57 | services: 58 | tidb: 59 | image: wangdi4zm/tind:${{ matrix.tidb-version }}-standalone 60 | ports: 61 | - 4000:4000 62 | 63 | steps: 64 | - name: Checkout 65 | uses: actions/checkout@v3 66 | 67 | - name: Setup Python 68 | uses: actions/setup-python@v4 69 | with: 70 | python-version: ${{ matrix.python-version }} 71 | 72 | - name: Install dependencies 73 | run: | 74 | python -m pip install --upgrade pip 75 | python -m pip install tox tox-gh-actions 76 | sudo apt-get update 77 | sudo apt-get install -y libmemcached-dev zlib1g-dev 78 | 79 | - name: Run tests 80 | run: tox 81 | env: 82 | DJANGO_VERSION: ${{ matrix.django-version }} 83 | 84 | vector-tests: 85 | strategy: 86 | fail-fast: false 87 | matrix: 88 | python-version: 89 | - '3.13' 90 | django-version: 91 | - '5.2.1' 92 | 93 | name: vector-py${{ matrix.python-version }}_django${{ matrix.django-version }} 94 | runs-on: ubuntu-latest 95 | services: 96 | tidb: 97 | image: wangdi4zm/tind:v8.5.1-with-tiflash 98 | ports: 99 | - 4000:4000 100 | steps: 101 | - name: Checkout 102 | uses: actions/checkout@v3 103 | 104 | - name: Setup Python 105 | uses: actions/setup-python@v4 106 | with: 107 | python-version: ${{ matrix.python-version }} 108 | 109 | - name: Install dependencies 110 | run: | 111 | python -m pip install --upgrade pip 112 | python -m pip install tox tox-gh-actions 113 | sudo apt-get update 114 | sudo apt-get install -y libmemcached-dev zlib1g-dev 115 | 116 | - name: Hack for vector tests 117 | run: | 118 | sed '27a cp -rT ./tests/tidb_vector $DJANGO_TESTS_DIR/django/tests/tidb_vector' -i django_test_suite.sh 119 | sed '31a pip install numpy~=1.0' -i django_test_suite.sh 120 | cat django_test_suite.sh 121 | echo -n "tidb_vector" > django_test_apps.txt 122 | cat django_test_apps.txt 123 | 124 | - name: Run tests 125 | run: tox 126 | env: 127 | DJANGO_VERSION: ${{ matrix.django-version }} 128 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python 🐍 distributions 📦 to PyPI 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.x' 18 | - name: Install pypa/build 19 | run: >- 20 | python3 -m 21 | pip install 22 | build 23 | --user 24 | - name: Build a binary wheel and a source tarball 25 | run: >- 26 | python3 -m 27 | build 28 | --sdist 29 | --wheel 30 | --outdir dist/ 31 | . 32 | - name: Publish distribution 📦 to PyPI 33 | if: startsWith(github.ref, 'refs/tags') 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | with: 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | verbose: true 38 | print-hash: true 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .idea/ 141 | django_tests_dir 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TiDB dialect for Django 2 | 3 | ![PyPI](https://img.shields.io/pypi/v/django-tidb) 4 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/django-tidb) 5 | ![PyPI - Downloads](https://img.shields.io/pypi/dw/django-tidb) 6 | [![.github/workflows/ci.yml](https://github.com/pingcap/django-tidb/actions/workflows/ci.yml/badge.svg)](https://github.com/pingcap/django-tidb/actions/workflows/ci.yml) 7 | 8 | This adds compatibility for [TiDB](https://github.com/pingcap/tidb) to Django. 9 | 10 | ## Installation Guide 11 | 12 | ### Prerequisites 13 | 14 | Before installing django-tidb, ensure you have a MySQL driver installed. You can choose either `mysqlclient`(recommended) or `pymysql`(at your own risk). 15 | 16 | #### Install mysqlclient (Recommended) 17 | 18 | Please refer to the [mysqlclient official guide](https://github.com/PyMySQL/mysqlclient#install) 19 | 20 | #### Install pymysql (At your own risk) 21 | 22 | > django-tidb has not been tested with pymysql 23 | 24 | ```bash 25 | pip install pymysql 26 | ``` 27 | 28 | Then add the following code at the beginning of your Django's `settings.py`: 29 | 30 | ```python 31 | import pymysql 32 | 33 | pymysql.install_as_MySQLdb() 34 | ``` 35 | 36 | ### Installing django-tidb 37 | 38 | To install django-tidb, you need to select the version that corresponds with your Django version. Please refer to the table below for guidance: 39 | 40 | > The minor release number of Django doesn't correspond to the minor release number of django-tidb. Use the latest minor release of each. 41 | 42 | |django|django-tidb|install command| 43 | |:----:|:---------:|:-------------:| 44 | |v5.2.x|v5.2.x|`pip install 'django-tidb~=5.2.0'`| 45 | |v5.1.x|v5.1.x|`pip install 'django-tidb~=5.1.0'`| 46 | |v5.0.x|v5.0.x|`pip install 'django-tidb~=5.0.0'`| 47 | |v4.2.x|v4.2.x|`pip install 'django-tidb~=4.2.0'`| 48 | |v4.1.x|v4.1.x|`pip install 'django-tidb~=4.1.0'`| 49 | |v3.2.x|v3.2.x|`pip install 'django-tidb~=3.2.0'`| 50 | 51 | ## Usage 52 | 53 | Set `'ENGINE': 'django_tidb'` in your settings to this: 54 | 55 | ```python 56 | DATABASES = { 57 | 'default': { 58 | 'ENGINE': 'django_tidb', 59 | 'NAME': 'django', 60 | 'USER': 'root', 61 | 'PASSWORD': '', 62 | 'HOST': '127.0.0.1', 63 | 'PORT': 4000, 64 | }, 65 | } 66 | DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' 67 | USE_TZ = False 68 | SECRET_KEY = 'django_tests_secret_key' 69 | ``` 70 | 71 | - [AUTO_RANDOM](#using-auto_random) 72 | - [AUTO_ID_CACHE](#using-auto_id_cache) 73 | - [Vector (Beta)](#vector-beta) 74 | 75 | ### Using `AUTO_RANDOM` 76 | 77 | [`AUTO_RANDOM`](https://docs.pingcap.com/tidb/stable/auto-random) is a feature in TiDB that generates unique IDs for a table automatically. It is similar to `AUTO_INCREMENT`, but it can avoid write hotspot in a single storage node caused by TiDB assigning consecutive IDs. It also have some restrictions, please refer to the [documentation](https://docs.pingcap.com/tidb/stable/auto-random#restrictions). 78 | 79 | To use `AUTO_RANDOM` in Django, you can do it by following two ways: 80 | 81 | 1. Declare globally in `settings.py` as shown below, it will affect all models: 82 | 83 | ```python 84 | DEFAULT_AUTO_FIELD = 'django_tidb.fields.BigAutoRandomField' 85 | ``` 86 | 87 | 2. Manually declare it in the model as shown below: 88 | 89 | ```python 90 | from django_tidb.fields import BigAutoRandomField 91 | 92 | class MyModel(models.Model): 93 | id = BigAutoRandomField(primary_key=True) 94 | title = models.CharField(max_length=200) 95 | ``` 96 | 97 | `BigAutoRandomField` is a subclass of `BigAutoField`, it can only be used for primary key and its behavior can be controlled by setting the parameters `shard_bits` and `range`. For detailed information, please refer to the [documentation](https://docs.pingcap.com/tidb/stable/auto-random#basic-concepts). 98 | 99 | Migrate from `AUTO_INCREMENT` to `AUTO_RANDOM`: 100 | 101 | 1. Check if the original column is `BigAutoField(bigint)`, if not, migrate it to `BigAutoField(bigint)` first. 102 | 2. In the database configuration (`settings.py`), define [`SET @@tidb_allow_remove_auto_inc = ON`](https://docs.pingcap.com/tidb/stable/system-variables#tidb_allow_remove_auto_inc-new-in-v2118-and-v304) in the `init_command`. You can remove it after completing the migration. 103 | 104 | ```python 105 | # settings.py 106 | DATABASES = { 107 | 'default': { 108 | 'ENGINE': 'django_tidb', 109 | ... 110 | 'OPTIONS': { 111 | 'init_command': 'SET @@tidb_allow_remove_auto_inc = ON', 112 | } 113 | 114 | } 115 | } 116 | ``` 117 | 118 | 3. Finnaly, migrate it to `BigAutoRandomField(bigint)`. 119 | 120 | > **Note** 121 | > 122 | > `AUTO_RANDOM` is supported after TiDB v3.1.0, and only support define with `range` after v6.3.0, so `range` will be ignored if TiDB version is lower than v6.3.0 123 | 124 | ### Using `AUTO_ID_CACHE` 125 | 126 | [`AUTO_ID_CACHE`](https://docs.pingcap.com/tidb/stable/auto-increment#auto_id_cache) allow users to set the cache size for allocating the auto-increment ID, as you may know, TiDB guarantees that AUTO_INCREMENT values are monotonic (always increasing) on a per-server basis, but its value may appear to jump dramatically if an INSERT operation is performed against another TiDB Server, This is caused by the fact that each server has its own cache which is controlled by `AUTO_ID_CACHE`. But from TiDB v6.4.0, it introduces a centralized auto-increment ID allocating service, you can enable [*MySQL compatibility mode*](https://docs.pingcap.com/tidb/stable/auto-increment#mysql-compatibility-mode) by set `AUTO_ID_CACHE` to `1` when creating a table without losing performance. 127 | 128 | To use `AUTO_ID_CACHE` in Django, you can specify `tidb_auto_id_cache` in the model's `Meta` class as shown below when creating a new table: 129 | 130 | ```python 131 | class MyModel(models.Model): 132 | title = models.CharField(max_length=200) 133 | 134 | class Meta: 135 | tidb_auto_id_cache = 1 136 | ``` 137 | 138 | But there are some limitations: 139 | 140 | - `tidb_auto_id_cache` can only affect the table creation, after that it will be ignored even if you change it. 141 | - `tidb_auto_id_cache` only affects the `AUTO_INCREMENT` column. 142 | 143 | ### Vector (Beta) 144 | 145 | Now only TiDB Cloud Serverless cluster supports vector data type, see [Integrating Vector Search into TiDB Serverless for AI Applications](https://www.pingcap.com/blog/integrating-vector-search-into-tidb-for-ai-applications/). 146 | 147 | `VectorField` is still in beta, and the API may change in the future. 148 | 149 | To use `VectorField` in Django, you need to install `django-tidb` with `vector` extra: 150 | 151 | ```bash 152 | pip install 'django-tidb[vector]' 153 | ``` 154 | 155 | Then you can use `VectorField` in your model: 156 | 157 | ```python 158 | from django.db import models 159 | from django_tidb.fields.vector import VectorField 160 | 161 | class Test(models.Model): 162 | embedding = VectorField(dimensions=3) 163 | ``` 164 | 165 | You can also add an hnsw index when creating the table, for more information, please refer to the [documentation](https://docs.google.com/document/d/15eAO0xrvEd6_tTxW_zEko4CECwnnSwQg8GGrqK1Caiw). 166 | 167 | ```python 168 | class Test(models.Model): 169 | embedding = VectorField(dimensions=3) 170 | class Meta: 171 | indexes = [ 172 | VectorIndex(L2Distance("embedding"), name='idx_l2'), 173 | ] 174 | ``` 175 | 176 | #### Create a record 177 | 178 | ```python 179 | Test.objects.create(embedding=[1, 2, 3]) 180 | ``` 181 | 182 | #### Get instances with vector field 183 | 184 | TiDB Vector support below distance functions: 185 | 186 | - `L1Distance` 187 | - `L2Distance` 188 | - `CosineDistance` 189 | - `NegativeInnerProduct` 190 | 191 | Get instances with vector field and calculate distance to a given vector: 192 | 193 | ```python 194 | Test.objects.annotate(distance=CosineDistance('embedding', [3, 1, 2])) 195 | ``` 196 | 197 | Get instances with vector field and calculate distance to a given vector, and filter by distance: 198 | 199 | ```python 200 | Test.objects.alias(distance=CosineDistance('embedding', [3, 1, 2])).filter(distance__lt=5) 201 | ``` 202 | 203 | ## Supported versions 204 | 205 | - TiDB 5.4 and newer(https://www.pingcap.com/tidb-release-support-policy/) 206 | - Django 3.2, 4.1, 4.2, 5.0, 5.1 and 5.2 207 | - Python 3.6 and newer(must match Django's Python version requirement) 208 | 209 | ## Test 210 | 211 | create your virtualenv with: 212 | 213 | ```bash 214 | $ virtualenv venv 215 | $ source venv/bin/activate 216 | ``` 217 | 218 | you can use the command ```deactivate``` to exit from the virtual environment. 219 | 220 | run all integration tests. 221 | 222 | ```bash 223 | $ DJANGO_VERSION=3.2.12 python run_testing_worker.py 224 | ``` 225 | 226 | ## Migrate from previous versions 227 | 228 | Releases on PyPi before 3.0.0 are published from repository https://github.com/blacktear23/django_tidb. This repository is a new implementation and released under versions from 3.0.0. No backwards compatibility is ensured. The most significant points are: 229 | 230 | - Engine name is `django_tidb` instead of `django_tidb.tidb`. 231 | 232 | ## Known issues 233 | 234 | - TiDB before v6.6.0 does not support FOREIGN KEY constraints([#18209](https://github.com/pingcap/tidb/issues/18209)). 235 | - TiDB before v6.2.0 does not support SAVEPOINT([#6840](https://github.com/pingcap/tidb/issues/6840)). 236 | - TiDB has limited support for default value expressions, please refer to the [documentation](https://docs.pingcap.com/tidb/dev/data-type-default-values#specify-expressions-as-default-values). 237 | -------------------------------------------------------------------------------- /django_test_apps.txt: -------------------------------------------------------------------------------- 1 | tidb 2 | tidb_field_defaults 3 | admin_changelist 4 | admin_custom_urls 5 | admin_docs 6 | admin_filters 7 | admin_inlines 8 | admin_ordering 9 | admin_utils 10 | admin_views 11 | aggregation 12 | aggregation_regress 13 | annotations 14 | auth_tests 15 | backends 16 | basic 17 | bulk_create 18 | cache 19 | check_framework 20 | conditional_processing 21 | constraints 22 | contenttypes_tests 23 | custom_columns 24 | custom_lookups 25 | custom_managers 26 | custom_methods 27 | custom_migration_operations 28 | custom_pk 29 | datatypes 30 | dates 31 | datetimes 32 | db_typecasts 33 | db_utils 34 | db_functions 35 | defer 36 | defer_regress 37 | delete 38 | delete_regress 39 | distinct_on_fields 40 | empty 41 | expressions_case 42 | expressions_window 43 | extra_regress 44 | field_subclassing 45 | file_storage 46 | file_uploads 47 | filtered_relation 48 | fixtures 49 | fixtures_model_package 50 | fixtures_regress 51 | force_insert_update 52 | foreign_object 53 | forms_tests 54 | from_db_value 55 | generic_inline_admin 56 | generic_relations 57 | generic_relations_regress 58 | generic_views 59 | get_earliest_or_latest 60 | get_object_or_404 61 | get_or_create 62 | i18n 63 | indexes 64 | inline_formsets 65 | inspectdb 66 | introspection 67 | invalid_models_tests 68 | known_related_objects 69 | lookup 70 | m2m_and_m2o 71 | m2m_intermediary 72 | m2m_multiple 73 | m2m_recursive 74 | m2m_regress 75 | m2m_signals 76 | m2m_through 77 | m2m_through_regress 78 | m2o_recursive 79 | managers_regress 80 | many_to_many 81 | many_to_one 82 | many_to_one_null 83 | max_lengths 84 | migrate_signals 85 | migration_test_data_persistence 86 | migrations 87 | model_fields 88 | model_forms 89 | model_formsets 90 | model_formsets_regress 91 | model_indexes 92 | model_inheritance 93 | model_inheritance_regress 94 | model_meta 95 | model_options 96 | model_package 97 | model_regress 98 | modeladmin 99 | null_fk 100 | null_fk_ordering 101 | null_queries 102 | one_to_one 103 | or_lookups 104 | order_with_respect_to 105 | ordering 106 | pagination 107 | prefetch_related 108 | properties 109 | proxy_model_inheritance 110 | proxy_models 111 | queries 112 | queryset_pickle 113 | raw_query 114 | reserved_names 115 | reverse_lookup 116 | save_delete_hooks 117 | schema 118 | select_for_update 119 | select_related 120 | select_related_onetoone 121 | select_related_regress 122 | serializers 123 | servers 124 | signals 125 | sitemaps_tests 126 | sites_framework 127 | sites_tests 128 | string_lookup 129 | swappable_models 130 | syndication_tests 131 | test_client 132 | test_client_regress 133 | test_utils 134 | timezones 135 | transaction_hooks 136 | transactions 137 | unmanaged_models 138 | update 139 | update_only_fields 140 | validation 141 | view_tests 142 | nested_foreign_keys 143 | mutually_referential 144 | multiple_database -------------------------------------------------------------------------------- /django_test_suite.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 PingCAP, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | set -x pipefail 16 | 17 | # Disable buffering, so that the logs stream through. 18 | export PYTHONUNBUFFERED=1 19 | 20 | export DJANGO_TESTS_DIR="django_tests_dir" 21 | mkdir -p $DJANGO_TESTS_DIR 22 | 23 | pip3 install -e . 24 | git clone --depth 1 --branch $DJANGO_VERSION https://github.com/django/django.git $DJANGO_TESTS_DIR/django 25 | cp tidb_settings.py $DJANGO_TESTS_DIR/django/tidb_settings.py 26 | cp tidb_settings.py $DJANGO_TESTS_DIR/django/tests/tidb_settings.py 27 | cp -rT ./tests/tidb $DJANGO_TESTS_DIR/django/tests/tidb 28 | cp -rT ./tests/tidb_field_defaults $DJANGO_TESTS_DIR/django/tests/tidb_field_defaults 29 | 30 | cd $DJANGO_TESTS_DIR/django && pip3 install -e . && pip3 install -r tests/requirements/py3.txt && pip3 install -r tests/requirements/mysql.txt; cd ../../ 31 | cd $DJANGO_TESTS_DIR/django/tests 32 | 33 | EXIT_STATUS=0 34 | for DJANGO_TEST_APP in $DJANGO_TEST_APPS 35 | do 36 | python3 runtests.py $DJANGO_TEST_APP --noinput --settings tidb_settings || EXIT_STATUS=$? 37 | if [[ $EXIT_STATUS -ne 0 ]]; then 38 | exit $EXIT_STATUS 39 | fi 40 | done 41 | exit $EXIT_STATUS 42 | -------------------------------------------------------------------------------- /django_tidb/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | # Check Django compatibility before other imports which may fail if the 15 | # wrong version of Django is installed. 16 | 17 | from .patch import monkey_patch 18 | 19 | __version__ = "5.2.0" 20 | 21 | 22 | monkey_patch() 23 | -------------------------------------------------------------------------------- /django_tidb/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | """ 15 | TiDB database backend for Django. 16 | Requires mysqlclient: https://pypi.org/project/mysqlclient/ 17 | """ 18 | from django.db.backends.mysql.base import ( 19 | DatabaseWrapper as MysqlDatabaseWrapper, 20 | ) 21 | from django.utils.functional import cached_property 22 | 23 | # Some of these import MySQLdb, so import them after checking if it's installed. 24 | from .features import DatabaseFeatures 25 | from .introspection import DatabaseIntrospection 26 | from .operations import DatabaseOperations 27 | from .schema import DatabaseSchemaEditor 28 | from .version import TiDBVersion 29 | 30 | server_version = TiDBVersion() 31 | 32 | 33 | class DatabaseWrapper(MysqlDatabaseWrapper): 34 | # Django has some hard code for mysql in `JSONFields` and tests through check vendor name, 35 | # as TiDB is compatible with MySQL, so setting vendor name to mysql is ok. 36 | vendor = "mysql" 37 | display_name = "TiDB" 38 | 39 | SchemaEditorClass = DatabaseSchemaEditor 40 | # Classes instantiated in __init__(). 41 | features_class = DatabaseFeatures 42 | introspection_class = DatabaseIntrospection 43 | ops_class = DatabaseOperations 44 | 45 | def get_database_version(self): 46 | return self.tidb_version 47 | 48 | @cached_property 49 | def data_type_check_constraints(self): 50 | if self.features.supports_column_check_constraints: 51 | check_constraints = { 52 | "PositiveBigIntegerField": "`%(column)s` >= 0", 53 | "PositiveIntegerField": "`%(column)s` >= 0", 54 | "PositiveSmallIntegerField": "`%(column)s` >= 0", 55 | "JSONField": "JSON_VALID(`%(column)s`)", 56 | } 57 | # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as 58 | # a check constraint. 59 | return check_constraints 60 | return {} 61 | 62 | @cached_property 63 | def tidb_server_data(self): 64 | with self.temporary_connection() as cursor: 65 | # Select some server variables and test if the time zone 66 | # definitions are installed. CONVERT_TZ returns NULL if 'UTC' 67 | # timezone isn't loaded into the mysql.time_zone table. 68 | cursor.execute( 69 | """ 70 | SELECT VERSION(), 71 | @@sql_mode, 72 | @@default_storage_engine, 73 | @@sql_auto_is_null, 74 | @@lower_case_table_names, 75 | CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL 76 | """ 77 | ) 78 | row = cursor.fetchone() 79 | return { 80 | "version": row[0], 81 | "sql_mode": row[1], 82 | "default_storage_engine": row[2], 83 | "sql_auto_is_null": bool(row[3]), 84 | "lower_case_table_names": bool(row[4]), 85 | "has_zoneinfo_database": bool(row[5]), 86 | } 87 | 88 | @cached_property 89 | def tidb_server_info(self): 90 | return self.tidb_server_data["version"] 91 | 92 | @cached_property 93 | def tidb_version(self): 94 | match = server_version.match(self.tidb_server_info) 95 | if not match: 96 | raise Exception( 97 | "Unable to determine Tidb version from version string %r" 98 | % self.tidb_server_info 99 | ) 100 | return server_version.version 101 | 102 | @cached_property 103 | def sql_mode(self): 104 | sql_mode = self.tidb_server_data["sql_mode"] 105 | return set(sql_mode.split(",") if sql_mode else ()) 106 | -------------------------------------------------------------------------------- /django_tidb/features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | import operator 15 | 16 | from django.db.backends.mysql.features import ( 17 | DatabaseFeatures as MysqlDatabaseFeatures, 18 | ) 19 | from django.utils.functional import cached_property 20 | 21 | 22 | class DatabaseFeatures(MysqlDatabaseFeatures): 23 | has_select_for_update = True 24 | has_native_uuid_field = False 25 | atomic_transactions = False 26 | supports_atomic_references_rename = False 27 | can_clone_databases = False 28 | can_rollback_ddl = False 29 | # Unsupported add column and foreign key in single statement 30 | # https://github.com/pingcap/tidb/issues/45474 31 | can_create_inline_fk = False 32 | order_by_nulls_first = True 33 | create_test_procedure_without_params_sql = None 34 | create_test_procedure_with_int_param_sql = None 35 | test_collations = { 36 | "ci": "utf8mb4_general_ci", 37 | "non_default": "utf8mb4_bin", 38 | "virtual": "utf8mb4_general_ci", 39 | } 40 | 41 | minimum_database_version = (5,) 42 | 43 | @cached_property 44 | def supports_foreign_keys(self): 45 | if self.connection.tidb_version >= (6, 6, 0): 46 | return True 47 | return False 48 | 49 | @cached_property 50 | def indexes_foreign_keys(self): 51 | if self.connection.tidb_version >= (6, 6, 0): 52 | return True 53 | return False 54 | 55 | @cached_property 56 | def supports_transactions(self): 57 | # https://code.djangoproject.com/ticket/28263 58 | if self.connection.tidb_version >= (6, 2, 0): 59 | return True 60 | return False 61 | 62 | @cached_property 63 | def uses_savepoints(self): 64 | if self.connection.tidb_version >= (6, 2, 0): 65 | return True 66 | return False 67 | 68 | @cached_property 69 | def can_release_savepoints(self): 70 | if self.connection.tidb_version >= (6, 2, 0): 71 | return True 72 | return False 73 | 74 | @cached_property 75 | def django_test_skips(self): 76 | skips = { 77 | "This doesn't work on MySQL.": { 78 | "db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround", 79 | "db_functions.comparison.test_least.LeastTests.test_coalesce_workaround", 80 | # UPDATE ... ORDER BY syntax on MySQL/MariaDB does not support ordering by related fields 81 | "update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation_desc", 82 | }, 83 | "MySQL doesn't support functional indexes on a function that " 84 | "returns JSON": { 85 | "schema.tests.SchemaTests.test_func_index_json_key_transform", 86 | }, 87 | "MySQL supports multiplying and dividing DurationFields by a " 88 | "scalar value but it's not implemented (#25287).": { 89 | "expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide", 90 | }, 91 | "tidb": { 92 | # Unknown column 'annotations_publisher.id' in 'where clause' 93 | # https://github.com/pingcap/tidb/issues/45181 94 | "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_filter_with_subquery", 95 | # Designed for MySQL only 96 | "backends.mysql.test_features.TestFeatures.test_supports_transactions", 97 | "backends.mysql.tests.Tests.test_check_database_version_supported", 98 | "backends.mysql.test_introspection.StorageEngineTests.test_get_storage_engine", 99 | "check_framework.test_database.DatabaseCheckTests.test_mysql_strict_mode", 100 | # Unsupported add column and foreign key in single statement 101 | "indexes.tests.SchemaIndexesMySQLTests.test_no_index_for_foreignkey", 102 | # TiDB does not support `JSON` format for `EXPLAIN ANALYZE` 103 | "queries.test_explain.ExplainTests.test_mysql_analyze", 104 | "queries.test_explain.ExplainTests.test_mysql_text_to_traditional", 105 | # TiDB cannot guarantee to always rollback the main thread txn when deadlock occurs 106 | "transactions.tests.AtomicMySQLTests.test_implicit_savepoint_rollback", 107 | "filtered_relation.tests.FilteredRelationTests.test_union", 108 | # [planner:3065]Expression #1 of ORDER BY clause is not in SELECT list, references column '' which is 109 | # not in SELECT list; this is incompatible with 110 | "ordering.tests.OrderingTests.test_orders_nulls_first_on_filtered_subquery", 111 | # Unsupported modify column: this column has primary key flag 112 | "schema.tests.SchemaTests.test_alter_auto_field_to_char_field", 113 | # Unsupported modify column: can't remove auto_increment without @@tidb_allow_remove_auto_inc enabled 114 | "schema.tests.SchemaTests.test_alter_auto_field_to_integer_field", 115 | # Found wrong number (0) of check constraints for schema_author.height 116 | "schema.tests.SchemaTests.test_alter_field_default_dropped", 117 | # Unsupported modify column: can't set auto_increment 118 | "schema.tests.SchemaTests.test_alter_int_pk_to_autofield_pk", 119 | "schema.tests.SchemaTests.test_alter_int_pk_to_bigautofield_pk", 120 | # Unsupported drop primary key when the table's pkIsHandle is true 121 | "schema.tests.SchemaTests.test_alter_int_pk_to_int_unique", 122 | # Unsupported drop integer primary key 123 | "schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key", 124 | # Unsupported modify column: can't set auto_increment 125 | "schema.tests.SchemaTests.test_alter_smallint_pk_to_smallautofield_pk", 126 | # Unsupported modify column: this column has primary key flag 127 | "schema.tests.SchemaTests.test_char_field_pk_to_auto_field", 128 | # Unsupported modify charset from utf8mb4 to utf8 129 | "schema.tests.SchemaTests.test_ci_cs_db_collation", 130 | # Unsupported drop integer primary key 131 | "schema.tests.SchemaTests.test_primary_key", 132 | "schema.tests.SchemaTests.test_add_auto_field", 133 | "schema.tests.SchemaTests.test_alter_autofield_pk_to_smallautofield_pk", 134 | "schema.tests.SchemaTests.test_alter_primary_key_db_collation", 135 | "schema.tests.SchemaTests.test_alter_primary_key_the_same_name", 136 | "schema.tests.SchemaTests.test_autofield_to_o2o", 137 | "update.tests.AdvancedTests.test_update_ordered_by_inline_m2m_annotation", 138 | "update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation", 139 | # IntegrityError not raised 140 | "constraints.tests.CheckConstraintTests.test_database_constraint", 141 | "constraints.tests.CheckConstraintTests.test_database_constraint_unicode", 142 | # Result of function ROUND(x, d) is different from MySQL 143 | # https://github.com/pingcap/tidb/issues/26993 144 | "db_functions.math.test_round.RoundTests.test_integer_with_negative_precision", 145 | "db_functions.text.test_chr.ChrTests.test_transform", 146 | "db_functions.text.test_chr.ChrTests.test_non_ascii", 147 | "db_functions.text.test_chr.ChrTests.test_basic", 148 | "db_functions.comparison.test_collate.CollateTests.test_collate_filter_ci", 149 | # Unsupported modifying collation of column from 'utf8mb4_general_ci' to 'utf8mb4_bin' 150 | # when index is defined on it. 151 | "migrations.test_operations.OperationTests.test_alter_field_pk_fk_db_collation", 152 | "migrations.test_executor.ExecutorTests.test_alter_id_type_with_fk", 153 | "migrations.test_operations.OperationTests.test_alter_field_pk", 154 | "migrations.test_operations.OperationTests.test_alter_field_reloads_state_on_fk_target_changes", 155 | "migrations.test_operations.OperationTests.test_rename_field_reloads_state_on_fk_target_changes", 156 | # Unsupported drop primary key when the table is using clustered index 157 | "migrations.test_operations.OperationTests.test_composite_pk_operations", 158 | # 'Adding generated stored column through ALTER TABLE' is not supported for generated columns 159 | "migrations.test_operations.OperationTests.test_generated_field_changes_output_field", 160 | # Unsupported modifying the Reorg-Data types on the primary key 161 | "migrations.test_operations.OperationTests.test_alter_field_pk_fk", 162 | "migrations.test_operations.OperationTests.test_alter_field_pk_fk_char_to_int", 163 | "migrations.test_operations.OperationTests.test_add_constraint", 164 | "migrations.test_operations.OperationTests.test_add_constraint_combinable", 165 | "migrations.test_operations.OperationTests.test_add_constraint_percent_escaping", 166 | "migrations.test_operations.OperationTests.test_add_or_constraint", 167 | "migrations.test_operations.OperationTests.test_create_model_with_constraint", 168 | "migrations.test_operations.OperationTests.test_remove_constraint", 169 | "migrations.test_operations.OperationTests.test_alter_field_pk_mti_and_fk_to_base", 170 | "migrations.test_operations.OperationTests.test_alter_field_pk_mti_fk", 171 | "migrations.test_operations.OperationTests.test_create_model_with_boolean_expression_in_check_constraint", 172 | # TiDB doesn't support drop integer primary key 173 | "migrations.test_operations.OperationTests.test_alter_id_pk_to_uuid_pk", 174 | # TiDB doesn't allow renaming columns referenced by generated columns (same as MySQL) 175 | "migrations.test_operations.OperationTests.test_invalid_generated_field_changes_on_rename_virtual", 176 | # Unsupported adding a stored generated column through ALTER TABLE 177 | "migrations.test_operations.OperationTests.test_invalid_generated_field_changes_on_rename_stored", 178 | "migrations.test_operations.OperationTests.test_add_field_after_generated_field", 179 | "migrations.test_operations.OperationTests.test_add_generated_field_stored", 180 | "migrations.test_operations.OperationTests.test_invalid_generated_field_changes_stored", 181 | "migrations.test_operations.OperationTests.test_invalid_generated_field_persistency_change", 182 | "migrations.test_operations.OperationTests.test_remove_generated_field_stored", 183 | "schema.tests.SchemaTests.test_add_generated_field_contains", 184 | # Failed to modify column's default value when has expression index 185 | # https://github.com/pingcap/tidb/issues/52355 186 | "migrations.test_operations.OperationTests.test_alter_field_with_func_index", 187 | # TiDB has limited support for default value expressions 188 | # https://docs.pingcap.com/tidb/dev/data-type-default-values#specify-expressions-as-default-values 189 | "migrations.test_operations.OperationTests.test_add_field_database_default_function", 190 | "schema.tests.SchemaTests.test_add_text_field_with_db_default", 191 | "schema.tests.SchemaTests.test_db_default_equivalent_sql_noop", 192 | "schema.tests.SchemaTests.test_db_default_output_field_resolving", 193 | # about Pessimistic/Optimistic Transaction Model 194 | "select_for_update.tests.SelectForUpdateTests.test_raw_lock_not_available", 195 | # Wrong referenced_table_schema in information_schema.key_column_usage 196 | # https://github.com/pingcap/tidb/issues/52350 197 | "backends.mysql.test_introspection.TestCrossDatabaseRelations.test_omit_cross_database_relations", 198 | # https://github.com/pingcap/tidb/issues/61091 199 | "model_fields.test_jsonfield.TestQuerying.test_lookups_special_chars", 200 | "model_fields.test_jsonfield.TestQuerying.test_lookups_special_chars_double_quotes", 201 | }, 202 | } 203 | if self.connection.tidb_version < (5,): 204 | skips.update( 205 | { 206 | "tidb4": { 207 | # Unsupported modify column 208 | "schema.tests.SchemaTests.test_rename", 209 | "schema.tests.SchemaTests.test_m2m_rename_field_in_target_model", 210 | "schema.tests.SchemaTests.test_alter_textual_field_keep_null_status", 211 | "schema.tests.SchemaTests.test_alter_text_field_to_time_field", 212 | "schema.tests.SchemaTests.test_alter_text_field_to_datetime_field", 213 | "schema.tests.SchemaTests.test_alter_text_field_to_date_field", 214 | "schema.tests.SchemaTests.test_alter_field_type_and_db_collation", 215 | # wrong result 216 | "expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank", 217 | "migrations.test_operations.OperationTests.test_alter_fk_non_fk", 218 | "migrations.test_operations.OperationTests" 219 | ".test_alter_field_reloads_state_on_fk_with_to_field_target_changes", 220 | "model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values", 221 | } 222 | } 223 | ) 224 | if self.connection.tidb_version < (6, 3): 225 | skips.update( 226 | { 227 | "auto_random": { 228 | "tidb.test_tidb_auto_random.TiDBAutoRandomMigrateTests" 229 | ".test_create_table_explicit_auto_random_field_with_shard_bits_and_range", 230 | "tidb.test_tidb_auto_random.TiDBAutoRandomMigrateTests" 231 | ".test_create_table_explicit_auto_random_field_with_range", 232 | } 233 | } 234 | ) 235 | if self.connection.tidb_version < (6, 6): 236 | skips.update( 237 | { 238 | "tidb653": { 239 | "fixtures_regress.tests.TestFixtures.test_loaddata_raises_error_when_fixture_has_invalid_foreign_key", 240 | "migrations.test_operations.OperationTests.test_autofield__bigautofield_foreignfield_growth", 241 | "migrations.test_operations.OperationTests.test_smallfield_autofield_foreignfield_growth", 242 | "migrations.test_operations.OperationTests.test_smallfield_bigautofield_foreignfield_growth", 243 | "migrations.test_commands.MigrateTests.test_migrate_syncdb_app_label", 244 | "migrations.test_commands.MigrateTests.test_migrate_syncdb_deferred_sql_executed_with_schemaeditor", 245 | "schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references", 246 | "schema.tests.SchemaTests.test_rename_table_renames_deferred_sql_references", 247 | } 248 | } 249 | ) 250 | if self.connection.tidb_version < (7, 2): 251 | skips.update( 252 | { 253 | "tidb72": { 254 | # TiDB support CHECK constraint from v7.2 255 | # https://github.com/pingcap/tidb/issues/41711 256 | "migrations.test_operations.OperationTests.test_create_model_constraint_percent_escaping", 257 | } 258 | } 259 | ) 260 | if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode: 261 | skips.update( 262 | { 263 | "GROUP BY cannot contain nonaggregated column when " 264 | "ONLY_FULL_GROUP_BY mode is enabled on TiDB.": { 265 | "aggregation.tests.AggregateTestCase.test_group_by_nested_expression_with_params", 266 | }, 267 | } 268 | ) 269 | if not self.supports_foreign_keys: 270 | skips.update( 271 | { 272 | # Django does not check if the database supports foreign keys. 273 | "django42_db_unsupport_foreign_keys": { 274 | "inspectdb.tests.InspectDBTestCase.test_same_relations", 275 | }, 276 | } 277 | ) 278 | return skips 279 | 280 | @cached_property 281 | def update_can_self_select(self): 282 | return True 283 | 284 | @cached_property 285 | def can_introspect_foreign_keys(self): 286 | if self.connection.tidb_version >= (6, 6, 0): 287 | return True 288 | return False 289 | 290 | @cached_property 291 | def can_return_columns_from_insert(self): 292 | return False 293 | 294 | can_return_rows_from_bulk_insert = property( 295 | operator.attrgetter("can_return_columns_from_insert") 296 | ) 297 | 298 | @cached_property 299 | def has_zoneinfo_database(self): 300 | return self.connection.tidb_server_data["has_zoneinfo_database"] 301 | 302 | @cached_property 303 | def is_sql_auto_is_null_enabled(self): 304 | return self.connection.tidb_server_data["sql_auto_is_null"] 305 | 306 | @cached_property 307 | def supports_over_clause(self): 308 | return True 309 | 310 | supports_frame_range_fixed_distance = property( 311 | operator.attrgetter("supports_over_clause") 312 | ) 313 | 314 | @cached_property 315 | def supports_column_check_constraints(self): 316 | return True 317 | 318 | @cached_property 319 | def supports_expression_defaults(self): 320 | # TiDB has limited support for default value expressions 321 | # https://docs.pingcap.com/tidb/dev/data-type-default-values#specify-expressions-as-default-values 322 | return True 323 | 324 | supports_table_check_constraints = property( 325 | operator.attrgetter("supports_column_check_constraints") 326 | ) 327 | 328 | @cached_property 329 | def can_introspect_check_constraints(self): 330 | return False 331 | 332 | @cached_property 333 | def has_select_for_update_skip_locked(self): 334 | return False 335 | 336 | @cached_property 337 | def has_select_for_update_nowait(self): 338 | return False 339 | 340 | @cached_property 341 | def has_select_for_update_of(self): 342 | return False 343 | 344 | @cached_property 345 | def supports_explain_analyze(self): 346 | return True 347 | 348 | @cached_property 349 | def supported_explain_formats(self): 350 | return {"TRADITIONAL", "ROW", "BRIEF", "DOT", "TIDB_JSON"} 351 | 352 | @cached_property 353 | def ignores_table_name_case(self): 354 | return self.connection.tidb_server_data["lower_case_table_names"] 355 | 356 | @cached_property 357 | def supports_default_in_lead_lag(self): 358 | return True 359 | 360 | @cached_property 361 | def supports_json_field(self): 362 | return self.connection.tidb_version >= ( 363 | 6, 364 | 3, 365 | ) 366 | 367 | @cached_property 368 | def can_introspect_json_field(self): 369 | return self.supports_json_field and self.can_introspect_check_constraints 370 | 371 | @cached_property 372 | def supports_index_column_ordering(self): 373 | return False 374 | 375 | @cached_property 376 | def supports_expression_indexes(self): 377 | return self.connection.tidb_version >= ( 378 | 5, 379 | 1, 380 | ) 381 | -------------------------------------------------------------------------------- /django_tidb/fields/__init__.py: -------------------------------------------------------------------------------- 1 | from django.db.models import BigAutoField 2 | from django.core import checks 3 | 4 | __all__ = [ 5 | "BigAutoRandomField", 6 | ] 7 | 8 | 9 | class BigAutoRandomField(BigAutoField): 10 | def __init__( 11 | self, 12 | verbose_name=None, 13 | name=None, 14 | shard_bits=5, 15 | range=64, 16 | **kwargs, 17 | ): 18 | self.shard_bits, self.range = shard_bits, range 19 | super().__init__(verbose_name, name, **kwargs) 20 | 21 | def get_internal_type(self): 22 | return "BigAutoRandomField" 23 | 24 | def check(self, **kwargs): 25 | return [ 26 | *super().check(**kwargs), 27 | *self._check_range(), 28 | *self._check_shard_bits(), 29 | ] 30 | 31 | def _check_shard_bits(self): 32 | try: 33 | shard_bits = int(self.shard_bits) 34 | if shard_bits < 1 or shard_bits > 15: 35 | raise ValueError() 36 | except TypeError: 37 | return [ 38 | checks.Error( 39 | "BigAutoRandomField must define a 'shard_bits' attribute.", 40 | obj=self, 41 | ) 42 | ] 43 | except ValueError: 44 | return [ 45 | checks.Error( 46 | "BigAutoRandomField 'shard_bits' attribute must be an integer between 1 and 15.", 47 | obj=self, 48 | ) 49 | ] 50 | else: 51 | return [] 52 | 53 | def _check_range(self): 54 | try: 55 | range = int(self.range) 56 | if range < 32 or range > 64: 57 | raise ValueError() 58 | except TypeError: 59 | return [ 60 | checks.Error( 61 | "BigAutoRandomField must define a 'range' attribute.", 62 | obj=self, 63 | ) 64 | ] 65 | except ValueError: 66 | return [ 67 | checks.Error( 68 | "BigAutoRandomField 'range' attribute must be an integer between 32 and 64.", 69 | obj=self, 70 | ) 71 | ] 72 | else: 73 | return [] 74 | 75 | def deconstruct(self): 76 | name, path, args, kwargs = super().deconstruct() 77 | if self.shard_bits is not None: 78 | kwargs["shard_bits"] = self.shard_bits 79 | if self.range is not None: 80 | kwargs["range"] = self.range 81 | return name, path, args, kwargs 82 | 83 | def db_type(self, connection): 84 | data = self.db_type_parameters(connection) 85 | if connection.tidb_version < (6, 3): 86 | # TiDB < 6.3 doesn't support define AUTO_RANDOM with range 87 | data_type = "bigint AUTO_RANDOM(%(shard_bits)s)" 88 | else: 89 | data_type = "bigint AUTO_RANDOM(%(shard_bits)s, %(range)s)" 90 | return data_type % data 91 | -------------------------------------------------------------------------------- /django_tidb/fields/vector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from django.core import checks 3 | from django import forms 4 | from django.db import models 5 | from django.db.models import Field, FloatField, Func, Value, Index 6 | 7 | MAX_DIM_LENGTH = 16000 8 | MIN_DIM_LENGTH = 1 9 | 10 | 11 | def encode_vector(value, dim=None): 12 | if value is None: 13 | return value 14 | 15 | if isinstance(value, np.ndarray): 16 | if value.ndim != 1: 17 | raise ValueError("expected ndim to be 1") 18 | 19 | if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype( 20 | value.dtype, np.floating 21 | ): 22 | raise ValueError("dtype must be numeric") 23 | 24 | value = value.tolist() 25 | 26 | if dim is not None and len(value) != dim: 27 | raise ValueError("expected %d dimensions, not %d" % (dim, len(value))) 28 | 29 | return "[" + ",".join([str(float(v)) for v in value]) + "]" 30 | 31 | 32 | def decode_vector(value): 33 | if value is None or isinstance(value, np.ndarray): 34 | return value 35 | 36 | if isinstance(value, bytes): 37 | value = value.decode("utf-8") 38 | 39 | return np.array(value[1:-1].split(","), dtype=np.float32) 40 | 41 | 42 | class VectorField(Field): 43 | """ 44 | Support for AI Vector storage. 45 | 46 | Status: Beta 47 | 48 | Info: https://www.pingcap.com/blog/integrating-vector-search-into-tidb-for-ai-applications/ 49 | 50 | Example: 51 | ```python 52 | from django.db import models 53 | from django_tidb.fields.vector import VectorField, CosineDistance 54 | 55 | class Document(models.Model): 56 | content = models.TextField() 57 | embedding = VectorField(dimensions=3) 58 | 59 | # Create a document 60 | Document.objects.create( 61 | content="test content", 62 | embedding=[1, 2, 3], 63 | ) 64 | 65 | # Query with distance 66 | Document.objects.alias( 67 | distance=CosineDistance('embedding', [3, 1, 2]) 68 | ).filter(distance__lt=5) 69 | ``` 70 | """ 71 | 72 | description = "Vector" 73 | empty_strings_allowed = False 74 | 75 | def __init__(self, *args, dimensions=None, **kwargs): 76 | self.dimensions = dimensions 77 | super().__init__(*args, **kwargs) 78 | 79 | def deconstruct(self): 80 | name, path, args, kwargs = super().deconstruct() 81 | if self.dimensions is not None: 82 | kwargs["dimensions"] = self.dimensions 83 | return name, path, args, kwargs 84 | 85 | def db_type(self, connection): 86 | if self.dimensions is None: 87 | return "vector" 88 | return "vector(%d)" % self.dimensions 89 | 90 | def from_db_value(self, value, expression, connection): 91 | return decode_vector(value) 92 | 93 | def to_python(self, value): 94 | if isinstance(value, list): 95 | return np.array(value, dtype=np.float32) 96 | return decode_vector(value) 97 | 98 | def get_prep_value(self, value): 99 | return encode_vector(value) 100 | 101 | def value_to_string(self, obj): 102 | return self.get_prep_value(self.value_from_object(obj)) 103 | 104 | def validate(self, value, model_instance): 105 | if isinstance(value, np.ndarray): 106 | value = value.tolist() 107 | super().validate(value, model_instance) 108 | 109 | def run_validators(self, value): 110 | if isinstance(value, np.ndarray): 111 | value = value.tolist() 112 | super().run_validators(value) 113 | 114 | def formfield(self, **kwargs): 115 | return super().formfield(form_class=VectorFormField, **kwargs) 116 | 117 | def check(self, **kwargs): 118 | return [ 119 | *super().check(**kwargs), 120 | *self._check_dimensions(), 121 | ] 122 | 123 | def _check_dimensions(self): 124 | if self.dimensions is not None and ( 125 | self.dimensions < MIN_DIM_LENGTH or self.dimensions > MAX_DIM_LENGTH 126 | ): 127 | return [ 128 | checks.Error( 129 | f"Vector dimensions must be in the range [{MIN_DIM_LENGTH}, {MAX_DIM_LENGTH}]", 130 | obj=self, 131 | ) 132 | ] 133 | return [] 134 | 135 | 136 | class VectorIndex(Index): 137 | """ 138 | Example: 139 | ```python 140 | from django.db import models 141 | from django_tidb.fields.vector import VectorField, VectorIndex, CosineDistance 142 | 143 | class Document(models.Model): 144 | content = models.TextField() 145 | embedding = VectorField(dimensions=3) 146 | class Meta: 147 | indexes = [ 148 | VectorIndex(CosineDistance("embedding"), name='idx_cos'), 149 | ] 150 | 151 | # Create a document 152 | Document.objects.create( 153 | content="test content", 154 | embedding=[1, 2, 3], 155 | ) 156 | 157 | # Query with distance 158 | Document.objects.alias( 159 | distance=CosineDistance('embedding', [3, 1, 2]) 160 | ).filter(distance__lt=5) 161 | ``` 162 | 163 | Note: 164 | Creating a vector index will automatically set the "TiFlash replica" to 1 in TiDB. 165 | If you want to use high-availability columnar storage feature, use raw SQL instead. 166 | 167 | """ 168 | 169 | def __init__( 170 | self, 171 | *expressions, 172 | name, 173 | ) -> None: 174 | super().__init__(*expressions, fields=(), name=name) 175 | 176 | def create_sql(self, model, schema_editor, using="", **kwargs): 177 | include = [ 178 | model._meta.get_field(field_name).column for field_name in self.include 179 | ] 180 | index_expressions = [] 181 | for expression in self.expressions: 182 | index_expression = models.indexes.IndexExpression(expression) 183 | index_expression.set_wrapper_classes(schema_editor.connection) 184 | index_expressions.append(index_expression) 185 | expressions = models.indexes.ExpressionList( 186 | *index_expressions 187 | ).resolve_expression( 188 | models.sql.query.Query(model, alias_cols=False), 189 | ) 190 | fields = None 191 | col_suffixes = None 192 | # TODO: remove the tiflash replica setting statement from sql_template 193 | # after we support `ADD_TIFLASH_ON_DEMAND` in the `CREATE VECTOR INDEX ...` 194 | sql_template = """ALTER TABLE %(table)s SET TIFLASH REPLICA 1; 195 | CREATE VECTOR INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s""" 196 | return schema_editor._create_index_sql( 197 | model, 198 | fields=fields, 199 | name=self.name, 200 | using=using, 201 | db_tablespace=self.db_tablespace, 202 | col_suffixes=col_suffixes, 203 | sql=sql_template, 204 | opclasses=self.opclasses, 205 | condition=None, 206 | include=include, 207 | expressions=expressions, 208 | **kwargs, 209 | ) 210 | 211 | 212 | class DistanceBase(Func): 213 | output_field = FloatField() 214 | 215 | def __init__(self, expression, vector=None, **extra): 216 | """ 217 | expression: the name of a field, or an expression returing a vector 218 | vector: a vector to compare against 219 | """ 220 | expressions = [expression] 221 | # When using the distance function as expression in the vector index 222 | # statement, the `vector` is None 223 | if vector is not None: 224 | if not hasattr(vector, "resolve_expression"): 225 | vector = Value(encode_vector(vector)) 226 | expressions.append(vector) 227 | super().__init__(*expressions, **extra) 228 | 229 | 230 | class L1Distance(DistanceBase): 231 | function = "VEC_L1_DISTANCE" 232 | 233 | 234 | class L2Distance(DistanceBase): 235 | function = "VEC_L2_DISTANCE" 236 | 237 | 238 | class CosineDistance(DistanceBase): 239 | function = "VEC_COSINE_DISTANCE" 240 | 241 | 242 | class NegativeInnerProduct(DistanceBase): 243 | function = "VEC_NEGATIVE_INNER_PRODUCT" 244 | 245 | 246 | class VectorWidget(forms.TextInput): 247 | def format_value(self, value): 248 | if isinstance(value, np.ndarray): 249 | value = value.tolist() 250 | return super().format_value(value) 251 | 252 | 253 | class VectorFormField(forms.CharField): 254 | widget = VectorWidget 255 | 256 | def has_changed(self, initial, data): 257 | if isinstance(initial, np.ndarray): 258 | initial = initial.tolist() 259 | return super().has_changed(initial, data) 260 | -------------------------------------------------------------------------------- /django_tidb/introspection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | from collections import namedtuple 15 | 16 | from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo 17 | from django.db.backends.mysql.introspection import ( 18 | DatabaseIntrospection as MysqlDatabaseIntrospection, 19 | ) 20 | from django.db.models import Index 21 | from django.utils.datastructures import OrderedSet 22 | 23 | FieldInfo = namedtuple( 24 | "FieldInfo", 25 | BaseFieldInfo._fields 26 | + ("extra", "is_unsigned", "has_json_constraint", "comment", "data_type"), 27 | ) 28 | InfoLine = namedtuple( 29 | "InfoLine", 30 | "col_name data_type max_len num_prec num_scale extra column_default " 31 | "collation is_unsigned comment", 32 | ) 33 | 34 | 35 | class DatabaseIntrospection(MysqlDatabaseIntrospection): 36 | def get_table_description(self, cursor, table_name): 37 | """ 38 | Return a description of the table with the DB-API cursor.description 39 | interface." 40 | """ 41 | json_constraints = {} 42 | if self.connection.features.can_introspect_json_field: 43 | # JSON data type is an alias for LONGTEXT in MariaDB, select 44 | # JSON_VALID() constraints to introspect JSONField. 45 | cursor.execute( 46 | """ 47 | SELECT c.constraint_name AS column_name 48 | FROM information_schema.check_constraints AS c 49 | WHERE 50 | c.table_name = %s AND 51 | LOWER(c.check_clause) = 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND 52 | c.constraint_schema = DATABASE() 53 | """, 54 | [table_name], 55 | ) 56 | json_constraints = {row[0] for row in cursor.fetchall()} 57 | # A default collation for the given table. 58 | cursor.execute( 59 | """ 60 | SELECT table_collation 61 | FROM information_schema.tables 62 | WHERE table_schema = DATABASE() 63 | AND table_name = %s 64 | """, 65 | [table_name], 66 | ) 67 | row = cursor.fetchone() 68 | default_column_collation = row[0] if row else "" 69 | # information_schema database gives more accurate results for some figures: 70 | # - varchar length returned by cursor.description is an internal length, 71 | # not visible length (#5725) 72 | # - precision and scale (for decimal fields) (#5014) 73 | # - auto_increment is not available in cursor.description 74 | cursor.execute( 75 | """ 76 | SELECT 77 | column_name, data_type, character_maximum_length, 78 | numeric_precision, numeric_scale, extra, column_default, 79 | CASE 80 | WHEN collation_name = %s THEN NULL 81 | ELSE collation_name 82 | END AS collation_name, 83 | CASE 84 | WHEN column_type LIKE '%% unsigned' THEN 1 85 | ELSE 0 86 | END AS is_unsigned, 87 | column_comment 88 | FROM information_schema.columns 89 | WHERE table_name = %s AND table_schema = DATABASE() 90 | """, 91 | [default_column_collation, table_name], 92 | ) 93 | field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()} 94 | 95 | cursor.execute( 96 | "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name) 97 | ) 98 | 99 | def to_int(i): 100 | return int(i) if i is not None else i 101 | 102 | fields = [] 103 | for line in cursor.description: 104 | info = field_info[line[0]] 105 | fields.append( 106 | FieldInfo( 107 | *line[:2], 108 | to_int(info.max_len) or line[2], 109 | to_int(info.max_len) or line[3], 110 | to_int(info.num_prec) or line[4], 111 | to_int(info.num_scale) or line[5], 112 | line[6], 113 | info.column_default, 114 | info.collation, 115 | info.extra, 116 | info.is_unsigned, 117 | line[0] in json_constraints, 118 | info.comment, 119 | info.data_type, 120 | ) 121 | ) 122 | return fields 123 | 124 | def get_constraints(self, cursor, table_name): 125 | """ 126 | Retrieve any constraints or keys (unique, pk, fk, check, index) across 127 | one or more columns. 128 | """ 129 | constraints = {} 130 | # Get the actual constraint names and columns 131 | name_query = """ 132 | SELECT kc.`constraint_name`, kc.`column_name`, 133 | kc.`referenced_table_name`, kc.`referenced_column_name`, 134 | c.`constraint_type` 135 | FROM 136 | information_schema.key_column_usage AS kc, 137 | information_schema.table_constraints AS c 138 | WHERE 139 | kc.table_schema = DATABASE() AND 140 | c.table_schema = kc.table_schema AND 141 | c.constraint_name = kc.constraint_name AND 142 | c.constraint_type != 'CHECK' AND 143 | kc.table_name = %s 144 | ORDER BY kc.`ordinal_position` 145 | """ 146 | cursor.execute(name_query, [table_name]) 147 | for constraint, column, ref_table, ref_column, kind in cursor.fetchall(): 148 | if constraint not in constraints: 149 | constraints[constraint] = { 150 | "columns": OrderedSet(), 151 | "primary_key": kind == "PRIMARY KEY", 152 | "unique": kind in {"PRIMARY KEY", "UNIQUE"}, 153 | "index": False, 154 | "check": False, 155 | "foreign_key": (ref_table, ref_column) if ref_column else None, 156 | } 157 | if self.connection.features.supports_index_column_ordering: 158 | constraints[constraint]["orders"] = [] 159 | constraints[constraint]["columns"].add(column) 160 | # Add check constraints. 161 | if self.connection.features.can_introspect_check_constraints: 162 | unnamed_constraints_index = 0 163 | columns = { 164 | info.name for info in self.get_table_description(cursor, table_name) 165 | } 166 | type_query = """ 167 | SELECT cc.constraint_name, cc.check_clause 168 | FROM 169 | information_schema.check_constraints AS cc, 170 | information_schema.table_constraints AS tc 171 | WHERE 172 | cc.constraint_schema = DATABASE() AND 173 | tc.table_schema = cc.constraint_schema AND 174 | cc.constraint_name = tc.constraint_name AND 175 | tc.constraint_type = 'CHECK' AND 176 | tc.table_name = %s 177 | """ 178 | cursor.execute(type_query, [table_name]) 179 | for constraint, check_clause in cursor.fetchall(): 180 | constraint_columns = self._parse_constraint_columns( 181 | check_clause, columns 182 | ) 183 | # Ensure uniqueness of unnamed constraints. Unnamed unique 184 | # and check columns constraints have the same name as 185 | # a column. 186 | if set(constraint_columns) == {constraint}: 187 | unnamed_constraints_index += 1 188 | constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index 189 | constraints[constraint] = { 190 | "columns": constraint_columns, 191 | "primary_key": False, 192 | "unique": False, 193 | "index": False, 194 | "check": True, 195 | "foreign_key": None, 196 | } 197 | # Now add in the indexes 198 | cursor.execute( 199 | "SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name) 200 | ) 201 | for table, non_unique, index, colseq, column, order, type_ in [ 202 | x[:6] + (x[10],) for x in cursor.fetchall() 203 | ]: 204 | if index not in constraints: 205 | constraints[index] = { 206 | "columns": OrderedSet(), 207 | "primary_key": False, 208 | "unique": not non_unique, 209 | "check": False, 210 | "foreign_key": None, 211 | } 212 | if self.connection.features.supports_index_column_ordering: 213 | constraints[index]["orders"] = [] 214 | constraints[index]["index"] = True 215 | constraints[index]["type"] = ( 216 | Index.suffix if type_ == "BTREE" else type_.lower() 217 | ) 218 | constraints[index]["columns"].add(column) 219 | if self.connection.features.supports_index_column_ordering: 220 | constraints[index]["orders"].append("DESC" if order == "D" else "ASC") 221 | # Convert the sorted sets to lists 222 | for constraint in constraints.values(): 223 | constraint["columns"] = list(constraint["columns"]) 224 | return constraints 225 | -------------------------------------------------------------------------------- /django_tidb/operations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | from django.db.backends.mysql.operations import ( 15 | DatabaseOperations as MysqlDatabaseOperations, 16 | ) 17 | 18 | 19 | class DatabaseOperations(MysqlDatabaseOperations): 20 | integer_field_ranges = { 21 | **MysqlDatabaseOperations.integer_field_ranges, 22 | "BigAutoRandomField": (-9223372036854775808, 9223372036854775807), 23 | } 24 | 25 | def explain_query_prefix(self, format=None, **options): 26 | # Alias TiDB's "ROW" format to "TEXT" for consistency with other backends. 27 | if format and format.upper() == "TEXT": 28 | format = "ROW" 29 | elif not format: 30 | format = "ROW" 31 | 32 | # Check if the format is supported by TiDB. 33 | supported_formats = self.connection.features.supported_explain_formats 34 | normalized_format = format.upper() 35 | if normalized_format not in supported_formats: 36 | msg = "%s is not a recognized format." % normalized_format 37 | if supported_formats: 38 | msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats)) 39 | else: 40 | msg += f" {self.connection.display_name} does not support any formats." 41 | raise ValueError(msg) 42 | 43 | analyze = options.pop("analyze", False) 44 | if options: 45 | raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys()))) 46 | 47 | prefix = self.explain_prefix 48 | if analyze: 49 | prefix += " ANALYZE" 50 | prefix += ' FORMAT="%s"' % format 51 | return prefix 52 | 53 | def regex_lookup(self, lookup_type): 54 | # REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE 55 | # doesn't exist in MySQL 5.x or in MariaDB. 56 | if lookup_type == "regex": 57 | return "%s REGEXP BINARY %s" 58 | return "%s REGEXP %s" 59 | -------------------------------------------------------------------------------- /django_tidb/patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | from django.db.models.functions import Chr 15 | from django.db.models import options 16 | from django.db.migrations import state 17 | 18 | 19 | def char(self, compiler, connection, **extra_context): 20 | # TiDB doesn't support utf16 21 | return self.as_sql( 22 | compiler, 23 | connection, 24 | function="CHAR", 25 | template="%(function)s(%(expressions)s USING utf8mb4)", 26 | **extra_context, 27 | ) 28 | 29 | 30 | def patch_model_functions(): 31 | Chr.as_mysql = char 32 | 33 | 34 | def patch_model_options(): 35 | # Patch `tidb_auto_id_cache` to options.DEFAULT_NAMES, 36 | # so that user can define it in model's Meta class. 37 | options.DEFAULT_NAMES += ("tidb_auto_id_cache",) 38 | # Because Django named import DEFAULT_NAMES in migrations, 39 | # so we need to patch it again here. 40 | # Django will record `tidb_auto_id_cache` in migration files, 41 | # and then restore it when applying migrations. 42 | state.DEFAULT_NAMES += ("tidb_auto_id_cache",) 43 | 44 | 45 | def monkey_patch(): 46 | patch_model_functions() 47 | patch_model_options() 48 | -------------------------------------------------------------------------------- /django_tidb/schema.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | from django.db.backends.mysql.schema import ( 15 | DatabaseSchemaEditor as MysqlDatabaseSchemaEditor, 16 | ) 17 | 18 | 19 | class DatabaseSchemaEditor(MysqlDatabaseSchemaEditor): 20 | # Unsupported add column and foreign key in single statement 21 | # https://github.com/pingcap/tidb/issues/45474 22 | sql_create_column_inline_fk = None 23 | 24 | @property 25 | def sql_delete_check(self): 26 | return "ALTER TABLE %(table)s DROP CHECK %(name)s" 27 | 28 | @property 29 | def sql_rename_column(self): 30 | return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s" 31 | 32 | def skip_default_on_alter(self, field): 33 | if self._is_limited_data_type(field): 34 | # TiDB doesn't support defaults for BLOB/TEXT/JSON in the 35 | # ALTER COLUMN statement. 36 | return True 37 | return False 38 | 39 | @property 40 | def _supports_limited_data_type_defaults(self): 41 | return False 42 | 43 | def _field_should_be_indexed(self, model, field): 44 | if not field.db_index or field.unique: 45 | return False 46 | # No need to create an index for ForeignKey fields except if 47 | # db_constraint=False because the index from that constraint won't be 48 | # created. 49 | if field.get_internal_type() == "ForeignKey" and field.db_constraint: 50 | return False 51 | return not self._is_limited_data_type(field) 52 | 53 | def add_field(self, model, field): 54 | if field._unique: 55 | # TiDB does not support multiple operations with a single DDL statement, 56 | # so we need to execute the unique constraint creation separately. 57 | field._unique = False 58 | # Django set `cached_property` decorator for `unique` property, 59 | # so we need to clear the cached value. 60 | if "unique" in field.__dict__: 61 | del field.unique 62 | super().add_field(model, field) 63 | field._unique = True 64 | if "unique" in field.__dict__: 65 | del field.unique 66 | self.execute(self._create_unique_sql(model, [field])) 67 | else: 68 | super().add_field(model, field) 69 | 70 | def table_sql(self, model): 71 | sql, params = super().table_sql(model) 72 | tidb_auto_id_cache = getattr(model._meta, "tidb_auto_id_cache", None) 73 | if tidb_auto_id_cache is not None: 74 | sql += " AUTO_ID_CACHE %s" % tidb_auto_id_cache 75 | return sql, params 76 | -------------------------------------------------------------------------------- /django_tidb/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | 15 | # TiDBVersion deal with tidb's version string. 16 | # Our tidb version string is got from ```select version();``` 17 | # it look like this: 18 | # 5.7.25-TiDB-v5.1.0-64-gfb0eaf7b4 19 | # or 5.7.25-TiDB-v5.2.0-alpha-385-g0f0b06ab5 20 | class TiDBVersion: 21 | _version = (0, 0, 0) 22 | 23 | def match(self, version): 24 | version_list = version.split("-") 25 | if len(version_list) < 3: 26 | return False 27 | tidb_version_list = version_list[2].lstrip("v").split(".") 28 | self._version = tuple(int(x) for x in tidb_version_list) 29 | return True 30 | 31 | @property 32 | def version(self): 33 | return self._version 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "django-tidb" 7 | authors = [ 8 | { name="Xiang Zhang", email="zhangxiang02@pingcap.com" }, 9 | { name="Di Wang", email="wangdi@pingcap.com" } 10 | ] 11 | description = "Django backend for TiDB" 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | classifiers = [ 15 | "Development Status :: 5 - Production/Stable", 16 | "Framework :: Django", 17 | "Framework :: Django :: 5.2", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: OS Independent", 20 | "Programming Language :: Python", 21 | "Programming Language :: Python :: 3", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Programming Language :: Python :: 3.12", 25 | "Programming Language :: Python :: 3.13", 26 | ] 27 | dynamic = ["version"] 28 | 29 | [project.urls] 30 | "Homepage" = "https://github.com/pingcap/tidb" 31 | "Bug Reports" = "https://github.com/pingcap/django-tidb/issues" 32 | "Source" = "https://github.com/pingcap/django-tidb" 33 | 34 | [project.optional-dependencies] 35 | vector = ["numpy~=1.0"] 36 | 37 | [tool.setuptools] 38 | packages = ["django_tidb", "django_tidb.fields"] 39 | 40 | [tool.setuptools.dynamic] 41 | version = {attr = "django_tidb.__version__"} -------------------------------------------------------------------------------- /run_testing_worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2020 Google LLC. 4 | 5 | # Use of this source code is governed by a BSD-style 6 | # license that can be found in the LICENSE file or at 7 | # https://developers.google.com/open-source/licenses/bsd 8 | 9 | # NOTE: The code in this file is based on code from the 10 | # googleapis/python-spanner-django project, licensed under BSD 11 | # 12 | # https://github.com/googleapis/python-spanner-django/blob/0544208d6f9ef81b290cf5c4ee304ba0ec0e95c4/run_testing_worker.py 13 | # 14 | 15 | # Copyright 2021 PingCAP, Inc. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # See the License for the specific language governing permissions and 26 | # limitations under the License. 27 | 28 | import os 29 | 30 | with open("django_test_apps.txt", "r") as file: 31 | all_apps = file.read().split("\n") 32 | 33 | print("test apps: ", all_apps) 34 | 35 | if not all_apps: 36 | exit() 37 | 38 | exitcode = os.WEXITSTATUS( 39 | os.system( 40 | """DJANGO_TEST_APPS="{apps}" bash ./django_test_suite.sh""".format( 41 | apps=" ".join(all_apps) 42 | ) 43 | ) 44 | ) 45 | exit(exitcode) 46 | -------------------------------------------------------------------------------- /tests/tidb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pingcap/django-tidb/d149d1014eb138c006638ecb1ad4c144aab3fefb/tests/tidb/__init__.py -------------------------------------------------------------------------------- /tests/tidb/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | from django_tidb.fields import BigAutoRandomField 4 | 5 | 6 | class Course(models.Model): 7 | name = models.CharField(max_length=100) 8 | 9 | 10 | class BigAutoRandomModel(models.Model): 11 | value = BigAutoRandomField(primary_key=True) 12 | tag = models.CharField(max_length=100, blank=True, null=True) 13 | 14 | 15 | class BigAutoRandomExplicitInsertModel(models.Model): 16 | value = BigAutoRandomField(primary_key=True) 17 | tag = models.CharField(max_length=100, blank=True, null=True) 18 | -------------------------------------------------------------------------------- /tests/tidb/test_tidb_auto_id_cache.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from django.db import models, connection 4 | from django.db.utils import ProgrammingError 5 | from django.test import TransactionTestCase 6 | from django.test.utils import isolate_apps 7 | 8 | 9 | AUTO_ID_CACHE_PATTERN = re.compile(r"\/\*T!\[auto_id_cache\] AUTO_ID_CACHE=(\d+) \*\/") 10 | 11 | 12 | class TiDBAutoIDCacheTests(TransactionTestCase): 13 | available_apps = ["tidb"] 14 | 15 | def get_auto_id_cache_info(self, table): 16 | with connection.cursor() as cursor: 17 | cursor.execute( 18 | # It seems that SHOW CREATE TABLE is the only way to get the auto_random info. 19 | # Use parameterized query will add quotes to the table name, which will cause syntax error. 20 | f"SHOW CREATE TABLE {table}", 21 | ) 22 | row = cursor.fetchone() 23 | if row is None: 24 | return None 25 | match = AUTO_ID_CACHE_PATTERN.search(row[1]) 26 | if match: 27 | return match.groups()[0] 28 | return None 29 | 30 | @isolate_apps("tidb") 31 | def test_create_table_with_tidb_auto_id_cache_1(self): 32 | class AutoIDCacheNode1(models.Model): 33 | title = models.CharField(max_length=255) 34 | 35 | class Meta: 36 | app_label = "tidb" 37 | tidb_auto_id_cache = 1 38 | 39 | with connection.schema_editor() as editor: 40 | editor.create_model(AutoIDCacheNode1) 41 | self.assertEqual( 42 | self.get_auto_id_cache_info(AutoIDCacheNode1._meta.db_table), "1" 43 | ) 44 | 45 | @isolate_apps("tidb") 46 | def test_create_table_with_tidb_auto_id_cache_non_1(self): 47 | class AutoIDCacheNode2(models.Model): 48 | title = models.CharField(max_length=255) 49 | 50 | class Meta: 51 | app_label = "tidb" 52 | tidb_auto_id_cache = 10 53 | 54 | with connection.schema_editor() as editor: 55 | editor.create_model(AutoIDCacheNode2) 56 | self.assertEqual( 57 | self.get_auto_id_cache_info(AutoIDCacheNode2._meta.db_table), "10" 58 | ) 59 | 60 | @isolate_apps("tidb") 61 | def test_create_table_with_invalid_tidb_auto_id_cache(self): 62 | class AutoIDCacheNode3(models.Model): 63 | title = models.CharField(max_length=255) 64 | 65 | class Meta: 66 | app_label = "tidb" 67 | tidb_auto_id_cache = "invalid" 68 | 69 | with self.assertRaises(ProgrammingError): 70 | with connection.schema_editor() as editor: 71 | editor.create_model(AutoIDCacheNode3) 72 | 73 | @isolate_apps("tidb") 74 | def test_create_table_without_tidb_auto_id_cache(self): 75 | class AutoIDCacheNode4(models.Model): 76 | title = models.CharField(max_length=255) 77 | 78 | class Meta: 79 | app_label = "tidb" 80 | 81 | with connection.schema_editor() as editor: 82 | editor.create_model(AutoIDCacheNode4) 83 | self.assertIsNone(self.get_auto_id_cache_info(AutoIDCacheNode4._meta.db_table)) 84 | -------------------------------------------------------------------------------- /tests/tidb/test_tidb_auto_random.py: -------------------------------------------------------------------------------- 1 | import re 2 | from contextlib import contextmanager 3 | 4 | from django.db import models, connection 5 | from django.core import checks, validators 6 | from django.core.exceptions import ValidationError 7 | from django.test import TestCase, TransactionTestCase, override_settings 8 | from django.test.utils import isolate_apps 9 | 10 | from django_tidb.fields import BigAutoRandomField 11 | from .models import BigAutoRandomModel, BigAutoRandomExplicitInsertModel 12 | 13 | 14 | class TiDBBigAutoRandomFieldTests(TestCase): 15 | model = BigAutoRandomModel 16 | explicit_insert_model = BigAutoRandomExplicitInsertModel 17 | documented_range = (-9223372036854775808, 9223372036854775807) 18 | rel_db_type_class = BigAutoRandomField 19 | 20 | @contextmanager 21 | def explicit_insert_allowed(self): 22 | with connection.cursor() as cursor: 23 | cursor.execute("SET @@allow_auto_random_explicit_insert=true") 24 | yield 25 | cursor.execute("SET @@allow_auto_random_explicit_insert=false") 26 | 27 | @property 28 | def backend_range(self): 29 | field = self.model._meta.get_field("value") 30 | internal_type = field.get_internal_type() 31 | return connection.ops.integer_field_range(internal_type) 32 | 33 | def test_documented_range(self): 34 | """ 35 | Values within the documented safe range pass validation, and can be 36 | saved and retrieved without corruption. 37 | """ 38 | min_value, max_value = self.documented_range 39 | 40 | with self.explicit_insert_allowed(): 41 | instance = self.explicit_insert_model(value=min_value) 42 | instance.full_clean() 43 | instance.save() 44 | qs = self.explicit_insert_model.objects.filter(value__lte=min_value) 45 | self.assertEqual(qs.count(), 1) 46 | self.assertEqual(qs[0].value, min_value) 47 | 48 | instance = self.explicit_insert_model(value=max_value) 49 | instance.full_clean() 50 | instance.save() 51 | qs = self.explicit_insert_model.objects.filter(value__gte=max_value) 52 | self.assertEqual(qs.count(), 1) 53 | self.assertEqual(qs[0].value, max_value) 54 | 55 | def test_backend_range_save(self): 56 | """ 57 | Backend specific ranges can be saved without corruption. 58 | """ 59 | min_value, max_value = self.backend_range 60 | with self.explicit_insert_allowed(): 61 | if min_value is not None: 62 | instance = self.explicit_insert_model(value=min_value) 63 | instance.full_clean() 64 | instance.save() 65 | qs = self.explicit_insert_model.objects.filter(value__lte=min_value) 66 | self.assertEqual(qs.count(), 1) 67 | self.assertEqual(qs[0].value, min_value) 68 | 69 | if max_value is not None: 70 | instance = self.explicit_insert_model(value=max_value) 71 | instance.full_clean() 72 | instance.save() 73 | qs = self.explicit_insert_model.objects.filter(value__gte=max_value) 74 | self.assertEqual(qs.count(), 1) 75 | self.assertEqual(qs[0].value, max_value) 76 | 77 | def test_backend_range_validation(self): 78 | """ 79 | Backend specific ranges are enforced at the model validation level 80 | (#12030). 81 | """ 82 | min_value, max_value = self.backend_range 83 | 84 | if min_value is not None: 85 | instance = self.model(value=min_value - 1) 86 | expected_message = validators.MinValueValidator.message % { 87 | "limit_value": min_value, 88 | } 89 | with self.assertRaisesMessage(ValidationError, expected_message): 90 | instance.full_clean() 91 | instance.value = min_value 92 | instance.full_clean() 93 | 94 | if max_value is not None: 95 | instance = self.model(value=max_value + 1) 96 | expected_message = validators.MaxValueValidator.message % { 97 | "limit_value": max_value, 98 | } 99 | with self.assertRaisesMessage(ValidationError, expected_message): 100 | instance.full_clean() 101 | instance.value = max_value 102 | instance.full_clean() 103 | 104 | def test_redundant_backend_range_validators(self): 105 | """ 106 | If there are stricter validators than the ones from the database 107 | backend then the backend validators aren't added. 108 | """ 109 | min_backend_value, max_backend_value = self.backend_range 110 | 111 | for callable_limit in (True, False): 112 | with self.subTest(callable_limit=callable_limit): 113 | if min_backend_value is not None: 114 | min_custom_value = min_backend_value + 1 115 | limit_value = ( 116 | (lambda: min_custom_value) 117 | if callable_limit 118 | else min_custom_value 119 | ) 120 | ranged_value_field = self.model._meta.get_field("value").__class__( 121 | validators=[validators.MinValueValidator(limit_value)] 122 | ) 123 | field_range_message = validators.MinValueValidator.message % { 124 | "limit_value": min_custom_value, 125 | } 126 | with self.assertRaisesMessage( 127 | ValidationError, "[%r]" % field_range_message 128 | ): 129 | ranged_value_field.run_validators(min_backend_value - 1) 130 | 131 | if max_backend_value is not None: 132 | max_custom_value = max_backend_value - 1 133 | limit_value = ( 134 | (lambda: max_custom_value) 135 | if callable_limit 136 | else max_custom_value 137 | ) 138 | ranged_value_field = self.model._meta.get_field("value").__class__( 139 | validators=[validators.MaxValueValidator(limit_value)] 140 | ) 141 | field_range_message = validators.MaxValueValidator.message % { 142 | "limit_value": max_custom_value, 143 | } 144 | with self.assertRaisesMessage( 145 | ValidationError, "[%r]" % field_range_message 146 | ): 147 | ranged_value_field.run_validators(max_backend_value + 1) 148 | 149 | def test_types(self): 150 | instance = self.model(tag="a") 151 | instance.save() 152 | self.assertIsInstance(instance.value, int) 153 | instance = self.model.objects.get() 154 | self.assertIsInstance(instance.value, int) 155 | 156 | def test_invalid_value(self): 157 | tests = [ 158 | (TypeError, ()), 159 | (TypeError, []), 160 | (TypeError, {}), 161 | (TypeError, set()), 162 | (TypeError, object()), 163 | (TypeError, complex()), 164 | (ValueError, "non-numeric string"), 165 | (ValueError, b"non-numeric byte-string"), 166 | ] 167 | for exception, value in tests: 168 | with self.subTest(value): 169 | msg = "Field 'value' expected a number but got %r." % (value,) 170 | with self.assertRaisesMessage(exception, msg): 171 | self.explicit_insert_model.objects.create(value=value) 172 | 173 | def test_rel_db_type(self): 174 | field = self.model._meta.get_field("value") 175 | rel_db_type = field.rel_db_type(connection) 176 | # Currently, We can't find a general way to get the auto_random info from the field. 177 | self.assertEqual(rel_db_type, "bigint") 178 | if connection.tidb_version < (6, 3): 179 | self.assertEqual( 180 | self.rel_db_type_class().db_type(connection), "bigint AUTO_RANDOM(5)" 181 | ) 182 | else: 183 | self.assertEqual( 184 | self.rel_db_type_class().db_type(connection), 185 | "bigint AUTO_RANDOM(5, 64)", 186 | ) 187 | 188 | 189 | AUTO_RANDOM_PATTERN = re.compile( 190 | r"\/\*T!\[auto_rand\] AUTO_RANDOM\((\d+)(?:, (\d+))?\) \*\/" 191 | ) 192 | 193 | 194 | class TiDBAutoRandomMigrateTests(TransactionTestCase): 195 | available_apps = ["tidb"] 196 | 197 | def get_primary_key(self, table): 198 | with connection.cursor() as cursor: 199 | primary_key_columns = connection.introspection.get_primary_key_columns( 200 | cursor, table 201 | ) 202 | return primary_key_columns[0] if primary_key_columns else None 203 | 204 | def get_auto_random_info(self, table): 205 | # return (shard_bits, range) 206 | with connection.cursor() as cursor: 207 | cursor.execute( 208 | # It seems that SHOW CREATE TABLE is the only way to get the auto_random info. 209 | # Use parameterized query will add quotes to the table name, which will cause syntax error. 210 | f"SHOW CREATE TABLE {table}", 211 | ) 212 | row = cursor.fetchone() 213 | if row is None: 214 | return (None, None) 215 | for line in row[1].splitlines(): 216 | match = AUTO_RANDOM_PATTERN.search(line) 217 | if match: 218 | return match.groups() 219 | return (None, None) 220 | 221 | @isolate_apps("tidb") 222 | @override_settings(DEFAULT_AUTO_FIELD="django_tidb.fields.BigAutoRandomField") 223 | def test_create_table_with_default_auto_field(self): 224 | class AutoRandomNode1(models.Model): 225 | title = models.CharField(max_length=255) 226 | 227 | class Meta: 228 | app_label = "tidb" 229 | 230 | with connection.schema_editor() as editor: 231 | editor.create_model(AutoRandomNode1) 232 | self.assertEqual(self.get_primary_key(AutoRandomNode1._meta.db_table), "id") 233 | self.assertIsInstance(AutoRandomNode1._meta.pk, BigAutoRandomField) 234 | self.assertEqual( 235 | self.get_auto_random_info(AutoRandomNode1._meta.db_table), ("5", None) 236 | ) 237 | 238 | @isolate_apps("tidb") 239 | def test_create_table_explicit_auto_random_field(self): 240 | class AutoRandomNode2(models.Model): 241 | id = BigAutoRandomField(primary_key=True) 242 | title = models.CharField(max_length=255) 243 | 244 | class Meta: 245 | app_label = "tidb" 246 | 247 | with connection.schema_editor() as editor: 248 | editor.create_model(AutoRandomNode2) 249 | self.assertEqual(self.get_primary_key(AutoRandomNode2._meta.db_table), "id") 250 | self.assertIsInstance(AutoRandomNode2._meta.pk, BigAutoRandomField) 251 | self.assertEqual( 252 | self.get_auto_random_info(AutoRandomNode2._meta.db_table), ("5", None) 253 | ) 254 | 255 | @isolate_apps("tidb") 256 | def test_create_table_explicit_auto_random_field_with_shard_bits(self): 257 | class AutoRandomNode3(models.Model): 258 | id = BigAutoRandomField(primary_key=True, shard_bits=10) 259 | title = models.CharField(max_length=255) 260 | 261 | class Meta: 262 | app_label = "tidb" 263 | 264 | with connection.schema_editor() as editor: 265 | editor.create_model(AutoRandomNode3) 266 | self.assertEqual(self.get_primary_key(AutoRandomNode3._meta.db_table), "id") 267 | self.assertIsInstance(AutoRandomNode3._meta.pk, BigAutoRandomField) 268 | self.assertEqual( 269 | self.get_auto_random_info(AutoRandomNode3._meta.db_table), ("10", None) 270 | ) 271 | 272 | @isolate_apps("tidb") 273 | def test_create_table_explicit_auto_random_field_with_shard_bits_and_range(self): 274 | class AutoRandomNode4(models.Model): 275 | id = BigAutoRandomField(primary_key=True, shard_bits=10, range=60) 276 | title = models.CharField(max_length=255) 277 | 278 | class Meta: 279 | app_label = "tidb" 280 | 281 | with connection.schema_editor() as editor: 282 | editor.create_model(AutoRandomNode4) 283 | self.assertEqual(self.get_primary_key(AutoRandomNode4._meta.db_table), "id") 284 | self.assertIsInstance(AutoRandomNode4._meta.pk, BigAutoRandomField) 285 | self.assertEqual( 286 | self.get_auto_random_info(AutoRandomNode4._meta.db_table), ("10", "60") 287 | ) 288 | 289 | @isolate_apps("tidb") 290 | def test_create_table_explicit_auto_random_field_with_range(self): 291 | class AutoRandomNode5(models.Model): 292 | id = BigAutoRandomField(primary_key=True, range=60) 293 | title = models.CharField(max_length=255) 294 | 295 | class Meta: 296 | app_label = "tidb" 297 | 298 | with connection.schema_editor() as editor: 299 | editor.create_model(AutoRandomNode5) 300 | self.assertEqual(self.get_primary_key(AutoRandomNode5._meta.db_table), "id") 301 | self.assertIsInstance(AutoRandomNode5._meta.pk, BigAutoRandomField) 302 | self.assertEqual( 303 | self.get_auto_random_info(AutoRandomNode5._meta.db_table), ("5", "60") 304 | ) 305 | 306 | @isolate_apps("tidb") 307 | def test_create_table_explicit_auto_random_field_with_invalid_range(self): 308 | class AutoRandomNode6(models.Model): 309 | id = BigAutoRandomField(primary_key=True, range=31) 310 | 311 | class Meta: 312 | app_label = "tidb" 313 | 314 | id = AutoRandomNode6._meta.get_field("id") 315 | self.assertEqual( 316 | id.check(), 317 | [ 318 | checks.Error( 319 | "BigAutoRandomField 'range' attribute must be an integer between 32 and 64.", 320 | obj=id, 321 | ) 322 | ], 323 | ) 324 | 325 | class AutoRandomNode7(models.Model): 326 | id = BigAutoRandomField(primary_key=True, range=None) 327 | 328 | class Meta: 329 | app_label = "tidb" 330 | 331 | id = AutoRandomNode7._meta.get_field("id") 332 | self.assertEqual( 333 | id.check(), 334 | [ 335 | checks.Error( 336 | "BigAutoRandomField must define a 'range' attribute.", 337 | obj=id, 338 | ) 339 | ], 340 | ) 341 | 342 | @isolate_apps("tidb") 343 | def test_create_table_explicit_auto_random_field_with_invalid_shard_bits(self): 344 | class AutoRandomNode8(models.Model): 345 | id = BigAutoRandomField(primary_key=True, shard_bits=16) 346 | 347 | class Meta: 348 | app_label = "tidb" 349 | 350 | id = AutoRandomNode8._meta.get_field("id") 351 | self.assertEqual( 352 | id.check(), 353 | [ 354 | checks.Error( 355 | "BigAutoRandomField 'shard_bits' attribute must be an integer between 1 and 15.", 356 | obj=id, 357 | ) 358 | ], 359 | ) 360 | 361 | class AutoRandomNode9(models.Model): 362 | id = BigAutoRandomField(primary_key=True, shard_bits=None) 363 | 364 | class Meta: 365 | app_label = "tidb" 366 | 367 | id = AutoRandomNode9._meta.get_field("id") 368 | self.assertEqual( 369 | id.check(), 370 | [ 371 | checks.Error( 372 | "BigAutoRandomField must define a 'shard_bits' attribute.", 373 | obj=id, 374 | ) 375 | ], 376 | ) 377 | -------------------------------------------------------------------------------- /tests/tidb/test_tidb_ddl.py: -------------------------------------------------------------------------------- 1 | from django.test import TransactionTestCase 2 | from django.test.utils import isolate_apps 3 | from django.db import models, connection 4 | 5 | 6 | class TiDBDDLTests(TransactionTestCase): 7 | available_apps = ["tidb"] 8 | 9 | def get_indexes(self, table): 10 | """ 11 | Get the indexes on the table using a new cursor. 12 | """ 13 | with connection.cursor() as cursor: 14 | return [ 15 | c["columns"][0] 16 | for c in connection.introspection.get_constraints( 17 | cursor, table 18 | ).values() 19 | if c["index"] and len(c["columns"]) == 1 20 | ] 21 | 22 | def get_uniques(self, table): 23 | with connection.cursor() as cursor: 24 | return [ 25 | c["columns"][0] 26 | for c in connection.introspection.get_constraints( 27 | cursor, table 28 | ).values() 29 | if c["unique"] and len(c["columns"]) == 1 30 | ] 31 | 32 | @isolate_apps("tidb") 33 | def test_should_create_db_index(self): 34 | # When define a model with db_index=True, TiDB should create a db index 35 | class Tag(models.Model): 36 | title = models.CharField(max_length=255, db_index=True) 37 | 38 | class Meta: 39 | app_label = "tidb" 40 | 41 | with connection.schema_editor() as editor: 42 | editor.create_model(Tag) 43 | self.assertIn("title", self.get_indexes("tidb_tag")) 44 | 45 | new_field = models.CharField(max_length=255, db_index=True) 46 | new_field.set_attributes_from_name("new_field") 47 | with connection.schema_editor() as editor: 48 | editor.add_field(Tag, new_field) 49 | self.assertIn("new_field", self.get_indexes("tidb_tag")) 50 | 51 | @isolate_apps("tidb") 52 | def test_should_create_db_index_for_foreign_key_with_no_db_constraint(self): 53 | # When define a model with ForeignKey, TiDB should not create a db index 54 | class Node1(models.Model): 55 | title = models.CharField(max_length=255) 56 | 57 | class Meta: 58 | app_label = "tidb" 59 | 60 | class Node2(models.Model): 61 | node1 = models.ForeignKey( 62 | Node1, on_delete=models.CASCADE, db_constraint=False 63 | ) 64 | 65 | class Meta: 66 | app_label = "tidb" 67 | 68 | with connection.schema_editor() as editor: 69 | editor.create_model(Node1) 70 | editor.create_model(Node2) 71 | 72 | self.assertIn("node1_id", self.get_indexes("tidb_node2")) 73 | 74 | @isolate_apps("tidb") 75 | def test_add_unique_field(self): 76 | # issue: https://github.com/pingcap/django-tidb/issues/48 77 | class Node3(models.Model): 78 | title = models.CharField(max_length=255, unique=True) 79 | 80 | class Meta: 81 | app_label = "tidb" 82 | 83 | with connection.schema_editor() as editor: 84 | editor.create_model(Node3) 85 | self.assertIn("title", self.get_uniques("tidb_node3")) 86 | 87 | new_field = models.CharField(max_length=255, unique=True) 88 | new_field.set_attributes_from_name("new_field") 89 | with connection.schema_editor() as editor: 90 | editor.add_field(Node3, new_field) 91 | self.assertIn("new_field", self.get_uniques("tidb_node3")) 92 | 93 | parent = models.OneToOneField(Node3, models.CASCADE) 94 | parent.set_attributes_from_name("parent") 95 | with connection.schema_editor() as editor: 96 | editor.add_field(Node3, parent) 97 | self.assertIn("parent_id", self.get_uniques("tidb_node3")) 98 | -------------------------------------------------------------------------------- /tests/tidb/test_tidb_explain.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from django.test import TestCase 4 | from django.test.utils import CaptureQueriesContext 5 | from django.db import connection, transaction 6 | 7 | from .models import Course 8 | 9 | 10 | class TiDBExplainTests(TestCase): 11 | SUPPORTED_FORMATS = {"TRADITIONAL", "ROW", "BRIEF", "DOT", "TIDB_JSON"} 12 | 13 | def test_explain_with_supported_format(self): 14 | for format in self.SUPPORTED_FORMATS: 15 | with self.subTest(format=format), transaction.atomic(): 16 | with CaptureQueriesContext(connection) as captured_queries: 17 | result = Course.objects.filter(name="test").explain(format=format) 18 | self.assertTrue( 19 | captured_queries[0]["sql"].startswith( 20 | connection.ops.explain_prefix + f' FORMAT="{format}"' 21 | ) 22 | ) 23 | if format == "TIDB_JSON": 24 | try: 25 | json.loads(result) 26 | except json.JSONDecodeError as e: 27 | self.fail( 28 | f"QuerySet.explain() result is not valid JSON: {e}" 29 | ) 30 | 31 | def test_explain_analyze_with_supported_format(self): 32 | for format in self.SUPPORTED_FORMATS: 33 | with self.subTest(format=format), transaction.atomic(): 34 | with CaptureQueriesContext(connection) as captured_queries: 35 | result = Course.objects.filter(name="test").explain( 36 | analyze=True, format=format 37 | ) 38 | self.assertTrue( 39 | captured_queries[0]["sql"].startswith( 40 | connection.ops.explain_prefix 41 | + f' ANALYZE FORMAT="{format}"' 42 | ) 43 | ) 44 | if format == "TIDB_JSON": 45 | try: 46 | json.loads(result) 47 | except json.JSONDecodeError as e: 48 | self.fail( 49 | f"QuerySet.explain() result is not valid JSON: {e}" 50 | ) 51 | 52 | def test_explain_with_unsupported_format(self): 53 | with self.assertRaises(ValueError): 54 | Course.objects.filter(name="test").explain(format="JSON") 55 | 56 | def test_explain_analyze_with_unsupported_format(self): 57 | with self.assertRaises(ValueError): 58 | Course.objects.filter(name="test").explain(analyze=True, format="JSON") 59 | 60 | def test_explain_with_unkonwn_option(self): 61 | with self.assertRaises(ValueError): 62 | Course.objects.filter(name="test").explain(unknown_option=True) 63 | 64 | def test_explain_with_default_params(self): 65 | with transaction.atomic(): 66 | with CaptureQueriesContext(connection) as captured_queries: 67 | Course.objects.filter(name="test").explain() 68 | self.assertTrue( 69 | captured_queries[0]["sql"].startswith( 70 | connection.ops.explain_prefix + ' FORMAT="ROW"' 71 | ) 72 | ) 73 | -------------------------------------------------------------------------------- /tests/tidb_field_defaults/README.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | This test is copied from the Django [field_defaults](https://github.com/django/django/tree/main/tests/field_defaults), as TiDB has some [limitations](https://docs.pingcap.com/tidb/dev/data-type-default-values#specify-expressions-as-default-values) on the default expression of the field, it does not support such many expressions as MySQL. 4 | -------------------------------------------------------------------------------- /tests/tidb_field_defaults/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pingcap/django-tidb/d149d1014eb138c006638ecb1ad4c144aab3fefb/tests/tidb_field_defaults/__init__.py -------------------------------------------------------------------------------- /tests/tidb_field_defaults/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Callable defaults 3 | 4 | You can pass callable objects as the ``default`` parameter to a field. When 5 | the object is created without an explicit value passed in, Django will call 6 | the method to determine the default value. 7 | 8 | This example uses ``datetime.datetime.now`` as the default for the ``pub_date`` 9 | field. 10 | """ 11 | 12 | from datetime import datetime 13 | from decimal import Decimal 14 | 15 | from django.db import models 16 | from django.db.models.functions import Random, Now 17 | 18 | 19 | class Article(models.Model): 20 | headline = models.CharField(max_length=100, default="Default headline") 21 | pub_date = models.DateTimeField(default=datetime.now) 22 | 23 | def __str__(self): 24 | return self.headline 25 | 26 | 27 | class DBArticle(models.Model): 28 | """ 29 | Values or expressions can be passed as the db_default parameter to a field. 30 | When the object is created without an explicit value passed in, the 31 | database will insert the default value automatically. 32 | """ 33 | 34 | headline = models.CharField(max_length=100, db_default="Default headline") 35 | pub_date = models.DateTimeField(db_default=Now()) 36 | cost = models.DecimalField( 37 | max_digits=3, decimal_places=2, db_default=Decimal("3.33") 38 | ) 39 | 40 | class Meta: 41 | required_db_features = {"supports_expression_defaults"} 42 | 43 | 44 | class DBDefaults(models.Model): 45 | both = models.IntegerField(default=1, db_default=2) 46 | null = models.FloatField(null=True, db_default=1.1) 47 | 48 | 49 | # This model has too many db_default expressions that TiDB does not support 50 | # class DBDefaultsFunction(models.Model): 51 | # number = models.FloatField(db_default=Pi()) 52 | # year = models.IntegerField(db_default=ExtractYear(Now())) 53 | # added = models.FloatField(db_default=Pi() + 4.5) 54 | # multiple_subfunctions = models.FloatField(db_default=Coalesce(4.5, Pi())) 55 | # case_when = models.IntegerField( 56 | # db_default=models.Case(models.When(GreaterThan(2, 1), then=3), default=4) 57 | # ) 58 | 59 | # class Meta: 60 | # required_db_features = {"supports_expression_defaults"} 61 | 62 | 63 | class TiDBDefaultsFunction(models.Model): 64 | number = models.DecimalField(max_digits=3, decimal_places=2, db_default=Random()) 65 | created_at = models.DateTimeField(db_default=Now()) 66 | 67 | 68 | class DBDefaultsPK(models.Model): 69 | language_code = models.CharField(primary_key=True, max_length=2, db_default="en") 70 | 71 | 72 | class DBDefaultsFK(models.Model): 73 | language_code = models.ForeignKey( 74 | DBDefaultsPK, db_default="fr", on_delete=models.CASCADE 75 | ) 76 | -------------------------------------------------------------------------------- /tests/tidb_field_defaults/tests.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from decimal import Decimal 3 | 4 | from django.core.exceptions import ValidationError 5 | from django.db import connection 6 | from django.db.models import Case, F, FloatField, Value, When 7 | from django.db.models.expressions import ( 8 | Expression, 9 | ExpressionList, 10 | ExpressionWrapper, 11 | Func, 12 | OrderByList, 13 | RawSQL, 14 | ) 15 | from django.db.models.functions import Collate 16 | from django.db.models.lookups import GreaterThan 17 | from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature 18 | 19 | from .models import ( 20 | Article, 21 | DBArticle, 22 | DBDefaults, 23 | DBDefaultsFK, 24 | # DBDefaultsFunction, 25 | TiDBDefaultsFunction, 26 | DBDefaultsPK, 27 | ) 28 | 29 | 30 | class DefaultTests(TestCase): 31 | def test_field_defaults(self): 32 | a = Article() 33 | now = datetime.now() 34 | a.save() 35 | 36 | self.assertIsInstance(a.id, int) 37 | self.assertEqual(a.headline, "Default headline") 38 | self.assertLess((now - a.pub_date).seconds, 5) 39 | 40 | @skipUnlessDBFeature( 41 | "can_return_columns_from_insert", "supports_expression_defaults" 42 | ) 43 | def test_field_db_defaults_returning(self): 44 | a = DBArticle() 45 | a.save() 46 | self.assertIsInstance(a.id, int) 47 | self.assertEqual(a.headline, "Default headline") 48 | self.assertIsInstance(a.pub_date, datetime) 49 | self.assertEqual(a.cost, Decimal("3.33")) 50 | 51 | @skipIfDBFeature("can_return_columns_from_insert") 52 | @skipUnlessDBFeature("supports_expression_defaults") 53 | def test_field_db_defaults_refresh(self): 54 | a = DBArticle() 55 | a.save() 56 | a.refresh_from_db() 57 | self.assertIsInstance(a.id, int) 58 | self.assertEqual(a.headline, "Default headline") 59 | self.assertIsInstance(a.pub_date, datetime) 60 | self.assertEqual(a.cost, Decimal("3.33")) 61 | 62 | def test_null_db_default(self): 63 | obj1 = DBDefaults.objects.create() 64 | if not connection.features.can_return_columns_from_insert: 65 | obj1.refresh_from_db() 66 | self.assertEqual(obj1.null, 1.1) 67 | 68 | obj2 = DBDefaults.objects.create(null=None) 69 | self.assertIsNone(obj2.null) 70 | 71 | # @skipUnlessDBFeature("supports_expression_defaults") 72 | # def test_db_default_function(self): 73 | # m = DBDefaultsFunction.objects.create() 74 | # if not connection.features.can_return_columns_from_insert: 75 | # m.refresh_from_db() 76 | # self.assertAlmostEqual(m.number, pi) 77 | # self.assertEqual(m.year, datetime.now().year) 78 | # self.assertAlmostEqual(m.added, pi + 4.5) 79 | # self.assertEqual(m.multiple_subfunctions, 4.5) 80 | 81 | @skipUnlessDBFeature("supports_expression_defaults") 82 | def test_db_default_function_tidb(self): 83 | m = TiDBDefaultsFunction.objects.create() 84 | if not connection.features.can_return_columns_from_insert: 85 | m.refresh_from_db() 86 | self.assertIsInstance(m.number, Decimal) 87 | self.assertTrue(0 <= m.number <= 1) 88 | self.assertIsInstance(m.created_at, datetime) 89 | self.assertEqual(m.created_at.year, datetime.now().year) 90 | 91 | @skipUnlessDBFeature("insert_test_table_with_defaults") 92 | def test_both_default(self): 93 | create_sql = connection.features.insert_test_table_with_defaults 94 | with connection.cursor() as cursor: 95 | cursor.execute(create_sql.format(DBDefaults._meta.db_table)) 96 | obj1 = DBDefaults.objects.get() 97 | self.assertEqual(obj1.both, 2) 98 | 99 | obj2 = DBDefaults.objects.create() 100 | self.assertEqual(obj2.both, 1) 101 | 102 | def test_pk_db_default(self): 103 | obj1 = DBDefaultsPK.objects.create() 104 | if not connection.features.can_return_columns_from_insert: 105 | # refresh_from_db() cannot be used because that needs the pk to 106 | # already be known to Django. 107 | obj1 = DBDefaultsPK.objects.get(pk="en") 108 | self.assertEqual(obj1.pk, "en") 109 | self.assertEqual(obj1.language_code, "en") 110 | 111 | obj2 = DBDefaultsPK.objects.create(language_code="de") 112 | self.assertEqual(obj2.pk, "de") 113 | self.assertEqual(obj2.language_code, "de") 114 | 115 | def test_foreign_key_db_default(self): 116 | parent1 = DBDefaultsPK.objects.create(language_code="fr") 117 | child1 = DBDefaultsFK.objects.create() 118 | if not connection.features.can_return_columns_from_insert: 119 | child1.refresh_from_db() 120 | self.assertEqual(child1.language_code, parent1) 121 | 122 | parent2 = DBDefaultsPK.objects.create() 123 | if not connection.features.can_return_columns_from_insert: 124 | # refresh_from_db() cannot be used because that needs the pk to 125 | # already be known to Django. 126 | parent2 = DBDefaultsPK.objects.get(pk="en") 127 | child2 = DBDefaultsFK.objects.create(language_code=parent2) 128 | self.assertEqual(child2.language_code, parent2) 129 | 130 | # @skipUnlessDBFeature( 131 | # "can_return_columns_from_insert", "supports_expression_defaults" 132 | # ) 133 | # def test_case_when_db_default_returning(self): 134 | # m = DBDefaultsFunction.objects.create() 135 | # self.assertEqual(m.case_when, 3) 136 | 137 | # @skipIfDBFeature("can_return_columns_from_insert") 138 | # @skipUnlessDBFeature("supports_expression_defaults") 139 | # def test_case_when_db_default_no_returning(self): 140 | # m = DBDefaultsFunction.objects.create() 141 | # m.refresh_from_db() 142 | # self.assertEqual(m.case_when, 3) 143 | 144 | @skipUnlessDBFeature("supports_expression_defaults") 145 | def test_bulk_create_all_db_defaults(self): 146 | articles = [DBArticle(), DBArticle()] 147 | DBArticle.objects.bulk_create(articles) 148 | 149 | headlines = DBArticle.objects.values_list("headline", flat=True) 150 | self.assertSequenceEqual(headlines, ["Default headline", "Default headline"]) 151 | 152 | @skipUnlessDBFeature("supports_expression_defaults") 153 | def test_bulk_create_all_db_defaults_one_field(self): 154 | pub_date = datetime.now() 155 | articles = [DBArticle(pub_date=pub_date), DBArticle(pub_date=pub_date)] 156 | DBArticle.objects.bulk_create(articles) 157 | 158 | headlines = DBArticle.objects.values_list("headline", "pub_date", "cost") 159 | self.assertSequenceEqual( 160 | headlines, 161 | [ 162 | ("Default headline", pub_date, Decimal("3.33")), 163 | ("Default headline", pub_date, Decimal("3.33")), 164 | ], 165 | ) 166 | 167 | @skipUnlessDBFeature("supports_expression_defaults") 168 | def test_bulk_create_mixed_db_defaults(self): 169 | articles = [DBArticle(), DBArticle(headline="Something else")] 170 | DBArticle.objects.bulk_create(articles) 171 | 172 | headlines = DBArticle.objects.values_list("headline", flat=True) 173 | self.assertCountEqual(headlines, ["Default headline", "Something else"]) 174 | 175 | # @skipUnlessDBFeature("supports_expression_defaults") 176 | # def test_bulk_create_mixed_db_defaults_function(self): 177 | # instances = [DBDefaultsFunction(), DBDefaultsFunction(year=2000)] 178 | # DBDefaultsFunction.objects.bulk_create(instances) 179 | 180 | # years = DBDefaultsFunction.objects.values_list("year", flat=True) 181 | # self.assertCountEqual(years, [2000, datetime.now().year]) 182 | 183 | def test_full_clean(self): 184 | obj = DBArticle() 185 | obj.full_clean() 186 | obj.save() 187 | obj.refresh_from_db() 188 | self.assertEqual(obj.headline, "Default headline") 189 | 190 | obj = DBArticle(headline="Other title") 191 | obj.full_clean() 192 | obj.save() 193 | obj.refresh_from_db() 194 | self.assertEqual(obj.headline, "Other title") 195 | 196 | obj = DBArticle(headline="") 197 | with self.assertRaises(ValidationError): 198 | obj.full_clean() 199 | 200 | 201 | class AllowedDefaultTests(SimpleTestCase): 202 | def test_allowed(self): 203 | class Max(Func): 204 | function = "MAX" 205 | 206 | tests = [ 207 | Value(10), 208 | Max(1, 2), 209 | RawSQL("Now()", ()), 210 | Value(10) + Value(7), # Combined expression. 211 | ExpressionList(Value(1), Value(2)), 212 | ExpressionWrapper(Value(1), output_field=FloatField()), 213 | Case(When(GreaterThan(2, 1), then=3), default=4), 214 | ] 215 | for expression in tests: 216 | with self.subTest(expression=expression): 217 | self.assertIs(expression.allowed_default, True) 218 | 219 | def test_disallowed(self): 220 | class Max(Func): 221 | function = "MAX" 222 | 223 | tests = [ 224 | Expression(), 225 | F("field"), 226 | Max(F("count"), 1), 227 | Value(10) + F("count"), # Combined expression. 228 | ExpressionList(F("count"), Value(2)), 229 | ExpressionWrapper(F("count"), output_field=FloatField()), 230 | Collate(Value("John"), "nocase"), 231 | OrderByList("field"), 232 | ] 233 | for expression in tests: 234 | with self.subTest(expression=expression): 235 | self.assertIs(expression.allowed_default, False) 236 | -------------------------------------------------------------------------------- /tests/tidb_vector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pingcap/django-tidb/d149d1014eb138c006638ecb1ad4c144aab3fefb/tests/tidb_vector/__init__.py -------------------------------------------------------------------------------- /tests/tidb_vector/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | from django_tidb.fields.vector import ( 4 | VectorField, 5 | VectorIndex, 6 | CosineDistance, 7 | L2Distance, 8 | ) 9 | 10 | 11 | class Document(models.Model): 12 | content = models.TextField() 13 | embedding = VectorField() 14 | 15 | 16 | class DocumentExplicitDimension(models.Model): 17 | content = models.TextField() 18 | embedding = VectorField(dimensions=3) 19 | 20 | 21 | class DocumentWithAnnIndex(models.Model): 22 | content = models.TextField() 23 | embedding = VectorField(dimensions=3) 24 | 25 | class Meta: 26 | indexes = [ 27 | VectorIndex(CosineDistance("embedding"), name="idx_cos"), 28 | VectorIndex(L2Distance("embedding"), name="idx_l2"), 29 | ] 30 | -------------------------------------------------------------------------------- /tests/tidb_vector/test_vector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import sqrt 3 | from django.db.utils import OperationalError 4 | from django.test import TestCase 5 | from django_tidb.fields.vector import ( 6 | CosineDistance, 7 | L1Distance, 8 | L2Distance, 9 | NegativeInnerProduct, 10 | ) 11 | 12 | from .models import Document, DocumentExplicitDimension, DocumentWithAnnIndex 13 | 14 | 15 | class TiDBVectorFieldTests(TestCase): 16 | model = Document 17 | 18 | def test_create_get(self): 19 | obj = self.model.objects.create( 20 | content="test content", 21 | embedding=[1, 2, 3], 22 | ) 23 | obj = self.model.objects.get(pk=obj.pk) 24 | self.assertTrue(np.array_equal(obj.embedding, np.array([1, 2, 3]))) 25 | self.assertEqual(obj.embedding.dtype, np.float32) 26 | 27 | def test_get_with_different_dimension(self): 28 | self.model.objects.create( 29 | content="test content", 30 | embedding=[1, 2, 3], 31 | ) 32 | with self.assertRaises(OperationalError) as cm: 33 | list( 34 | self.model.objects.annotate( 35 | distance=CosineDistance("embedding", [3, 1, 2, 4]) 36 | ).values_list("distance", flat=True) 37 | ) 38 | self.assertIn("vectors have different dimensions", str(cm.exception)) 39 | 40 | def create_documents(self): 41 | vectors = [[1, 1, 1], [2, 2, 2], [1, 1, 2]] 42 | for i, v in enumerate(vectors): 43 | self.model.objects.create( 44 | content=f"{i + 1}", 45 | embedding=v, 46 | ) 47 | 48 | def test_l1_distance(self): 49 | self.create_documents() 50 | distance = L1Distance("embedding", [1, 1, 1]) 51 | docs = self.model.objects.annotate(distance=distance).order_by("distance") 52 | self.assertEqual([d.content for d in docs], ["1", "3", "2"]) 53 | self.assertEqual([d.distance for d in docs], [0, 1, 3]) 54 | 55 | def test_l2_distance(self): 56 | self.create_documents() 57 | distance = L2Distance("embedding", [1, 1, 1]) 58 | docs = self.model.objects.annotate(distance=distance).order_by("distance") 59 | self.assertEqual([d.content for d in docs], ["1", "3", "2"]) 60 | self.assertEqual([d.distance for d in docs], [0, 1, sqrt(3)]) 61 | 62 | def test_cosine_distance(self): 63 | self.create_documents() 64 | distance = CosineDistance("embedding", [1, 1, 1]) 65 | docs = self.model.objects.annotate(distance=distance).order_by("distance") 66 | self.assertEqual([d.content for d in docs], ["1", "2", "3"]) 67 | self.assertEqual([d.distance for d in docs], [0, 0, 0.05719095841793653]) 68 | 69 | def test_negative_inner_product(self): 70 | self.create_documents() 71 | distance = NegativeInnerProduct("embedding", [1, 1, 1]) 72 | docs = self.model.objects.annotate(distance=distance).order_by("distance") 73 | self.assertEqual([d.content for d in docs], ["2", "3", "1"]) 74 | self.assertEqual([d.distance for d in docs], [-6, -4, -3]) 75 | 76 | 77 | class TiDBVectorFieldExplicitDimensionTests(TiDBVectorFieldTests): 78 | model = DocumentExplicitDimension 79 | 80 | 81 | class TiDBVectorFieldWithAnnIndexTests(TiDBVectorFieldTests): 82 | model = DocumentWithAnnIndex 83 | -------------------------------------------------------------------------------- /tidb_settings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | import os 15 | 16 | hosts = os.getenv("TIDB_HOST", "127.0.0.1") 17 | port = os.getenv("TIDB_PORT", 4000) 18 | user = os.getenv("TIDB_USER", "root") 19 | password = os.getenv("TIDB_PASSWORD", "") 20 | 21 | DATABASES = { 22 | "default": { 23 | "ENGINE": "django_tidb", 24 | "USER": user, 25 | "PASSWORD": password, 26 | "HOST": hosts, 27 | "PORT": port, 28 | "TEST": { 29 | "NAME": "django_tests", 30 | "CHARSET": "utf8mb4", 31 | "COLLATION": "utf8mb4_general_ci", 32 | }, 33 | "OPTIONS": { 34 | "init_command": "SET @@tidb_allow_remove_auto_inc = ON", 35 | }, 36 | }, 37 | "other": { 38 | "ENGINE": "django_tidb", 39 | "USER": user, 40 | "PASSWORD": password, 41 | "HOST": hosts, 42 | "PORT": port, 43 | "TEST": { 44 | "NAME": "django_tests2", 45 | "CHARSET": "utf8mb4", 46 | "COLLATION": "utf8mb4_general_ci", 47 | }, 48 | "OPTIONS": { 49 | "init_command": "SET @@tidb_allow_remove_auto_inc = ON", 50 | }, 51 | }, 52 | } 53 | DEFAULT_AUTO_FIELD = "django.db.models.AutoField" 54 | USE_TZ = False 55 | SECRET_KEY = "django_tests_secret_key" 56 | 57 | # Use a fast hasher to speed up tests. 58 | PASSWORD_HASHERS = [ 59 | "django.contrib.auth.hashers.MD5PasswordHasher", 60 | ] 61 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PingCAP, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | [tox] 15 | alwayscopy=true 16 | envlist = py312,py311,py310,lint 17 | 18 | [gh-actions] 19 | python = 20 | 3.10: py310 21 | 3.11: py311 22 | 3.12: py312 23 | 24 | [testenv] 25 | passenv = * 26 | commands = 27 | python3 run_testing_worker.py 28 | setenv = 29 | LANG = en_US.utf-8 30 | 31 | [testenv:lint] 32 | skip_install = True 33 | allowlist_externals = bash 34 | deps = 35 | flake8==6.0.0 36 | black==23.7.0 37 | commands = 38 | bash -c "flake8 --max-line-length 130 django_tidb tests *py" 39 | bash -c "black --diff --check django_tidb tests *py" 40 | --------------------------------------------------------------------------------