├── nonebot_plugin_orm ├── py.typed ├── templates │ ├── generic │ │ ├── __init__.py │ │ ├── README │ │ ├── versions │ │ │ └── __init__.py │ │ ├── script.py.mako │ │ └── env.py │ └── multidb │ │ ├── __init__.py │ │ ├── versions │ │ └── __init__.py │ │ ├── README │ │ ├── script.py.mako │ │ └── env.py ├── exception.py ├── env.py ├── config.py ├── model.py ├── param.py ├── __main__.py ├── __init__.py ├── utils.py └── migrate.py ├── .github ├── actions │ └── setup-python │ │ └── action.yml └── workflows │ └── release.yml ├── .editorconfig ├── .pre-commit-config.yaml ├── LICENSE ├── .devcontainer └── devcontainer.json ├── pyproject.toml ├── .gitignore └── README.md /nonebot_plugin_orm/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/generic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/multidb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/generic/README: -------------------------------------------------------------------------------- 1 | 单数据库模板 2 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/generic/versions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/multidb/versions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/exception.py: -------------------------------------------------------------------------------- 1 | from alembic import util 2 | from click import ClickException 3 | 4 | 5 | class CommandError(ClickException, util.CommandError): 6 | pass 7 | 8 | 9 | class AutogenerateDiffsDetected(util.AutogenerateDiffsDetected, CommandError): 10 | pass 11 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/multidb/README: -------------------------------------------------------------------------------- 1 | 多数据库模板 2 | 3 | 多数据库与单数据库没有太大区别。 4 | 最主要的区别是会运行迁移脚本 N 次(取决于你有多少个数据库), 5 | 并为每一次运行提供一个 `bind_key` 与对应的 `AsyncEngine` 和 `MetaData` 对象。 6 | 7 | `bind_key` 允许迁移脚本正确地运行数据库对应的迁移操作,参见 mako 模板文件。 8 | 9 | 在模型定义中,你需要声明 `__bind_key__` 以将模型加入对应的 `MetaData` 中。 10 | 在配置中,你可以提供 `bind_key` 对应的 `AsyncEngine` 或数据库 URL,或者使用默认数据库。 11 | 在插件中,`bind_key` 默认为插件名(将 "-" 替换为 "\_")。 12 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from alembic import context 4 | from sqlalchemy.sql.schema import SchemaItem 5 | 6 | from . import migrate 7 | 8 | 9 | def no_drop_table( 10 | _, __, type_: str, reflected: bool, compare_to: SchemaItem | None 11 | ) -> bool: 12 | return not ( 13 | getattr(context.config.cmd_opts, "cmd", (None,))[0] == migrate.check 14 | and type_ == "table" 15 | and reflected 16 | and compare_to is None 17 | ) 18 | -------------------------------------------------------------------------------- /.github/actions/setup-python/action.yml: -------------------------------------------------------------------------------- 1 | name: Setup Python 2 | description: Setup Python 3 | 4 | inputs: 5 | python-version: 6 | description: Python version 7 | required: false 8 | default: "3.10" 9 | 10 | runs: 11 | using: "composite" 12 | steps: 13 | - uses: pdm-project/setup-pdm@v3 14 | with: 15 | python-version: ${{ inputs.python-version }} 16 | architecture: "x64" 17 | cache: true 18 | 19 | - name: Install dependencies 20 | run: pdm install 21 | shell: bash 22 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/generic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | 迁移 ID: ${up_revision} 4 | 父迁移: ${down_revision | comma,n} 5 | 创建时间: ${create_date} 6 | 7 | """ 8 | from __future__ import annotations 9 | 10 | from collections.abc import Sequence 11 | 12 | from alembic import op 13 | import sqlalchemy as sa 14 | ${imports if imports else ""} 15 | 16 | revision: str = ${repr(up_revision)} 17 | down_revision: str | Sequence[str] | None = ${repr(down_revision)} 18 | branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} 19 | depends_on: str | Sequence[str] | None = ${repr(depends_on)} 20 | 21 | 22 | def upgrade(name: str = "") -> None: 23 | if name: 24 | return 25 | ${upgrades} 26 | 27 | 28 | def downgrade(name: str = "") -> None: 29 | if name: 30 | return 31 | ${downgrades} 32 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | root = true 3 | 4 | [*] 5 | indent_style = space 6 | indent_size = 2 7 | end_of_line = lf 8 | charset = utf-8 9 | trim_trailing_whitespace = true 10 | insert_final_newline = true 11 | 12 | # The JSON files contain newlines inconsistently 13 | [*.json] 14 | insert_final_newline = ignore 15 | 16 | # Minified JavaScript files shouldn't be changed 17 | [**.min.js] 18 | indent_style = ignore 19 | insert_final_newline = ignore 20 | 21 | # Makefiles always use tabs for indentation 22 | [Makefile] 23 | indent_style = tab 24 | 25 | # Batch files use tabs for indentation 26 | [*.bat] 27 | indent_style = tab 28 | 29 | [*.md] 30 | trim_trailing_whitespace = false 31 | 32 | # Matches the exact files either package.json or .travis.yml 33 | [{package.json,.travis.yml}] 34 | indent_size = 2 35 | 36 | [{*.py,*.pyi}] 37 | indent_size = 4 38 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_install_hook_types: [pre-commit, prepare-commit-msg] 2 | ci: 3 | autofix_commit_msg: ":rotating_light: auto fix by pre-commit hooks" 4 | autofix_prs: true 5 | autoupdate_branch: master 6 | autoupdate_schedule: monthly 7 | autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks" 8 | repos: 9 | - repo: https://github.com/hadialqattan/pycln 10 | rev: v2.4.0 11 | hooks: 12 | - id: pycln 13 | args: [--config, pyproject.toml] 14 | stages: [pre-commit] 15 | 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.13.2 18 | hooks: 19 | - id: isort 20 | stages: [pre-commit] 21 | 22 | - repo: https://github.com/psf/black 23 | rev: 24.10.0 24 | hooks: 25 | - id: black 26 | stages: [pre-commit] 27 | 28 | - repo: https://github.com/nonebot/nonemoji 29 | rev: v0.1.4 30 | hooks: 31 | - id: nonemoji 32 | stages: [prepare-commit-msg] 33 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Union, Optional 3 | 4 | from sqlalchemy import URL 5 | from pydantic import BaseModel 6 | from nonebot import get_plugin_config 7 | from sqlalchemy.ext.asyncio import AsyncEngine 8 | 9 | from .migrate import AlembicConfig 10 | 11 | __all__ = ( 12 | "Config", 13 | "plugin_config", 14 | ) 15 | 16 | 17 | class Config(BaseModel, arbitrary_types_allowed=True): 18 | sqlalchemy_database_url: Union[str, URL, AsyncEngine] = "" 19 | sqlalchemy_binds: dict[str, Union[str, URL, dict[str, Any], AsyncEngine]] = {} 20 | sqlalchemy_echo: bool = False 21 | sqlalchemy_engine_options: dict[str, Any] = {} 22 | sqlalchemy_session_options: dict[str, Any] = {} 23 | 24 | alembic_config: Union[Path, AlembicConfig, None] = None 25 | alembic_script_location: Optional[Path] = None 26 | alembic_version_locations: Union[Path, dict[str, Path], None] = None 27 | alembic_context: dict[str, Any] = {} 28 | alembic_startup_check: bool = True 29 | 30 | 31 | plugin_config = get_plugin_config(Config) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2023 NoneBot Team 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | this software and associated documentation files (the "Software"), to deal in 6 | the Software without restriction, including without limitation the rights to 7 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | the Software, and to permit persons to whom the Software is furnished to do so, 9 | subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/multidb/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | 迁移 ID: ${up_revision} 4 | 父迁移: ${down_revision | comma,n} 5 | 创建时间: ${create_date} 6 | 7 | """ 8 | from __future__ import annotations 9 | 10 | from collections.abc import Sequence 11 | from contextlib import suppress 12 | 13 | import sqlalchemy as sa 14 | from alembic import op 15 | ${imports if imports else ""} 16 | 17 | revision: str = ${repr(up_revision)} 18 | down_revision: str | Sequence[str] | None = ${repr(down_revision)} 19 | branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} 20 | depends_on: str | Sequence[str] | None = ${repr(depends_on)} 21 | 22 | 23 | def upgrade(name: str) -> None: 24 | with suppress(KeyError): 25 | globals()[f"upgrade_{name}"]() 26 | 27 | 28 | def downgrade(name: str) -> None: 29 | with suppress(KeyError): 30 | globals()[f"downgrade_{name}"]() 31 | 32 | % for name in config.attributes["metadatas"]: 33 | 34 | def upgrade_${name}() -> None: 35 | ${context.get(f"{name}_upgrades", "pass")} 36 | 37 | 38 | def downgrade_${name}() -> None: 39 | ${context.get(f"{name}_downgrades", "pass")} 40 | 41 | % endfor 42 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Default Linux Universal", 3 | "image": "mcr.microsoft.com/devcontainers/universal:2-linux", 4 | "features": { 5 | "ghcr.io/devcontainers-contrib/features/pdm:2": {} 6 | }, 7 | "postCreateCommand": "pdm install && pdm run pre-commit install", 8 | "customizations": { 9 | "vscode": { 10 | "settings": { 11 | "python.analysis.diagnosticMode": "workspace", 12 | "python.analysis.typeCheckingMode": "basic", 13 | "[python]": { 14 | "editor.defaultFormatter": "ms-python.black-formatter", 15 | "editor.codeActionsOnSave": { 16 | "source.organizeImports": "explicit" 17 | } 18 | }, 19 | "files.exclude": { 20 | "**/__pycache__": true 21 | }, 22 | "files.watcherExclude": { 23 | "**/target/**": true, 24 | "**/__pycache__": true 25 | } 26 | }, 27 | "extensions": [ 28 | "ms-python.python", 29 | "ms-python.vscode-pylance", 30 | "ms-python.isort", 31 | "ms-python.black-formatter", 32 | "EditorConfig.EditorConfig" 33 | ] 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | permissions: 9 | id-token: write 10 | contents: write 11 | 12 | jobs: 13 | release: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Setup Python 20 | uses: ./.github/actions/setup-python 21 | 22 | - name: Get version 23 | id: version 24 | run: | 25 | echo "VERSION=$(pdm show --version)" >> $GITHUB_OUTPUT 26 | echo "TAG_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT 27 | echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT 28 | 29 | - name: Check version 30 | if: steps.version.outputs.VERSION != steps.version.outputs.TAG_VERSION 31 | run: exit 1 32 | 33 | - name: Build package 34 | run: pdm build 35 | 36 | - name: Upload dist 37 | uses: actions/upload-artifact@v4 38 | with: 39 | name: dist 40 | path: dist/* 41 | 42 | - name: Publish package to PyPI 43 | uses: pypa/gh-action-pypi-publish@release/v1 44 | 45 | - name: Publish package to GitHub 46 | run: | 47 | gh release create ${{ steps.version.outputs.TAG_NAME }} dist/* \ 48 | -t "🔖 Release ${{ steps.version.outputs.VERSION }}" --generate-notes 49 | env: 50 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 51 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "nonebot-plugin-orm" 3 | version = "0.8.2" 4 | description = "SQLAlchemy ORM support for nonebot" 5 | authors = [ 6 | { name = "yanyongyu", email = "yyy@nonebot.dev" }, 7 | { name = "ProgramRipper", email = "programripper@foxmail.com" }, 8 | ] 9 | dependencies = [ 10 | "alembic~=1.16", 11 | "click~=8.1", 12 | "importlib-metadata>=4.6; python_version < \"3.10\"", 13 | "importlib-resources>=5.12; python_version < \"3.12\"", 14 | "nonebot-plugin-localstore~=0.7", 15 | "nonebot2~=2.4", 16 | "sqlalchemy~=2.0", 17 | "typing-extensions~=4.13", 18 | ] 19 | requires-python = ">=3.9,<4.0" 20 | readme = "README.md" 21 | license = { text = "MIT" } 22 | keywords = ["nonebot", "orm", "sqlalchemy"] 23 | 24 | [project.urls] 25 | homepage = "https://github.com/nonebot/plugin-orm" 26 | repository = "https://github.com/nonebot/plugin-orm" 27 | documentation = "https://github.com/nonebot/plugin-orm" 28 | 29 | [project.optional-dependencies] 30 | default = ["sqlalchemy[aiosqlite]"] 31 | mysql = ["sqlalchemy[aiomysql]"] 32 | asyncmy = ["sqlalchemy[asyncmy]"] 33 | aiomysql = ["sqlalchemy[aiomysql]"] 34 | postgresql = ["sqlalchemy[postgresql-psycopgbinary]"] 35 | psycopg = ["sqlalchemy[postgresql-psycopgbinary]"] 36 | asyncpg = ["sqlalchemy[postgresql-asyncpg]"] 37 | sqlite = ["sqlalchemy[aiosqlite]"] 38 | aiosqlite = ["sqlalchemy[aiosqlite]"] 39 | 40 | [project.entry-points.nb_scripts] 41 | orm = "nonebot_plugin_orm.__main__:main" 42 | 43 | [build-system] 44 | requires = ["pdm-backend"] 45 | build-backend = "pdm.backend" 46 | 47 | [tool.pdm] 48 | [tool.pdm.dev-dependencies] 49 | dev = [ 50 | "black~=24.2", 51 | "importlib-metadata~=7.0", 52 | "importlib-resources~=6.1", 53 | "isort~=5.13", 54 | "nonemoji~=0.1", 55 | "pre-commit~=3.5", 56 | "pycln~=2.4", 57 | "sqlalchemy[aiosqlite]", 58 | "typing-extensions~=4.9", 59 | ] 60 | 61 | [tool.black] 62 | line-length = 88 63 | include = '\.pyi?$' 64 | extend-exclude = ''' 65 | ''' 66 | 67 | [tool.isort] 68 | profile = "black" 69 | line_length = 88 70 | length_sort = true 71 | skip_gitignore = true 72 | force_sort_within_sections = true 73 | extra_standard_library = [ 74 | "importlib_metadata", 75 | "importlib_resources", 76 | "typing_extensions", 77 | ] 78 | 79 | [tool.pycln] 80 | path = "." 81 | 82 | [tool.pyright] 83 | pythonVersion = "3.9" 84 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/generic/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from typing import Any, cast 5 | 6 | from alembic import context 7 | from sqlalchemy import Connection 8 | from sqlalchemy.util import await_only 9 | from sqlalchemy.ext.asyncio import AsyncEngine 10 | 11 | from nonebot_plugin_orm.env import no_drop_table 12 | from nonebot_plugin_orm import AlembicConfig, plugin_config 13 | 14 | # Alembic Config 对象, 它提供正在使用的 .ini 文件中的值. 15 | config = cast(AlembicConfig, context.config) 16 | 17 | # 默认 AsyncEngine 18 | engine: AsyncEngine = config.attributes["engines"][""] 19 | 20 | # 模型的 MetaData, 用于 "autogenerate" 支持. 21 | # from myapp import mymodel 22 | # target_metadata = mymodel.Base.metadata 23 | target_metadata = config.attributes["metadatas"][""] 24 | 25 | # 其他来自 config 的值, 可以按 env.py 的需求定义, 例如可以获取: 26 | # my_important_option = config.get_main_option("my_important_option") 27 | # ... 等等. 28 | 29 | 30 | def run_migrations_offline() -> None: 31 | """在“离线”模式下运行迁移. 32 | 33 | 虽然这里也可以获得 Engine, 但我们只需要一个 URL 即可配置 context. 34 | 通过跳过 Engine 的创建, 我们甚至不需要 DBAPI 可用. 35 | 36 | 在这里调用 context.execute() 会将给定的字符串写入到脚本输出. 37 | """ 38 | 39 | opts: dict[str, Any] = { 40 | "url": engine.url, 41 | "dialect_opts": {"paramstyle": "named"}, 42 | "target_metadata": target_metadata, 43 | "literal_binds": True, 44 | } | plugin_config.alembic_context 45 | context.configure(**opts) 46 | 47 | with context.begin_transaction(): 48 | context.run_migrations() 49 | 50 | 51 | def do_run_migrations(connection: Connection) -> None: 52 | opts: dict[str, Any] = { 53 | "connection": connection, 54 | "render_as_batch": True, 55 | "target_metadata": target_metadata, 56 | "include_object": no_drop_table, 57 | } | plugin_config.alembic_context 58 | context.configure(**opts) 59 | 60 | with context.begin_transaction(): 61 | context.run_migrations() 62 | 63 | 64 | async def run_migrations_online() -> None: 65 | """在“在线”模式下运行迁移. 66 | 67 | 这种情况下, 我们需要为 context 创建一个连接. 68 | """ 69 | 70 | async with engine.connect() as connection: 71 | await connection.run_sync(do_run_migrations) 72 | 73 | 74 | if context.is_offline_mode(): 75 | run_migrations_offline() 76 | else: 77 | coro = run_migrations_online() 78 | 79 | try: 80 | asyncio.run(coro) 81 | except RuntimeError: 82 | await_only(coro) 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm-python 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from inspect import Parameter, Signature 4 | from typing_extensions import ( 5 | TYPE_CHECKING, 6 | Any, 7 | ClassVar, 8 | Annotated, 9 | get_args, 10 | get_origin, 11 | get_annotations, 12 | ) 13 | 14 | from sqlalchemy import Table, MetaData 15 | from nonebot import get_plugin_by_module_name 16 | from sqlalchemy.orm import Mapped, DeclarativeBase 17 | 18 | from .utils import DependsInner 19 | 20 | __all__ = ("Model",) 21 | 22 | 23 | _NAMING_CONVENTION = { 24 | "ix": "ix_%(column_0_label)s", 25 | "uq": "uq_%(table_name)s_%(column_0_name)s", 26 | "ck": "ck_%(table_name)s_%(constraint_name)s", 27 | "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", 28 | "pk": "pk_%(table_name)s", 29 | } 30 | 31 | 32 | class Model(DeclarativeBase): 33 | metadata = MetaData(naming_convention=_NAMING_CONVENTION) 34 | 35 | if TYPE_CHECKING: 36 | __bind_key__: ClassVar[str] 37 | __signature__: ClassVar[Signature] 38 | __table__: ClassVar[Table] 39 | 40 | def __init_subclass__(cls, **kwargs) -> None: 41 | _setup_di(cls) 42 | _setup_tablename(cls) 43 | 44 | super().__init_subclass__(**kwargs) 45 | 46 | if not hasattr(cls, "__table__"): 47 | return 48 | 49 | _setup_bind(cls) 50 | 51 | 52 | def _setup_di(cls: type[Model]) -> None: 53 | """Get signature for NoneBot's dependency injection, 54 | and set annotations for SQLAlchemy declarative class. 55 | """ 56 | parameters: list[Parameter] = [] 57 | 58 | annotations: dict[str, Any] = {} 59 | for base in reversed(cls.__mro__): 60 | annotations.update(get_annotations(base, eval_str=True)) 61 | 62 | for name, type_annotation in annotations.items(): 63 | # Check if the attribute is both a dependent and a mapped column 64 | depends_inner = None 65 | if get_origin(type_annotation) is Annotated: 66 | (type_annotation, *extra_args) = get_args(type_annotation) 67 | depends_inner = next( 68 | (x for x in extra_args if isinstance(x, DependsInner)), None 69 | ) 70 | 71 | if get_origin(type_annotation) is not Mapped: 72 | continue 73 | 74 | default = getattr(cls, name, Signature.empty) 75 | 76 | depends_inner = default if isinstance(default, DependsInner) else depends_inner 77 | if depends_inner is None: 78 | continue 79 | 80 | # Set parameter for NoneBot dependency injection 81 | parameters.append( 82 | Parameter( 83 | name, 84 | Parameter.KEYWORD_ONLY, 85 | default=depends_inner, 86 | annotation=get_args(type_annotation)[0], 87 | ) 88 | ) 89 | 90 | # Set annotation for SQLAlchemy declarative class 91 | cls.__annotations__[name] = type_annotation 92 | if default is not Signature.empty and not isinstance(default, Mapped): 93 | delattr(cls, name) 94 | 95 | cls.__signature__ = Signature(parameters) 96 | 97 | 98 | def _setup_tablename(cls: type[Model]) -> None: 99 | for attr in ("__abstract__", "__tablename__", "__table__"): 100 | if getattr(cls, attr, None): 101 | return 102 | 103 | cls.__tablename__ = cls.__name__.lower() 104 | 105 | if plugin := get_plugin_by_module_name(cls.__module__): 106 | cls.__tablename__ = f"{plugin.name.replace('-', '_')}_{cls.__tablename__}" 107 | 108 | 109 | def _setup_bind(cls: type[Model]) -> None: 110 | bind_key: str | None = getattr(cls, "__bind_key__", None) 111 | 112 | if bind_key is None: 113 | if plugin := get_plugin_by_module_name(cls.__module__): 114 | bind_key = plugin.name 115 | else: 116 | bind_key = "" 117 | 118 | cls.__table__.info["bind_key"] = bind_key 119 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/templates/multidb/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from typing import Any, cast 5 | 6 | from alembic import context 7 | from sqlalchemy.util import await_only 8 | from sqlalchemy.ext.asyncio import AsyncEngine, AsyncConnection 9 | from sqlalchemy import MetaData, Connection, TwoPhaseTransaction 10 | 11 | from nonebot_plugin_orm.env import no_drop_table 12 | from nonebot_plugin_orm import AlembicConfig, plugin_config 13 | 14 | # 是否使用二阶段提交 (Two-Phase Commit), 15 | # 当同时迁移多个数据库时, 可以启用以保证迁移的原子性. 16 | # 注意: 只有部分数据库支持(例如 SQLite 就不支持). 17 | USE_TWOPHASE = False 18 | 19 | # Alembic Config 对象, 它提供正在使用的 .ini 文件中的值. 20 | config = cast(AlembicConfig, context.config) 21 | 22 | # bind key 到 AsyncEngine 的映射 23 | engines: dict[str, AsyncEngine] = config.attributes["engines"] 24 | 25 | # bind key 到 MetaData 的映射, 用于 "autogenerate" 支持. 26 | # Metadata 对象必须仅包含对应数据库中的表. 27 | # table.to_metadata() 在需要“复制”表到 MetaData 中时可能很有用. 28 | # from myapp import mymodel 29 | # target_metadata = { 30 | # "engine1": mymodel.metadata1, 31 | # "engine2": mymodel.metadata2 32 | # } 33 | target_metadatas: dict[str, MetaData] = config.attributes["metadatas"] 34 | 35 | # 其他来自 config 的值, 可以按 env.py 的需求定义, 例如可以获取: 36 | # my_important_option = config.get_main_option("my_important_option") 37 | # ... 等等. 38 | 39 | 40 | def run_migrations_offline() -> None: 41 | """在“离线”模式下运行迁移. 42 | 43 | 虽然这里也可以获得 Engine, 但我们只需要一个 URL 即可配置 context. 44 | 通过跳过 Engine 的创建, 我们甚至不需要 DBAPI 可用. 45 | 46 | 在这里调用 context.execute() 会将给定的字符串写入到脚本输出. 47 | 48 | """ 49 | 50 | for name, engine in engines.items(): 51 | config.print_stdout(f"迁移数据库 {name or ''} 中 ...") 52 | file_ = f"{name}.sql" 53 | with open(file_, "w") as buffer: 54 | opts: dict[str, Any] = { 55 | "url": engine.url, 56 | "dialect_opts": {"paramstyle": "named"}, 57 | "output_buffer": buffer, 58 | "target_metadata": target_metadatas[name], 59 | "literal_binds": True, 60 | } | plugin_config.alembic_context 61 | context.configure(**opts) 62 | 63 | with context.begin_transaction(): 64 | context.run_migrations(name=name) 65 | config.print_stdout(f"将输出写入到 {file_}") 66 | 67 | 68 | def do_run_migrations(conn: Connection, name: str, metadata: MetaData) -> None: 69 | opts: dict[str, Any] = { 70 | "connection": conn, 71 | "render_as_batch": True, 72 | "target_metadata": metadata, 73 | "include_object": no_drop_table, 74 | "upgrade_token": f"{name}_upgrades", 75 | "downgrade_token": f"{name}_downgrades", 76 | } | plugin_config.alembic_context 77 | context.configure(**opts) 78 | 79 | context.run_migrations(name=name) 80 | 81 | 82 | async def run_migrations_online() -> None: 83 | """在“在线”模式下运行迁移. 84 | 85 | 这种情况下, 我们需要为 context 创建一个连接. 86 | """ 87 | 88 | conns: dict[str, AsyncConnection] = {} 89 | txns: dict[str, TwoPhaseTransaction] = {} 90 | 91 | try: 92 | for name, engine in engines.items(): 93 | config.print_stdout(f"迁移数据库 {name or ''} 中 ...") 94 | conn = conns[name] = await engine.connect() 95 | if USE_TWOPHASE: 96 | txns[name] = await conn.run_sync(Connection.begin_twophase) 97 | else: 98 | await conn.begin() 99 | 100 | await conn.run_sync(do_run_migrations, name, target_metadatas[name]) 101 | 102 | if USE_TWOPHASE: 103 | await asyncio.gather( 104 | *( 105 | conn.run_sync(lambda _: txns[name].prepare()) 106 | for name, conn in conns.items() 107 | ) 108 | ) 109 | 110 | await asyncio.gather(*(conn.commit() for conn in conns.values())) 111 | except BaseException: 112 | await asyncio.gather(*(conn.rollback() for conn in conns.values())) 113 | raise 114 | finally: 115 | await asyncio.gather(*(conn.close() for conn in conns.values())) 116 | 117 | 118 | if context.is_offline_mode(): 119 | run_migrations_offline() 120 | else: 121 | coro = run_migrations_online() 122 | 123 | try: 124 | asyncio.run(coro) 125 | except RuntimeError: 126 | await_only(coro) 127 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/param.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from operator import methodcaller 5 | from inspect import Parameter, isclass 6 | from collections.abc import Iterator, Sequence, AsyncIterator 7 | from typing_extensions import Any, Self, Annotated, cast, get_args, get_origin 8 | 9 | from pydantic.fields import FieldInfo 10 | from nonebot.dependencies import Param 11 | from nonebot.typing import origin_is_union 12 | from nonebot.params import Depends, DependParam 13 | from sqlalchemy import Row, Result, ScalarResult, select 14 | from sqlalchemy.sql.selectable import ExecutableReturnsRows 15 | from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult 16 | 17 | from .model import Model 18 | from .utils import Option, Dependency, generic_issubclass 19 | 20 | __all__ = ( 21 | "SQLDepends", 22 | "ORMParam", 23 | ) 24 | 25 | 26 | PATTERNS = { 27 | AsyncIterator[Sequence[Row[tuple[Any, ...]]]]: Option( 28 | True, 29 | False, 30 | (methodcaller("partitions"),), 31 | ), 32 | AsyncIterator[Sequence[tuple[Any, ...]]]: Option( 33 | True, 34 | False, 35 | (methodcaller("partitions"),), 36 | ), 37 | AsyncIterator[Sequence[Any]]: Option( 38 | True, 39 | True, 40 | (methodcaller("partitions"),), 41 | ), 42 | Iterator[Sequence[Row[tuple[Any, ...]]]]: Option( 43 | False, 44 | False, 45 | (methodcaller("partitions"),), 46 | ), 47 | Iterator[Sequence[tuple[Any, ...]]]: Option( 48 | False, 49 | False, 50 | (methodcaller("partitions"),), 51 | ), 52 | Iterator[Sequence[Any]]: Option( 53 | False, 54 | True, 55 | (methodcaller("partitions"),), 56 | ), 57 | AsyncResult[tuple[Any, ...]]: Option( 58 | True, 59 | False, 60 | ), 61 | AsyncScalarResult[Any]: Option( 62 | True, 63 | True, 64 | ), 65 | Result[tuple[Any, ...]]: Option( 66 | False, 67 | False, 68 | ), 69 | ScalarResult[Any]: Option( 70 | False, 71 | True, 72 | ), 73 | AsyncIterator[Row[tuple[Any, ...]]]: Option( 74 | True, 75 | False, 76 | ), 77 | Iterator[Row[tuple[Any, ...]]]: Option( 78 | False, 79 | False, 80 | ), 81 | Sequence[Row[tuple[Any, ...]]]: Option( 82 | True, 83 | False, 84 | (), 85 | methodcaller("all"), 86 | ), 87 | Sequence[tuple[Any, ...]]: Option( 88 | True, 89 | False, 90 | (), 91 | methodcaller("all"), 92 | ), 93 | Sequence[Any]: Option( 94 | True, 95 | True, 96 | (), 97 | methodcaller("all"), 98 | ), 99 | tuple[Any, ...]: Option( 100 | True, 101 | False, 102 | (), 103 | methodcaller("one_or_none"), 104 | ), 105 | Any: Option( 106 | True, 107 | True, 108 | (), 109 | methodcaller("one_or_none"), 110 | ), 111 | } 112 | 113 | 114 | @dataclass 115 | class SQLDependsInner: 116 | dependency: ExecutableReturnsRows 117 | use_cache: bool = True 118 | validate: bool | FieldInfo = False 119 | 120 | 121 | def SQLDepends( 122 | dependency: ExecutableReturnsRows, 123 | *, 124 | use_cache: bool = True, 125 | validate: bool | FieldInfo = False, 126 | ) -> Any: 127 | return SQLDependsInner(dependency, use_cache, validate) 128 | 129 | 130 | class ORMParam(DependParam): 131 | @classmethod 132 | def _check_param( 133 | cls, param: Parameter, allow_types: tuple[type[Param], ...] 134 | ) -> Self | None: 135 | type_annotation, depends_inner = param.annotation, None 136 | if get_origin(param.annotation) is Annotated: 137 | type_annotation, *extra_args = get_args(param.annotation) 138 | depends_inner = next( 139 | (x for x in reversed(extra_args) if isinstance(x, SQLDependsInner)), 140 | None, 141 | ) 142 | 143 | if isinstance(param.default, SQLDependsInner): 144 | depends_inner = param.default 145 | 146 | for pattern, option in PATTERNS.items(): 147 | if models := cast( 148 | "list[Any]", generic_issubclass(pattern, type_annotation) 149 | ): 150 | break 151 | else: 152 | models, option = [], Option() 153 | 154 | for index, model in enumerate(models): 155 | if origin_is_union(get_origin(model)): 156 | models[index] = next( 157 | ( 158 | arg 159 | for arg in get_args(model) 160 | if isclass(arg) and issubclass(arg, Model) 161 | ), 162 | None, 163 | ) 164 | 165 | if not (isclass(models[index]) and issubclass(models[index], Model)): 166 | models = [] 167 | break 168 | 169 | if depends_inner is not None: 170 | statement = depends_inner.dependency 171 | elif models: 172 | # NOTE: statement is generated (see below) 173 | statement = select(*models).where( 174 | *( 175 | getattr(model, name) == param.default 176 | for model in models 177 | for name, param in model.__signature__.parameters.items() 178 | ) 179 | ) 180 | else: 181 | return 182 | 183 | return super()._check_param( 184 | param.replace( 185 | default=Depends( 186 | Dependency(statement, option), 187 | use_cache=( 188 | depends_inner.use_cache if depends_inner else False 189 | ), # NOTE: default use_cache=False as it is impossible to reuse a generated statement (see above) 190 | validate=depends_inner.validate if depends_inner else False, 191 | ) 192 | ), 193 | allow_types, 194 | ) 195 | 196 | @classmethod 197 | def _check_parameterless( 198 | cls, value: Any, allow_types: tuple[type[Param], ...] 199 | ) -> None: 200 | return 201 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | nonebot 4 |

5 | 6 |
7 | 8 | # NoneBot Plugin ORM 9 | 10 | 11 | 12 | _✨ NoneBot 数据库支持插件 ✨_ 13 | 14 | 15 |
16 | 17 |

18 | 19 | license 20 | 21 | 22 | pypi 23 | 24 | python 25 |

26 | 27 | ## 安装 28 | 29 | ```shell 30 | pip install nonebot-plugin-orm 31 | poetry add nonebot-plugin-orm 32 | pdm add nonebot-plugin-orm 33 | 34 | # 无需配置、开箱即用的默认依赖 35 | pip install nonebot-plugin-orm[default] 36 | 37 | # 特定数据库后端的依赖 38 | pip install nonebot-plugin-orm[mysql] 39 | pip install nonebot-plugin-orm[postgresql] 40 | pip install nonebot-plugin-orm[sqlite] 41 | 42 | # 特定数据库驱动的依赖 43 | pip install nonebot-plugin-orm[asyncmy] 44 | pip install nonebot-plugin-orm[aiomysql] 45 | pip install nonebot-plugin-orm[psycopg] 46 | pip install nonebot-plugin-orm[asyncpg] 47 | pip install nonebot-plugin-orm[aiosqlite] 48 | ``` 49 | 50 | ## 使用方式 51 | 52 | ### ORM 53 | 54 | #### Model 依赖注入 55 | 56 | ```python 57 | from nonebot.adapters import Event 58 | from nonebot.params import Depends 59 | from nonebot import require, on_message 60 | from sqlalchemy.orm import Mapped, mapped_column 61 | 62 | require("nonebot_plugin_orm") 63 | from nonebot_plugin_orm import Model, async_scoped_session 64 | 65 | matcher = on_message() 66 | 67 | 68 | def get_user_id(event: Event) -> str: 69 | return event.get_user_id() 70 | 71 | 72 | class User(Model): 73 | id: Mapped[int] = mapped_column(primary_key=True) 74 | user_id: Mapped[str] = Depends(get_user_id) 75 | 76 | 77 | @matcher.handle() 78 | async def _(event: Event, sess: async_scoped_session, user: User | None): 79 | if user: 80 | await matcher.finish(f"Hello, {user.user_id}") 81 | 82 | sess.add(User(user_id=get_user_id(event))) 83 | await sess.commit() 84 | await matcher.finish("Hello, new user!") 85 | ``` 86 | 87 | #### SQL 依赖注入 88 | 89 | ```python 90 | from sqlalchemy import select 91 | from nonebot.adapters import Event 92 | from nonebot.params import Depends 93 | from nonebot import require, on_message 94 | from sqlalchemy.orm import Mapped, mapped_column 95 | 96 | require("nonebot_plugin_orm") 97 | from nonebot_plugin_orm import Model, SQLDepends, async_scoped_session 98 | 99 | matcher = on_message() 100 | 101 | 102 | def get_session_id(event: Event) -> str: 103 | return event.get_session_id() 104 | 105 | 106 | class Session(Model): 107 | id: Mapped[int] = mapped_column(primary_key=True) 108 | session_id: Mapped[str] 109 | 110 | 111 | @matcher.handle() 112 | async def _( 113 | event: Event, 114 | sess: async_scoped_session, 115 | session: Session 116 | | None = SQLDepends( 117 | select(Session).where(Session.session_id == Depends(get_session_id)) 118 | ), 119 | ): 120 | if session: 121 | await matcher.finish(f"Hello, {session.session_id}") 122 | 123 | sess.add(Session(session_id=get_session_id(event))) 124 | await sess.commit() 125 | await matcher.finish("Hello, new user!") 126 | 127 | ``` 128 | 129 | ### CLI 130 | 131 | 依赖 [NB CLI](https://github.com/nonebot/nb-cli) 132 | 133 | ```properties 134 | $ nb orm 135 | Usage: nb orm [OPTIONS] COMMAND [ARGS]... 136 | 137 | Options: 138 | -c, --config FILE 可选的配置文件;默认为 ALEMBIC_CONFIG 环境变量的值,或者 "alembic.ini"(如果存在) 139 | -n, --name TEXT .ini 文件中用于 Alembic 配置的小节的名称 [default: alembic] 140 | -x TEXT 自定义 env.py 脚本使用的其他参数,例如:-x setting1=somesetting -x 141 | setting2=somesetting 142 | -q, --quite 不要输出日志到标准输出 143 | --help Show this message and exit. 144 | 145 | Commands: 146 | branches 显示所有的分支。 147 | check 检查数据库是否与模型定义一致。 148 | current 显示当前的迁移。 149 | downgrade 回退到先前版本。 150 | edit 使用 $EDITOR 编辑迁移脚本。 151 | ensure_version 创建版本表。 152 | heads 显示所有的分支头。 153 | history 显示迁移的历史。 154 | init 初始化脚本目录。 155 | list_templates 列出所有可用的模板。 156 | merge 合并多个迁移。创建一个新的迁移脚本。 157 | revision 创建一个新迁移脚本。 158 | show 显示迁移的信息。 159 | stamp 将数据库标记为特定的迁移版本,不运行任何迁移。 160 | upgrade 升级到较新版本。 161 | ``` 162 | 163 | ## 配置项 164 | 165 | ### sqlalchemy_database_url 166 | 167 | 默认数据库连接 URL。 168 | 参见:[Engine Configuration — SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls) 169 | 170 | ```properties 171 | SQLALCHEMY_DATABASE_URL=sqlite+aiosqlite:// 172 | ``` 173 | 174 | ### sqlalchemy_binds 175 | 176 | bind keys 到 `AsyncEngine` 选项的映射。值可以是数据库连接 URL、`AsyncEngine` 选项字典或者 `AsyncEngine` 实例。 177 | 178 | ```properties 179 | SQLALCHEMY_BINDS='{ 180 | "": "sqlite+aiosqlite://", 181 | "nonebot_plugin_user": { 182 | "url": "postgresql+asyncpg://scott:tiger@localhost/mydatabase", 183 | "echo": true 184 | } 185 | }' 186 | ``` 187 | 188 | ### sqlalchemy_echo 189 | 190 | 所有 `AsyncEngine` 的 `echo` 和 `echo_pool` 选项的默认值。用于快速调试连接和 SQL 生成问题。 191 | 192 | ```properties 193 | SQLALCHEMY_ECHO=true 194 | ``` 195 | 196 | ### sqlalchemy_engine_options 197 | 198 | 所有 `AsyncEngine` 的默认选项字典。 199 | 参见:[Engine Configuration — SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/core/engines.html#engine-configuration) 200 | 201 | ```properties 202 | SQLALCHEMY_ENGINE_OPTIONS='{ 203 | "pool_size": 5, 204 | "max_overflow": 10, 205 | "pool_timeout": 30, 206 | "pool_recycle": 3600, 207 | "echo": true 208 | }' 209 | ``` 210 | 211 | ### sqlalchemy_session_options 212 | 213 | `AsyncSession` 的选项字典。 214 | 参见:[Session API — SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session.__init__) 215 | 216 | ```properties 217 | SQLALCHEMY_SESSION_OPTIONS='{ 218 | "autoflush": false, 219 | "autobegin": true, 220 | "expire_on_commit": true 221 | }' 222 | ``` 223 | 224 | ### alembic_config 225 | 226 | 配置文件路径或 `AlembicConfig` 实例。 227 | 228 | ```properties 229 | ALEMBIC_CONFIG=alembic.ini 230 | ``` 231 | 232 | ### alembic_script_location 233 | 234 | 脚本目录路径。 235 | 236 | ```properties 237 | ALEMBIC_SCRIPT_LOCATION=migrations 238 | ``` 239 | 240 | ### alembic_version_locations 241 | 242 | 迁移脚本目录路径或分支标签到迁移脚本目录路径的映射。 243 | 244 | ```properties 245 | ALEMBIC_VERSION_LOCATIONS=migrations/versions 246 | 247 | ALEMBIC_VERSION_LOCATIONS='{ 248 | "": "migrations/versions", 249 | "nonebot_plugin_user": "src/nonebot_plugin_user/versions", 250 | "nonebot_plugin_chatrecorder": "migrations/versions/nonebot_plugin_chatrecorder" 251 | }' 252 | ``` 253 | 254 | ### alembic_context 255 | 256 | `MigrationContext` 的选项字典。 257 | 参见:[Runtime Objects — Alembic 1.12.0 documentation](https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.environment.EnvironmentContext.configure) 258 | 259 | ```properties 260 | ALEMBIC_CONTEXT='{ 261 | "render_as_batch": true 262 | }' 263 | ``` 264 | 265 | ### alembic_startup_check 266 | 267 | 是否在启动时检查数据库与模型定义的一致性。 268 | 269 | ```properties 270 | ALEMBIC_STARTUP_CHECK=true 271 | ``` 272 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/__main__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from functools import wraps 5 | from argparse import Namespace 6 | from collections.abc import Callable, Iterable 7 | from typing_extensions import TypeVar, ParamSpec, Concatenate 8 | 9 | import click 10 | from alembic.script import Script 11 | 12 | from . import migrate 13 | from .config import plugin_config 14 | from .migrate import AlembicConfig 15 | 16 | _P = ParamSpec("_P") 17 | _R = TypeVar("_R") 18 | 19 | 20 | @click.group() 21 | @click.option( 22 | "-c", 23 | "--config", 24 | envvar="ALEMBIC_CONFIG", 25 | type=click.Path(exists=True, dir_okay=False, path_type=Path), 26 | help='可选的配置文件; 默认为 ALEMBIC_CONFIG 环境变量的值, 或者 "alembic.ini" (如果存在)', 27 | ) 28 | @click.option( 29 | "-n", 30 | "--name", 31 | default="alembic", 32 | show_default=True, 33 | help=".ini 文件中用于 Alembic 配置的小节的名称", 34 | ) 35 | @click.option( 36 | "-x", 37 | multiple=True, 38 | help="自定义 env.py 脚本使用的其他参数, 例如:-x setting1=somesetting -x setting2=somesetting", 39 | ) 40 | @click.option("-q", "--quite", is_flag=True, help="不要输出日志到标准输出") 41 | @click.pass_context 42 | def orm(ctx: click.Context, config: Path, name: str, **_) -> None: 43 | ctx.show_default = True 44 | 45 | if isinstance(plugin_config.alembic_config, AlembicConfig): 46 | ctx.obj = plugin_config.alembic_config 47 | else: 48 | cmd_opts = Namespace(**ctx.params) 49 | 50 | if ctx.invoked_subcommand: 51 | arguments = [] 52 | options = [] 53 | 54 | for param in globals()[ctx.invoked_subcommand].params: 55 | if isinstance(param, click.Argument): 56 | arguments.append(param.name) 57 | elif isinstance(param, click.Option): 58 | options.append(param.name) 59 | 60 | cmd_opts.cmd = ( 61 | getattr(migrate, ctx.invoked_subcommand), 62 | arguments, 63 | options, 64 | ) 65 | 66 | ctx.obj = AlembicConfig(config, ini_section=name, cmd_opts=cmd_opts) 67 | 68 | ctx.call_on_close(ctx.obj.close) 69 | 70 | 71 | def update_cmd_opts( 72 | f: Callable[Concatenate[AlembicConfig, _P], _R] 73 | ) -> Callable[_P, _R]: 74 | @wraps(f) 75 | @click.pass_context 76 | def wrapper(ctx: click.Context, *args: _P.args, **kwargs: _P.kwargs) -> _R: 77 | for key, value in kwargs.items(): 78 | setattr(ctx.obj.cmd_opts, key, value) 79 | 80 | return f(ctx.obj, *args, **kwargs) 81 | 82 | return wrapper 83 | 84 | 85 | @orm.result_callback() 86 | @click.pass_obj 87 | def move_script(config_: AlembicConfig, scripts: Iterable[Script] | None, **_) -> None: 88 | if not scripts: 89 | return 90 | 91 | for script in scripts: 92 | config_.move_script(script) 93 | 94 | 95 | @orm.command("list_templates") 96 | @update_cmd_opts 97 | def list_templates(*args, **kwargs) -> None: 98 | """列出所有可用的模板.""" 99 | 100 | return migrate.list_templates(*args, **kwargs) 101 | 102 | 103 | @orm.command() 104 | @click.argument( 105 | "directory", 106 | default=Path("migrations"), 107 | type=click.Path(file_okay=False, writable=True, resolve_path=True, path_type=Path), 108 | ) 109 | @click.option("-t", "--template", default="generic", help="使用的迁移环境模板") 110 | @click.option( 111 | "--package", is_flag=True, help="在脚本目录和版本目录中创建 __init__.py 文件" 112 | ) 113 | @update_cmd_opts 114 | def init(*args, **kwargs) -> None: 115 | """初始化脚本目录.""" 116 | 117 | return migrate.init(*args, **kwargs) 118 | 119 | 120 | @orm.command() 121 | @click.option("-m", "--message", help="描述") 122 | @click.option("--sql", is_flag=True, help="以 SQL 的形式输出迁移脚本") 123 | @click.option("--head", help="基准版本") 124 | @click.option("--splice", is_flag=True, help="允许非头部迁移作为基准版本") 125 | @click.option("--branch-label", help="分支标签") 126 | @click.option( 127 | "--version-path", 128 | default=None, 129 | type=click.Path(file_okay=False, writable=True, resolve_path=True, path_type=Path), 130 | help="存放迁移脚本的目录", 131 | ) 132 | @click.option("--rev-id", help="指定而不是使用生成的迁移 ID") 133 | @click.option("--depends-on", help="依赖的迁移") 134 | @update_cmd_opts 135 | def revision(*args, **kwargs) -> Iterable[Script]: 136 | """创建一个新迁移脚本.""" 137 | 138 | return migrate.revision(*args, **kwargs) 139 | 140 | 141 | @orm.command() 142 | @update_cmd_opts 143 | def check(*args, **kwargs) -> None: 144 | """检查数据库是否与模型定义一致.""" 145 | 146 | return migrate.check(*args, **kwargs) 147 | 148 | 149 | @orm.command() 150 | @click.argument("revisions", nargs=-1) 151 | @click.option("-m", "--message", help="描述") 152 | @click.option("--branch-label", help="分支标签") 153 | @click.option("--rev-id", help="指定而不是使用生成的迁移 ID") 154 | @update_cmd_opts 155 | def merge(*args, **kwargs) -> Iterable[Script]: 156 | """合并多个迁移.创建一个新的迁移脚本.""" 157 | 158 | return migrate.merge(*args, **kwargs) 159 | 160 | 161 | @orm.command() 162 | @click.argument("revision", required=False) 163 | @click.option("--sql", is_flag=True, help="以 SQL 的形式输出迁移脚本") 164 | @click.option("--tag", help="一个任意的字符串, 可在自定义的 env.py 中使用") 165 | @update_cmd_opts 166 | def upgrade(*args, **kwargs) -> None: 167 | """升级到较新版本.""" 168 | 169 | return migrate.upgrade(*args, **kwargs) 170 | 171 | 172 | @orm.command() 173 | @click.argument("revision") 174 | @click.option("--sql", is_flag=True, help="以 SQL 的形式输出迁移脚本") 175 | @click.option("--tag", help="一个任意的字符串, 可在自定义的 env.py 中使用") 176 | @update_cmd_opts 177 | def downgrade(*args, **kwargs) -> None: 178 | """回退到先前版本.""" 179 | 180 | return migrate.downgrade(*args, **kwargs) 181 | 182 | 183 | @orm.command() 184 | @click.argument("revision", required=False) 185 | @update_cmd_opts 186 | def sync(*args, **kwargs) -> None: 187 | """同步数据库模式 (仅用于开发).""" 188 | 189 | return migrate.sync(*args, **kwargs) 190 | 191 | 192 | @orm.command() 193 | @click.argument("revs", nargs=-1) 194 | @update_cmd_opts 195 | def show(*args, **kwargs) -> None: 196 | """显示迁移的信息.""" 197 | 198 | return migrate.show(*args, **kwargs) 199 | 200 | 201 | @orm.command() 202 | @click.option("-r", "--rev-range", required=False, help="范围") 203 | @click.option("-v", "--verbose", is_flag=True, help="显示详细信息") 204 | @click.option("-i", "--indicate-current", is_flag=True, help="指示出当前迁移") 205 | @update_cmd_opts 206 | def history(*args, **kwargs) -> None: 207 | """显示迁移的历史.""" 208 | 209 | return migrate.history(*args, **kwargs) 210 | 211 | 212 | @orm.command() 213 | @click.option("-v", "--verbose", is_flag=True, help="显示详细信息") 214 | @click.option("--resolve-dependencies", is_flag=True, help="将依赖的迁移视作父迁移") 215 | @update_cmd_opts 216 | def heads(*args, **kwargs) -> None: 217 | """显示所有的分支头.""" 218 | 219 | return migrate.heads(*args, **kwargs) 220 | 221 | 222 | @orm.command() 223 | @click.option("-v", "--verbose", is_flag=True, help="显示详细信息") 224 | @update_cmd_opts 225 | def branches(*args, **kwargs) -> None: 226 | """显示所有的分支.""" 227 | 228 | return migrate.branches(*args, **kwargs) 229 | 230 | 231 | @orm.command() 232 | @click.option("-v", "--verbose", is_flag=True, help="显示详细信息") 233 | @update_cmd_opts 234 | def current(*args, **kwargs) -> None: 235 | """显示当前的迁移.""" 236 | 237 | return migrate.current(*args, **kwargs) 238 | 239 | 240 | @orm.command() 241 | @click.argument("revisions", nargs=-1) 242 | @click.option("--sql", is_flag=True, help="以 SQL 的形式输出迁移脚本") 243 | @click.option("--tag", help="一个任意的字符串, 可在自定义的 env.py 中使用") 244 | @click.option("--purge", is_flag=True, help="在标记前清空数据库版本表") 245 | @update_cmd_opts 246 | def stamp(*args, **kwargs) -> None: 247 | """将数据库标记为特定的迁移版本, 不运行任何迁移.""" 248 | 249 | return migrate.stamp(*args, **kwargs) 250 | 251 | 252 | @orm.command() 253 | @click.argument("rev", default="current") 254 | @update_cmd_opts 255 | def edit(*args, **kwargs) -> None: 256 | """使用 $EDITOR 编辑迁移脚本.""" 257 | 258 | return migrate.edit(*args, **kwargs) 259 | 260 | 261 | @orm.command("ensure_version") 262 | @click.option("--sql", is_flag=True, help="以 SQL 的形式输出迁移脚本") 263 | @update_cmd_opts 264 | def ensure_version(*args, **kwargs) -> None: 265 | """创建版本表.""" 266 | 267 | return migrate.ensure_version(*args, **kwargs) 268 | 269 | 270 | def main(*args, **kwargs) -> None: 271 | from . import _init_orm 272 | 273 | if not (args or kwargs): 274 | kwargs["prog_name"] = "nb orm" 275 | 276 | _init_orm() 277 | orm(*args, **kwargs) 278 | 279 | 280 | if __name__ == "__main__": 281 | main() 282 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from argparse import Namespace 5 | from functools import cache, wraps 6 | from collections.abc import Generator 7 | from contextlib import contextmanager 8 | from typing_extensions import Any, Annotated 9 | 10 | import click 11 | from nonebot.rule import Rule 12 | from alembic.op import get_bind 13 | import sqlalchemy.ext.asyncio as sa_async 14 | from nonebot.permission import Permission 15 | from sqlalchemy.util import greenlet_spawn 16 | from sqlalchemy import URL, Table, MetaData 17 | from nonebot.message import run_postprocessor 18 | from nonebot.params import Depends, DefaultParam 19 | from nonebot.plugin import Plugin, PluginMetadata 20 | from sqlalchemy.log import Identified, _qual_logger_name_for_cls 21 | from nonebot.matcher import Matcher, current_event, current_matcher 22 | from nonebot import logger, require, get_driver, get_plugin_by_module_name 23 | from sqlalchemy.ext.asyncio import AsyncEngine, AsyncConnection, create_async_engine 24 | 25 | from . import migrate 26 | from .param import ORMParam 27 | from .config import Config, plugin_config 28 | from .utils import LoguruHandler, StreamToLogger, coroutine, get_subclasses 29 | 30 | require("nonebot_plugin_localstore") 31 | from nonebot_plugin_localstore import get_data_dir, get_plugin_data_dir 32 | 33 | __all__ = ( 34 | # __init__ 35 | "init_orm", 36 | "get_session", 37 | "AsyncSession", 38 | "get_scoped_session", 39 | "async_scoped_session", 40 | # model 41 | "Model", 42 | # param 43 | "SQLDepends", 44 | # config 45 | "Config", 46 | "plugin_config", 47 | # migrate 48 | "AlembicConfig", 49 | ) 50 | __plugin_meta__ = PluginMetadata( 51 | name="nonebot-plugin-orm", 52 | description="SQLAlchemy ORM support for nonebot", 53 | usage="https://github.com/nonebot/plugin-orm", 54 | type="library", 55 | homepage="https://github.com/nonebot/plugin-orm", 56 | config=Config, 57 | ) 58 | 59 | _binds: dict[type[Model], AsyncEngine] 60 | _engines: dict[str, AsyncEngine] 61 | _metadatas: dict[str, MetaData] 62 | _plugins: dict[str, Plugin] 63 | _session_factory: sa_async.async_sessionmaker[sa_async.AsyncSession] 64 | _scoped_sessions: sa_async.async_scoped_session[sa_async.AsyncSession] 65 | 66 | _data_dir = get_plugin_data_dir() 67 | if ( 68 | _deprecated_data_dir := get_data_dir(None) / "nonebot-plugin-orm" 69 | ).exists() and next(_deprecated_data_dir.iterdir(), None): 70 | if next(_data_dir.iterdir(), None): 71 | raise RuntimeError( 72 | "无法自动迁移数据目录, 请手动将 " 73 | f"{_deprecated_data_dir} 中的数据移动到 {_data_dir} 中." 74 | ) 75 | _data_dir.rmdir() 76 | _deprecated_data_dir.rename(_data_dir) 77 | 78 | _driver = get_driver() 79 | 80 | 81 | @_driver.on_startup 82 | async def init_orm() -> None: 83 | _init_orm() 84 | 85 | cmd_opts = Namespace() 86 | with migrate.AlembicConfig( 87 | stdout=StreamToLogger(), cmd_opts=cmd_opts 88 | ) as alembic_config: 89 | if plugin_config.alembic_startup_check: 90 | cmd_opts.cmd = (migrate.check, [], []) 91 | try: 92 | await greenlet_spawn(migrate.check, alembic_config) 93 | except click.UsageError as e: 94 | try: 95 | click.confirm("目标数据库未更新到最新迁移, 是否更新?", abort=True) 96 | except click.Abort: 97 | raise e 98 | 99 | cmd_opts.cmd = (migrate.upgrade, [], []) 100 | await greenlet_spawn(migrate.upgrade, alembic_config) 101 | else: 102 | logger.warning("跳过启动检查, 正在同步数据库模式...") 103 | cmd_opts.cmd = (migrate.sync, ["revision"], []) 104 | await greenlet_spawn(migrate.sync, alembic_config) 105 | 106 | 107 | def get_session(**local_kw: Any) -> sa_async.AsyncSession: 108 | try: 109 | return _session_factory(**local_kw) 110 | except NameError: 111 | _init_orm() 112 | 113 | return _session_factory(**local_kw) 114 | 115 | 116 | # NOTE: NoneBot DI will run sync function in thread pool executor, 117 | # which is poor performance for this simple function, so we wrap it as a coroutine function. 118 | AsyncSession = Annotated[ 119 | sa_async.AsyncSession, 120 | Depends(coroutine(wraps(lambda: None)(get_session)), use_cache=False), 121 | ] 122 | 123 | 124 | def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]: 125 | try: 126 | return _scoped_sessions 127 | except NameError: 128 | _init_orm() 129 | 130 | return _scoped_sessions 131 | 132 | 133 | async_scoped_session = Annotated[ 134 | sa_async.async_scoped_session[sa_async.AsyncSession], 135 | Depends(coroutine(get_scoped_session)), 136 | ] 137 | 138 | 139 | @contextmanager 140 | def _patch_migrate_session() -> Generator[None, Any, None]: 141 | global _session_factory, _scoped_sessions 142 | 143 | session_factory, scoped_sessions = _session_factory, _scoped_sessions 144 | 145 | _session_factory = sa_async.async_sessionmaker( 146 | AsyncConnection._retrieve_proxy_for_target(get_bind()), 147 | **plugin_config.sqlalchemy_session_options, 148 | ) 149 | _scoped_sessions = sa_async.async_scoped_session( 150 | _session_factory, 151 | lambda: (id(current_event.get(None)), current_matcher.get(None)), 152 | ) 153 | 154 | yield 155 | 156 | _session_factory, _scoped_sessions = session_factory, scoped_sessions 157 | 158 | 159 | def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEngine: 160 | if isinstance(engine, AsyncEngine): 161 | return engine 162 | 163 | options = plugin_config.sqlalchemy_engine_options.copy() 164 | 165 | if isinstance(engine, dict): 166 | url: str | URL = engine.pop("url") 167 | options.update(engine) 168 | else: 169 | url = engine 170 | 171 | return create_async_engine(url, **options) 172 | 173 | 174 | def _init_engines(): 175 | global _engines, _metadatas 176 | 177 | _engines = {} 178 | _metadatas = {"": MetaData()} 179 | for name, engine in plugin_config.sqlalchemy_binds.items(): 180 | _engines[name] = _create_engine(engine) 181 | _metadatas[name] = MetaData() 182 | 183 | if plugin_config.sqlalchemy_database_url: 184 | _engines[""] = _create_engine(plugin_config.sqlalchemy_database_url) 185 | 186 | if "" in _engines: 187 | return 188 | 189 | try: 190 | import aiosqlite 191 | 192 | del aiosqlite 193 | except (ImportError, RuntimeError): 194 | raise ValueError( 195 | '必须指定一个默认数据库 (SQLALCHEMY_DATABASE_URL 或 SQLALCHEMY_BINDS[""]). ' 196 | "可以通过 `pip install nonebot-plugin-orm[default]` 获得开箱即用的数据库配置." 197 | ) from None 198 | 199 | _engines[""] = _create_engine( 200 | URL.create("sqlite+aiosqlite", database=str(_data_dir / "db.sqlite3")) 201 | ) 202 | 203 | 204 | def _init_table(): 205 | global _binds, _metadatas, _plugins 206 | 207 | _binds = {} 208 | _plugins = {} 209 | 210 | _get_plugin_by_module_name = cache(get_plugin_by_module_name) 211 | for model in set(get_subclasses(Model)): 212 | table: Table | None = getattr(model, "__table__", None) 213 | 214 | if table is None or (bind_key := table.info.get("bind_key")) is None: 215 | continue 216 | 217 | if plugin := _get_plugin_by_module_name(model.__module__): 218 | _plugins[plugin.name.replace("-", "_")] = plugin 219 | 220 | _binds[model] = _engines.get(bind_key, _engines[""]) 221 | table.to_metadata(_metadatas.get(bind_key, _metadatas[""])) 222 | 223 | 224 | def _init_orm(): 225 | global _session_factory, _scoped_sessions 226 | 227 | _init_engines() 228 | _init_table() 229 | _session_factory = sa_async.async_sessionmaker( 230 | _engines[""], binds=_binds, **plugin_config.sqlalchemy_session_options 231 | ) 232 | _scoped_sessions = sa_async.async_scoped_session( 233 | _session_factory, 234 | lambda: (id(current_event.get(None)), current_matcher.get(None)), 235 | ) 236 | 237 | run_postprocessor(_scoped_sessions.remove) 238 | 239 | 240 | def _init_logger(): 241 | handler = LoguruHandler() 242 | logging.getLogger("alembic").addHandler(handler) 243 | logging.getLogger("sqlalchemy").addHandler(handler) 244 | 245 | log_level = _driver.config.log_level 246 | if isinstance(log_level, str): 247 | log_level = logger.level(log_level).no 248 | 249 | echo_log_level = log_level if plugin_config.sqlalchemy_echo else logging.WARNING 250 | 251 | levels = { 252 | "alembic": log_level, 253 | "sqlalchemy": log_level, 254 | **{ 255 | _qual_logger_name_for_cls(cls): echo_log_level 256 | for cls in set(get_subclasses(Identified)) 257 | }, 258 | } 259 | 260 | for name, level in levels.items(): 261 | logging.getLogger(name).setLevel(level) 262 | 263 | 264 | _init_logger() 265 | 266 | 267 | def _init_param(): 268 | for cls in (Rule, Permission): 269 | cls.HANDLER_PARAM_TYPES.insert(-1, ORMParam) 270 | 271 | Matcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + ( 272 | ORMParam, 273 | DefaultParam, 274 | ) 275 | 276 | 277 | _init_param() 278 | 279 | 280 | from .model import Model as Model 281 | from .config import Config as Config 282 | from .param import SQLDepends as SQLDepends 283 | from .config import plugin_config as plugin_config 284 | from .migrate import AlembicConfig as AlembicConfig 285 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | import json 5 | import logging 6 | from io import StringIO 7 | from pathlib import Path 8 | from functools import wraps 9 | from itertools import repeat 10 | from contextlib import suppress 11 | from operator import methodcaller 12 | from importlib.resources import files 13 | from dataclasses import field, dataclass 14 | from inspect import Parameter, Signature 15 | from collections.abc import Callable, Iterable, Coroutine, Generator 16 | from importlib.metadata import Distribution, PackageNotFoundError, distribution 17 | from typing_extensions import ( 18 | TYPE_CHECKING, 19 | Any, 20 | TypeVar, 21 | Annotated, 22 | ParamSpec, 23 | get_args, 24 | get_origin, 25 | ) 26 | 27 | import click 28 | from nonebot.plugin import Plugin 29 | from nonebot.params import Depends 30 | from nonebot import logger, get_driver 31 | from sqlalchemy.sql.selectable import ExecutableReturnsRows 32 | from nonebot.typing import origin_is_union, origin_is_literal 33 | 34 | if sys.version_info >= (3, 10): 35 | from importlib.metadata import packages_distributions 36 | else: 37 | from importlib_metadata import packages_distributions 38 | 39 | 40 | if TYPE_CHECKING: 41 | from . import async_scoped_session 42 | 43 | 44 | _T = TypeVar("_T") 45 | _P = ParamSpec("_P") 46 | 47 | 48 | DependsInner = type(Depends()) 49 | 50 | 51 | class LoguruHandler(logging.Handler): 52 | def emit(self, record: logging.LogRecord): 53 | try: 54 | level = logger.level(record.levelname).name 55 | if record.levelno <= logging.INFO: 56 | level = {"DEBUG": "TRACE", "INFO": "DEBUG"}.get(level, level) 57 | except ValueError: 58 | level = record.levelno 59 | 60 | frame, depth = sys._getframe(6), 6 61 | while frame and frame.f_code.co_filename == logging.__file__: 62 | frame = frame.f_back 63 | depth += 1 64 | 65 | logger.opt(depth=depth, exception=record.exc_info).log( 66 | level, record.getMessage() 67 | ) 68 | 69 | 70 | class StreamToLogger(StringIO): 71 | """Use for startup migrate, AlembicConfig.print_stdout() only""" 72 | 73 | def __init__(self, level="INFO"): 74 | self._level = level 75 | 76 | def write(self, buffer: str): 77 | frame, depth = sys._getframe(3), 3 78 | while frame and frame.f_code.co_name != "print_stdout": 79 | frame = frame.f_back 80 | depth += 1 81 | 82 | for line in buffer.rstrip().splitlines(): 83 | logger.opt(depth=depth + 1).log(self._level, line.rstrip()) 84 | 85 | return len(buffer) 86 | 87 | def flush(self): 88 | pass 89 | 90 | 91 | @dataclass(unsafe_hash=True) 92 | class Option: 93 | stream: bool = True 94 | scalars: bool = False 95 | calls: tuple[methodcaller, ...] = field(default_factory=tuple) 96 | result: methodcaller | None = None 97 | 98 | 99 | @dataclass 100 | class Dependency: 101 | __signature__: Signature = field(init=False) 102 | 103 | statement: ExecutableReturnsRows 104 | option: Option 105 | 106 | def __post_init__(self) -> None: 107 | from . import async_scoped_session 108 | 109 | self.__signature__ = Signature( 110 | [ 111 | Parameter( 112 | "_session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session 113 | ), 114 | *( 115 | Parameter(name, Parameter.KEYWORD_ONLY, default=depends) 116 | for name, depends in self.statement.compile().params.items() 117 | if isinstance(depends, DependsInner) 118 | ), 119 | ] 120 | ) 121 | 122 | async def __call__(self, *, _session: async_scoped_session, **params: Any) -> Any: 123 | if self.option.stream: 124 | result = await _session.stream(self.statement, params) 125 | else: 126 | result = await _session.execute(self.statement, params) 127 | 128 | if self.option.scalars: 129 | result = result.scalars() 130 | 131 | for call in self.option.calls: 132 | result = call(result) 133 | 134 | if call := self.option.result: 135 | result = call(result) 136 | 137 | if self.option.stream: 138 | result = await result 139 | 140 | return result 141 | 142 | def __hash__(self) -> int: 143 | return hash((self.statement, self.option)) 144 | 145 | 146 | def generic_issubclass(scls: Any, cls: Any) -> bool | list[Any]: 147 | if isinstance(cls, tuple): 148 | return _map_generic_issubclass(repeat(scls), cls) 149 | 150 | if scls is Any: 151 | return [cls] 152 | 153 | if cls is Any: 154 | return True 155 | 156 | with suppress(TypeError): 157 | return issubclass(scls, cls) 158 | 159 | scls_origin, scls_args = get_origin(scls) or scls, get_args(scls) 160 | cls_origin, cls_args = get_origin(cls) or cls, get_args(cls) 161 | 162 | if scls_origin is tuple and cls_origin is tuple: 163 | if len(scls_args) == 2 and scls_args[1] is Ellipsis: 164 | return generic_issubclass(scls_args[0], cls_args) 165 | 166 | if len(cls_args) == 2 and cls_args[1] is Ellipsis: 167 | return _map_generic_issubclass( 168 | scls_args, repeat(cls_args[0]), failfast=True 169 | ) 170 | 171 | if scls_origin is Annotated: 172 | return generic_issubclass(scls_args[0], cls) 173 | if cls_origin is Annotated: 174 | return generic_issubclass(scls, cls_args[0]) 175 | 176 | if origin_is_union(scls_origin): 177 | return _map_generic_issubclass(scls_args, repeat(cls), failfast=True) 178 | if origin_is_union(cls_origin): 179 | return generic_issubclass(scls, cls_args) 180 | 181 | if origin_is_literal(scls_origin) and origin_is_literal(cls_origin): 182 | return set(scls_args) <= set(cls_args) 183 | 184 | try: 185 | if not issubclass(scls_origin, cls_origin): 186 | return False 187 | except TypeError: 188 | return False 189 | 190 | if not cls_args: 191 | return True 192 | 193 | if len(scls_args) != len(cls_args): 194 | return False 195 | 196 | return _map_generic_issubclass(scls_args, cls_args, failfast=True) 197 | 198 | 199 | def _map_generic_issubclass( 200 | scls: Iterable[Any], cls: Iterable[Any], *, failfast: bool = False 201 | ) -> bool | list[Any]: 202 | results = [] 203 | for scls_arg, cls_arg in zip(scls, cls): 204 | if not (result := generic_issubclass(scls_arg, cls_arg)) and failfast: 205 | return False 206 | elif isinstance(result, list): 207 | results.extend(result) 208 | elif not isinstance(result, bool): 209 | results.append(result) 210 | 211 | return results or False 212 | 213 | 214 | def return_progressbar(func: Callable[_P, Iterable[_T]]) -> Callable[_P, Iterable[_T]]: 215 | log_level = get_driver().config.log_level 216 | if isinstance(log_level, str): 217 | log_level = logger.level(log_level).no 218 | 219 | if log_level <= logger.level("INFO").no: 220 | return func 221 | 222 | @wraps(func) 223 | def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Iterable[_T]: 224 | with click.progressbar( 225 | func(*args, **kwargs), label="运行迁移中", item_show_func=str 226 | ) as bar: 227 | yield from bar 228 | 229 | return wrapper 230 | 231 | 232 | def get_parent_plugins(plugin: Plugin | None) -> Generator[Plugin, Any, None]: 233 | while plugin: 234 | yield plugin 235 | plugin = plugin.parent_plugin 236 | 237 | 238 | pkgs = packages_distributions() 239 | 240 | 241 | def is_editable(plugin: Plugin) -> bool: 242 | *_, plugin = get_parent_plugins(plugin) 243 | 244 | try: 245 | path = files(plugin.module) 246 | except TypeError: 247 | return False 248 | 249 | if not isinstance(path, Path) or "site-packages" in path.parts: 250 | return False 251 | 252 | dist: Distribution | None = None 253 | 254 | with suppress(PackageNotFoundError): 255 | dist = distribution(plugin.name.replace("_", "-")) 256 | 257 | if not dist and plugin.module.__file__: 258 | path = Path(plugin.module.__file__) 259 | for name in pkgs.get(plugin.module_name.split(".")[0], ()): 260 | dist = distribution(name) 261 | if path in (file.locate() for file in dist.files or ()): 262 | break 263 | else: 264 | dist = None 265 | 266 | if not dist: 267 | return True 268 | 269 | # https://github.com/pdm-project/pdm/blob/fee1e6bffd7de30315e2134e19f9a6f58e15867c/src/pdm/utils.py#L361-L374 270 | if getattr(dist, "link_file", None) is not None: 271 | return True 272 | 273 | direct_url = dist.read_text("direct_url.json") 274 | if not direct_url: 275 | return False 276 | 277 | direct_url_data = json.loads(direct_url) 278 | return direct_url_data.get("dir_info", {}).get("editable", False) 279 | 280 | 281 | def get_subclasses(cls: type[_T]) -> Generator[type[_T], None, None]: 282 | yield from cls.__subclasses__() 283 | for subclass in cls.__subclasses__(): 284 | yield from get_subclasses(subclass) 285 | 286 | 287 | def coroutine(func: Callable[_P, _T]) -> Callable[_P, Coroutine[Any, Any, _T]]: 288 | @wraps(func) 289 | async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: 290 | return func(*args, **kwargs) 291 | 292 | return wrapper 293 | -------------------------------------------------------------------------------- /nonebot_plugin_orm/migrate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | import shutil 6 | import inspect 7 | from pathlib import Path 8 | from pprint import pformat 9 | from argparse import Namespace 10 | from operator import attrgetter 11 | from itertools import filterfalse 12 | from tempfile import TemporaryDirectory 13 | from configparser import DuplicateSectionError 14 | from typing_extensions import Any, Self, TextIO, cast 15 | from contextlib import ExitStack, suppress, contextmanager 16 | from collections.abc import Mapping, Iterable, Sequence, Generator 17 | 18 | import click 19 | import alembic 20 | import sqlalchemy 21 | from alembic.config import Config 22 | from sqlalchemy.util import asbool 23 | from nonebot import logger, get_plugin 24 | from nonebot.matcher import current_matcher 25 | from sqlalchemy import MetaData, Connection 26 | from alembic.util.editor import open_in_editor 27 | from alembic.script import Script, ScriptDirectory 28 | from alembic.util.langhelpers import rev_id as _rev_id 29 | from alembic.operations.ops import UpgradeOps, DowngradeOps 30 | from sqlalchemy.ext.asyncio import AsyncConnection, async_sessionmaker 31 | from alembic.migration import StampStep, RevisionStep, MigrationContext 32 | from alembic.runtime.environment import EnvironmentContext, ProcessRevisionDirectiveFn 33 | from alembic.autogenerate.api import ( 34 | RevisionContext, 35 | produce_migrations, 36 | render_python_code, 37 | ) 38 | 39 | from .exception import AutogenerateDiffsDetected 40 | from .utils import is_editable, get_parent_plugins, return_progressbar 41 | 42 | if sys.version_info >= (3, 12): 43 | from importlib.resources import files, as_file 44 | else: 45 | from importlib_resources import files, as_file 46 | 47 | 48 | __all__ = ( 49 | "AlembicConfig", 50 | "list_templates", 51 | "init", 52 | "revision", 53 | "check", 54 | "merge", 55 | "upgrade", 56 | "downgrade", 57 | "sync", 58 | "show", 59 | "history", 60 | "heads", 61 | "branches", 62 | "current", 63 | "stamp", 64 | "edit", 65 | "ensure_version", 66 | ) 67 | 68 | _SPLIT_ON_PATH = { 69 | None: " ", 70 | "space": " ", 71 | "os": os.pathsep, 72 | ":": ":", 73 | ";": ";", 74 | } 75 | 76 | 77 | class AlembicConfig(Config): 78 | _exit_stack: ExitStack 79 | _plugin_version_locations: dict[str, Path] 80 | _temp_dir: Path 81 | 82 | def __init__( 83 | self, 84 | file_: str | os.PathLike[str] | None = None, 85 | toml_file: str | os.PathLike[str] | None = None, 86 | ini_section: str = "alembic", 87 | output_buffer: TextIO | None = None, 88 | stdout: TextIO = sys.stdout, 89 | cmd_opts: Namespace | None = None, 90 | config_args: Mapping[str, Any] = {}, 91 | attributes: dict = {}, 92 | ) -> None: 93 | from . import _engines, _metadatas, plugin_config 94 | 95 | self._exit_stack = ExitStack() 96 | self._plugin_version_locations = {} 97 | self._temp_dir = Path(self._exit_stack.enter_context(TemporaryDirectory())) 98 | 99 | if file_ is None and isinstance(plugin_config.alembic_config, Path): 100 | file_ = plugin_config.alembic_config 101 | 102 | if plugin_config.alembic_script_location: 103 | script_location = plugin_config.alembic_script_location 104 | elif ( 105 | Path("migrations", "env.py").is_file() 106 | and Path("migrations", "script.py.mako").is_file() 107 | ): 108 | script_location = Path("migrations") 109 | elif len(_engines) == 1: 110 | script_location = self._exit_stack.enter_context( 111 | as_file(files(__name__) / "templates" / "generic") 112 | ) 113 | else: 114 | script_location = self._exit_stack.enter_context( 115 | as_file(files(__name__) / "templates" / "multidb") 116 | ) 117 | 118 | super().__init__( 119 | file_, 120 | toml_file, 121 | ini_section, 122 | output_buffer, 123 | stdout, 124 | cmd_opts, 125 | { 126 | "script_location": script_location, 127 | "prepend_sys_path": ".", 128 | "revision_environment": "true", 129 | "version_path_separator": "os", 130 | } 131 | | dict(config_args), 132 | { 133 | "engines": _engines, 134 | "metadatas": _metadatas, 135 | } 136 | | attributes, 137 | ) 138 | 139 | self._init_post_write_hooks() 140 | self._init_version_locations() 141 | 142 | def __enter__(self: Self) -> Self: 143 | return self 144 | 145 | def __exit__(self, *_) -> None: 146 | self.close() 147 | 148 | def close(self) -> None: 149 | self._exit_stack.close() 150 | 151 | def get_template_directory(self) -> str: 152 | return str(Path(__file__).parent / "templates") 153 | 154 | def print_stdout(self, text: str, *arg, **kwargs) -> None: 155 | if not getattr(self.cmd_opts, "quite", False): 156 | click.secho(text % arg, self.stdout, **kwargs) 157 | 158 | @contextmanager 159 | def status(self, status_msg: str) -> Generator[None, Any, None]: 160 | self.print_stdout(f"{status_msg} ...", nl=False) 161 | 162 | try: 163 | yield 164 | except: 165 | self.print_stdout(" 失败", fg="red") 166 | raise 167 | else: 168 | self.print_stdout(" 成功", fg="green") 169 | 170 | def move_script(self, script: Script) -> Path: 171 | script_path = Path(script.path) 172 | 173 | try: 174 | script_path = script_path.relative_to(self._temp_dir) 175 | except ValueError: 176 | return script_path 177 | 178 | plugin_name = script_path.parent.name 179 | version_location = self._plugin_version_locations.get(plugin_name) 180 | 181 | if not version_location: 182 | version_location = self._plugin_version_locations.get("") 183 | 184 | if not version_location: 185 | self.print_stdout( 186 | f'无法找到 {plugin_name or ""} 对应的版本目录, 忽略 "{script.path}"', 187 | fg="yellow", 188 | ) 189 | return script_path 190 | 191 | version_location.mkdir(parents=True, exist_ok=True) 192 | return Path(shutil.move(script.path, version_location)) 193 | 194 | def _add_post_write_hook(self, name: str, **kwargs: str) -> None: 195 | self.set_section_option( 196 | "post_write_hooks", 197 | "hooks", 198 | f"{self.get_section_option('post_write_hooks', 'hooks', '')}, {name}", 199 | ) 200 | for key, value in kwargs.items(): 201 | self.set_section_option("post_write_hooks", f"{name}.{key}", value) 202 | 203 | def _init_post_write_hooks(self) -> None: 204 | with suppress(DuplicateSectionError): 205 | self.file_config.add_section("post_write_hooks") 206 | 207 | if self.get_section_option("post_write_hooks", "hooks"): 208 | return 209 | 210 | with suppress(ImportError): 211 | import isort 212 | 213 | del isort 214 | self._add_post_write_hook( 215 | "isort", 216 | type="console_scripts", 217 | entrypoint="isort", 218 | options="REVISION_SCRIPT_FILENAME --profile black", 219 | ) 220 | 221 | with suppress(ImportError): 222 | import black 223 | 224 | del black 225 | self._add_post_write_hook( 226 | "black", 227 | type="console_scripts", 228 | entrypoint="black", 229 | options="REVISION_SCRIPT_FILENAME", 230 | ) 231 | 232 | def _init_version_locations(self) -> None: 233 | if self.get_main_option("version_locations"): 234 | return 235 | 236 | from . import _plugins, _data_dir, plugin_config 237 | 238 | alembic_version_locations = plugin_config.alembic_version_locations 239 | if isinstance(alembic_version_locations, dict): 240 | main_version_location = alembic_version_locations.get("") 241 | else: 242 | main_version_location = alembic_version_locations 243 | 244 | self._plugin_version_locations[""] = main_version_location or Path( 245 | "migrations", "versions" 246 | ) 247 | 248 | temp_version_locations: dict[Path, Path] = { 249 | _data_dir / "migrations": self._temp_dir 250 | } 251 | 252 | for plugin in _plugins.values(): 253 | if plugin.metadata and ( 254 | version_module := plugin.metadata.extra.get("orm_version_location") 255 | ): 256 | if sys.version_info[:2] == (3, 9): 257 | # importlib_resources.files() return a opaque Traversable object 258 | # even if the anchor is a namespace package in Python 3.9 259 | version_location = Path(version_module.__path__[0]) 260 | else: 261 | version_location = files(version_module) 262 | else: 263 | if sys.version_info[:2] == (3, 9): 264 | version_location = Path(plugin.module.__path__[0]) / "migrations" 265 | else: 266 | version_location = files(plugin.module) / "migrations" 267 | 268 | temp_version_location = Path( 269 | *map(attrgetter("name"), reversed(list(get_parent_plugins(plugin)))), 270 | ) 271 | 272 | if ( 273 | not main_version_location 274 | and is_editable(plugin) 275 | and isinstance(version_location, Path) 276 | ): 277 | self._plugin_version_locations[plugin.name] = version_location 278 | else: 279 | self._plugin_version_locations[plugin.name] = ( 280 | self._plugin_version_locations[""] / temp_version_location 281 | ) 282 | 283 | temp_version_locations[ 284 | self._exit_stack.enter_context(as_file(version_location)) 285 | ] = (self._temp_dir / temp_version_location) 286 | 287 | if isinstance(alembic_version_locations, dict): 288 | for plugin_name, version_location in alembic_version_locations.items(): 289 | if not (plugin := get_plugin(plugin_name)): 290 | continue 291 | 292 | version_location = Path(version_location) 293 | self._plugin_version_locations[plugin_name] = version_location 294 | temp_version_locations[version_location] = self._temp_dir.joinpath( 295 | *map( 296 | attrgetter("name"), 297 | reversed(list(get_parent_plugins(plugin))), 298 | ) 299 | ) 300 | 301 | temp_version_locations[self._plugin_version_locations[""]] = self._temp_dir 302 | 303 | for src, dst in temp_version_locations.items(): 304 | dst.mkdir(parents=True, exist_ok=True) 305 | with suppress(FileNotFoundError, shutil.Error): 306 | shutil.copytree(src, dst, dirs_exist_ok=True) 307 | 308 | pathsep = _SPLIT_ON_PATH[self.get_main_option("version_path_separator")] 309 | self.set_main_option( 310 | "version_locations", 311 | pathsep.join( 312 | str(path) 313 | for path in self._temp_dir.glob("**") 314 | if path.name != "__pycache__" 315 | ), 316 | ) 317 | 318 | 319 | def _move_run_scripts(config: AlembicConfig, script: ScriptDirectory, current) -> None: 320 | from . import _data_dir 321 | 322 | def ignore(path: str, names: list[str]) -> set[str]: 323 | path_ = Path(path) 324 | 325 | return set( 326 | name 327 | for name in names 328 | if Path(name).suffix in {".py", ".pyc", ".pyo"} 329 | and path_ / name not in run_script_path 330 | ) 331 | 332 | run_script_path = set( 333 | Path(sc.path) for sc in script.walk_revisions(base="base", head=current) 334 | ) 335 | shutil.rmtree(_data_dir / "migrations", ignore_errors=True) 336 | shutil.copytree( 337 | config._temp_dir, _data_dir / "migrations", ignore=ignore, dirs_exist_ok=True 338 | ) 339 | 340 | 341 | def list_templates(config: AlembicConfig) -> None: 342 | """列出所有可用的模板. 343 | 344 | 参数: 345 | config: `AlembicConfig` 对象 346 | """ 347 | 348 | config.print_stdout("可用的模板:\n") 349 | for tempname in Path(config.get_template_directory()).iterdir(): 350 | with (tempname / "README").open(encoding="utf-8") as readme: 351 | synopsis = readme.readline().rstrip() 352 | 353 | config.print_stdout(f"{tempname.name} - {synopsis}") 354 | 355 | config.print_stdout('\n可以通过 "init" 命令使用模板, 例如: ') 356 | config.print_stdout("\n nb orm init --template generic ./scripts") 357 | 358 | 359 | def init( 360 | config: AlembicConfig, 361 | directory: Path = Path("migrations"), 362 | template: str = "generic", 363 | package: bool = False, 364 | ) -> None: 365 | """初始化脚本目录. 366 | 367 | 参数: 368 | config: `AlembicConfig` 对象 369 | directory: 目标目录路径 370 | template: 使用的迁移环境模板 371 | package: 为 True 时, 在脚本目录和版本目录中创建 `__init__.py` 脚本 372 | """ 373 | 374 | if ( 375 | directory.is_dir() 376 | and next(directory.iterdir(), False) 377 | and not click.confirm(f'目录 "{directory}" 已存在并且不为空, 是否继续初始化?') 378 | ): 379 | raise click.BadParameter( 380 | f'目录 "{directory}" 已存在并且不为空', param_hint="DIRECTORY" 381 | ) 382 | 383 | template_dir = Path(config.get_template_directory()) / template 384 | if not template_dir.is_dir(): 385 | raise click.BadParameter(f"模板 {template} 不存在", param_hint="--template") 386 | 387 | with config.status(f'生成目录 "{directory}"'): 388 | shutil.copytree( 389 | template_dir, 390 | directory, 391 | ignore=None if package else shutil.ignore_patterns("__init__.py"), 392 | dirs_exist_ok=True, 393 | ) 394 | 395 | 396 | def revision( 397 | config: AlembicConfig, 398 | message: str | None = None, 399 | sql: bool | None = False, 400 | head: str | None = None, 401 | splice: bool = False, 402 | branch_label: str | None = None, 403 | version_path: str | Path | None = None, 404 | rev_id: str | None = None, 405 | depends_on: str | None = None, 406 | process_revision_directives: ProcessRevisionDirectiveFn | None = None, 407 | ) -> Iterable[Script]: 408 | """创建一个新迁移脚本. 409 | 410 | 参数: 411 | config: `AlembicConfig` 对象 412 | message: 迁移的描述 413 | sql: 是否以 SQL 的形式输出迁移脚本 414 | head: 迁移的基准版本, 如果提供了 branch_label 默认为 `branch_label@head`, 否则为主分支的头 415 | splice: 是否将迁移作为一个新的分支的头; 当 `head` 不是一个分支的头时, 此项必须为 `True` 416 | branch_label: 迁移的分支标签 417 | version_path: 存放迁移脚本的目录 418 | rev_id: 迁移的 ID 419 | depends_on: 迁移的依赖 420 | process_revision_directives: 迁移的处理函数, 参见: `alembic.EnvironmentContext.configure.process_revision_directives` 421 | """ 422 | from . import _plugins 423 | 424 | if version_path: 425 | version_path = Path(version_path).resolve() 426 | version_locations = config.get_main_option("version_locations", "") 427 | pathsep = _SPLIT_ON_PATH[config.get_main_option("version_path_separator")] 428 | 429 | if version_path not in ( 430 | Path(path).resolve() for path in version_locations.split(pathsep) 431 | ): 432 | config.set_main_option( 433 | "version_locations", f"{version_locations}{pathsep}{version_path}" 434 | ) 435 | logger.warning( 436 | f'临时将目录 "{version_path}" 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中' 437 | ) 438 | elif branch_label and (plugin := _plugins.get(branch_label)): 439 | version_path = config._temp_dir.joinpath( 440 | *map( 441 | attrgetter("name"), 442 | reversed(list(get_parent_plugins(plugin))), 443 | ) 444 | ) 445 | elif not head: 446 | version_path = config._temp_dir 447 | 448 | if isinstance(version_path, Path): 449 | version_path = str(version_path) 450 | 451 | script = ScriptDirectory.from_config(config) 452 | 453 | if not head: 454 | scripts = script.get_revisions(script.get_heads()) 455 | if branch_label: 456 | if any(branch_label in sc.branch_labels for sc in scripts): 457 | head = f"{branch_label}@head" 458 | branch_label = None 459 | else: 460 | head = "base" 461 | elif len(scripts) <= 1: 462 | head = "head" 463 | else: 464 | try: 465 | head = next(filterfalse(attrgetter("branch_labels"), scripts)).revision 466 | except StopIteration: 467 | head = "base" 468 | 469 | revision_context = RevisionContext( 470 | config, 471 | script, 472 | dict( 473 | message=message, 474 | autogenerate=not sql, 475 | sql=sql, 476 | head=head, 477 | splice=splice, 478 | branch_label=branch_label, 479 | version_path=version_path, 480 | rev_id=rev_id, 481 | depends_on=depends_on, 482 | ), 483 | process_revision_directives=process_revision_directives, 484 | ) 485 | 486 | if sql: 487 | 488 | def retrieve_migrations( 489 | rev, context: MigrationContext 490 | ) -> Iterable[StampStep | RevisionStep]: 491 | revision_context.run_no_autogenerate(rev, context) 492 | return () 493 | 494 | else: 495 | 496 | def retrieve_migrations( 497 | rev, context: MigrationContext 498 | ) -> Iterable[StampStep | RevisionStep]: 499 | if set(script.get_revisions(rev)) != set(script.get_revisions("heads")): 500 | raise click.UsageError( 501 | "目标数据库未更新到最新迁移. 请通过 `nb orm upgrade` 升级数据库后重试." 502 | ) 503 | revision_context.run_autogenerate(rev, context) 504 | return () 505 | 506 | with EnvironmentContext( 507 | config, 508 | script, 509 | fn=retrieve_migrations, 510 | as_sql=sql, 511 | template_args=revision_context.template_args, 512 | revision_context=revision_context, 513 | ): 514 | script.run_env() 515 | 516 | return filter(None, revision_context.generate_scripts()) 517 | 518 | 519 | def check(config: AlembicConfig) -> None: 520 | """检查数据库是否与模型定义一致. 521 | 522 | 参数: 523 | config: `AlembicConfig` 对象 524 | """ 525 | 526 | script = ScriptDirectory.from_config(config) 527 | 528 | revision_context = RevisionContext( 529 | config, 530 | script, 531 | dict( 532 | message=None, 533 | autogenerate=True, 534 | sql=False, 535 | head="head", 536 | splice=False, 537 | branch_label=None, 538 | version_path=None, 539 | rev_id=None, 540 | depends_on=None, 541 | ), 542 | ) 543 | 544 | def retrieve_migrations( 545 | rev, context: MigrationContext 546 | ) -> Iterable[StampStep | RevisionStep]: 547 | if set(script.get_revisions(rev)) != set(script.get_revisions("heads")): 548 | raise click.UsageError( 549 | "目标数据库未更新到最新迁移. 请通过 `nb orm upgrade` 升级数据库后重试." 550 | ) 551 | revision_context.run_autogenerate(rev, context) 552 | return () 553 | 554 | with EnvironmentContext( 555 | config, 556 | script, 557 | fn=retrieve_migrations, 558 | as_sql=False, 559 | template_args=revision_context.template_args, 560 | revision_context=revision_context, 561 | ): 562 | script.run_env() 563 | 564 | migration_script = revision_context.generated_revisions[-1] 565 | diffs = cast(UpgradeOps, migration_script.upgrade_ops).as_diffs() 566 | if diffs: 567 | raise AutogenerateDiffsDetected( 568 | f"检测到新的升级操作:\n{pformat(diffs)}", revision_context, diffs 569 | ) 570 | else: 571 | config.print_stdout("没有检测到新的升级操作") 572 | 573 | 574 | def merge( 575 | config: AlembicConfig, 576 | revisions: tuple[str, ...], 577 | message: str | None = None, 578 | branch_label: str | None = None, 579 | rev_id: str | None = None, 580 | ) -> Iterable[Script]: 581 | """合并多个迁移. 创建一个新的迁移脚本. 582 | 583 | 参数: 584 | config: `AlembicConfig` 对象 585 | revisions: 要合并的迁移 586 | message: 迁移的描述 587 | branch_label: 迁移的分支标签 588 | rev_id: 迁移的 ID 589 | """ 590 | 591 | script = ScriptDirectory.from_config(config) 592 | template_args: dict[str, Any] = {"config": config} 593 | 594 | environment = asbool(config.get_main_option("revision_environment")) 595 | 596 | if environment: 597 | with EnvironmentContext( 598 | config, 599 | script, 600 | fn=lambda *_: (), 601 | as_sql=False, 602 | template_args=template_args, 603 | ): 604 | script.run_env() 605 | 606 | sc = script.generate_revision( 607 | rev_id or _rev_id(), 608 | message, 609 | refresh=True, 610 | head=revisions, 611 | branch_labels=branch_label, 612 | **template_args, 613 | ) 614 | return (sc,) if sc else () 615 | 616 | 617 | def upgrade( 618 | config: AlembicConfig, 619 | revision: str | None = None, 620 | sql: bool = False, 621 | tag: str | None = None, 622 | ) -> None: 623 | """升级到较新版本. 624 | 625 | 参数: 626 | config: `AlembicConfig` 对象 627 | revision: 目标迁移 628 | sql: 是否以 SQL 的形式输出迁移脚本 629 | tag: 一个任意的字符串, 可在自定义的 `env.py` 中通过 `alembic.EnvironmentContext.get_tag_argument` 获得 630 | """ 631 | 632 | script = ScriptDirectory.from_config(config) 633 | 634 | if revision is None: 635 | revision = "head" if len(script.get_heads()) == 1 else "heads" 636 | 637 | starting_rev = None 638 | if ":" in revision: 639 | if not sql: 640 | raise click.BadParameter( 641 | "不允许在非 --sql 模式下使用迁移范围", param_hint="REVISION" 642 | ) 643 | starting_rev, revision = revision.split(":", 2) 644 | 645 | @return_progressbar 646 | def upgrade(rev, _) -> Iterable[StampStep | RevisionStep]: 647 | from . import _patch_migrate_session 648 | 649 | with _patch_migrate_session(): 650 | yield from script._upgrade_revs(revision, rev) 651 | 652 | _move_run_scripts(config, script, revision) 653 | 654 | with EnvironmentContext( 655 | config, 656 | script, 657 | fn=upgrade, 658 | as_sql=sql, 659 | starting_rev=starting_rev, 660 | destination_rev=revision, 661 | tag=tag, 662 | ): 663 | script.run_env() 664 | 665 | 666 | def downgrade( 667 | config: AlembicConfig, 668 | revision: str, 669 | sql: bool = False, 670 | tag: str | None = None, 671 | ) -> None: 672 | """回退到先前版本. 673 | 674 | 参数: 675 | config: `AlembicConfig` 对象 676 | revision: 目标迁移 677 | sql: 是否以 SQL 的形式输出迁移脚本 678 | tag: 一个任意的字符串, 可在自定义的 `env.py` 中通过 `alembic.EnvironmentContext.get_tag_argument` 获得 679 | """ 680 | 681 | script = ScriptDirectory.from_config(config) 682 | starting_rev = None 683 | if ":" in revision: 684 | if not sql: 685 | raise click.BadParameter( 686 | "不允许在非 --sql 模式下使用迁移范围", param_hint="REVISION" 687 | ) 688 | starting_rev, revision = revision.split(":", 2) 689 | elif sql: 690 | raise click.BadParameter( 691 | "--sql 模式下降级必须指定迁移范围 :", param_hint="REVISION" 692 | ) 693 | 694 | @return_progressbar 695 | def downgrade(rev, _) -> Iterable[StampStep | RevisionStep]: 696 | from . import _patch_migrate_session 697 | 698 | with _patch_migrate_session(): 699 | yield from script._downgrade_revs(revision, rev) 700 | 701 | _move_run_scripts(config, script, revision) 702 | 703 | with EnvironmentContext( 704 | config, 705 | script, 706 | fn=downgrade, 707 | as_sql=sql, 708 | starting_rev=starting_rev, 709 | destination_rev=revision, 710 | tag=tag, 711 | ): 712 | script.run_env() 713 | 714 | 715 | def sync(config: AlembicConfig, revision: str | None = None): 716 | """同步数据库模式 (仅用于开发). 717 | 718 | 参数: 719 | config: `AlembicConfig` 对象 720 | revision: 目标迁移, 如果不提供则与当前模型同步 721 | """ 722 | script = ScriptDirectory.from_config(config) 723 | 724 | revision_context = RevisionContext( 725 | config, 726 | script, 727 | dict( 728 | message=None, 729 | autogenerate=True, 730 | sql=False, 731 | head="head", 732 | splice=False, 733 | branch_label=None, 734 | version_path=None, 735 | rev_id=None, 736 | depends_on=None, 737 | ), 738 | ) 739 | 740 | def retrieve_migrations( 741 | rev, context: MigrationContext 742 | ) -> Iterable[StampStep | RevisionStep]: 743 | assert context.connection 744 | 745 | metadata = MetaData() if revision else context.opts["target_metadata"] 746 | ops = cast(UpgradeOps, produce_migrations(context, metadata).upgrade_ops) 747 | 748 | if not (revision or ops.as_diffs()): 749 | return 750 | 751 | try: 752 | _run_ops(context, ops) 753 | except Exception: 754 | if revision: 755 | raise 756 | 757 | _run_ops( 758 | context, 759 | cast(UpgradeOps, produce_migrations(context, MetaData()).upgrade_ops), 760 | ) 761 | metadata.create_all(context.connection) 762 | 763 | yield from script._stamp_revs("base", rev) 764 | 765 | if revision: 766 | yield from script._upgrade_revs(revision, "base") 767 | 768 | _move_run_scripts(config, script, revision or "base") 769 | 770 | with EnvironmentContext( 771 | config, 772 | script, 773 | fn=retrieve_migrations, 774 | as_sql=False, 775 | template_args=revision_context.template_args, 776 | revision_context=revision_context, 777 | ): 778 | script.run_env() 779 | 780 | 781 | def _run_ops(context: MigrationContext, ops: UpgradeOps | DowngradeOps) -> None: 782 | with context.begin_transaction(True): 783 | exec( 784 | inspect.cleandoc( 785 | render_python_code( 786 | ops, 787 | render_as_batch=True, 788 | ) 789 | ), 790 | {"sa": sqlalchemy, "op": alembic.op}, 791 | ) 792 | 793 | 794 | def show(config: AlembicConfig, revs: str | Sequence[str] = "current") -> None: 795 | """显示迁移的信息. 796 | 797 | 参数: 798 | config: `AlembicConfig` 对象 799 | revs: 目标迁移范围 800 | """ 801 | 802 | script = ScriptDirectory.from_config(config) 803 | 804 | if revs in {(), "current", ("current",)}: 805 | revs = [] 806 | 807 | with EnvironmentContext( 808 | config, script, fn=lambda rev, _: revs.append(rev) or () 809 | ): 810 | script.run_env() 811 | 812 | for sc in cast("tuple[Script]", script.get_revisions(revs)): 813 | config.print_stdout(sc.log_entry) 814 | 815 | 816 | def history( 817 | config: AlembicConfig, 818 | rev_range: str | None = None, 819 | verbose: bool = False, 820 | indicate_current: bool = False, 821 | ) -> None: 822 | """显示迁移的历史. 823 | 824 | 参数: 825 | config: `AlembicConfig` 对象 826 | rev_range: 迁移范围 827 | verbose: 是否显示详细信息 828 | indicate_current: 指示出当前迁移 829 | """ 830 | 831 | script = ScriptDirectory.from_config(config) 832 | if rev_range is not None: 833 | if ":" not in rev_range: 834 | raise click.BadParameter( 835 | "历史范围应为 [start]:[end]、[start]: 或 :[end]", param_hint="REV_RANGE" 836 | ) 837 | base, head = rev_range.strip().split(":") 838 | else: 839 | base = head = None 840 | 841 | environment = ( 842 | asbool(config.get_main_option("revision_environment")) or indicate_current 843 | ) 844 | 845 | def _display_history(config, script, base, head, currents=()): 846 | for sc in script.walk_revisions(base=base or "base", head=head or "heads"): 847 | if indicate_current: 848 | sc._db_current_indicator = sc.revision in currents 849 | 850 | config.print_stdout( 851 | sc.cmd_format( 852 | verbose=verbose, 853 | include_branches=True, 854 | include_doc=True, 855 | include_parents=True, 856 | ) 857 | ) 858 | 859 | def _display_history_w_current(config, script, base, head): 860 | def _display_current_history(rev): 861 | if head == "current": 862 | _display_history(config, script, base, rev, rev) 863 | elif base == "current": 864 | _display_history(config, script, rev, head, rev) 865 | else: 866 | _display_history(config, script, base, head, rev) 867 | 868 | revs = [] 869 | with EnvironmentContext( 870 | config, script, fn=lambda rev, _: revs.append(rev) or () 871 | ): 872 | script.run_env() 873 | 874 | for rev in revs: 875 | _display_current_history(rev) 876 | 877 | if base == "current" or head == "current" or environment: 878 | _display_history_w_current(config, script, base, head) 879 | else: 880 | _display_history(config, script, base, head) 881 | 882 | 883 | def heads( 884 | config: AlembicConfig, verbose: bool = False, resolve_dependencies: bool = False 885 | ) -> None: 886 | """显示所有的分支头. 887 | 888 | 参数: 889 | config: `AlembicConfig` 对象 890 | verbose: 是否显示详细信息 891 | resolve_dependencies: 是否将依赖的迁移视作父迁移 892 | """ 893 | 894 | script = ScriptDirectory.from_config(config) 895 | if resolve_dependencies: 896 | heads = script.get_revisions("heads") 897 | else: 898 | heads = script.get_revisions(script.get_heads()) 899 | 900 | for rev in cast("tuple[Script]", heads): 901 | config.print_stdout( 902 | rev.cmd_format(verbose, include_branches=True, tree_indicators=False) 903 | ) 904 | 905 | 906 | def branches(config: AlembicConfig, verbose: bool = False) -> None: 907 | """显示所有的分支. 908 | 909 | 参数: 910 | config: `AlembicConfig` 对象 911 | verbose: 是否显示详细信息 912 | """ 913 | script = ScriptDirectory.from_config(config) 914 | for sc in script.walk_revisions(): 915 | if not sc.is_branch_point: 916 | continue 917 | 918 | config.print_stdout( 919 | "%s\n%s\n", 920 | sc.cmd_format(verbose, include_branches=True), 921 | "\n".join( 922 | "%s -> %s" 923 | % ( 924 | " " * len(str(sc.revision)), 925 | cast(Script, script.get_revision(rev)).cmd_format( 926 | False, include_branches=True, include_doc=verbose 927 | ), 928 | ) 929 | for rev in sc.nextrev 930 | ), 931 | ) 932 | 933 | 934 | def current(config: AlembicConfig, verbose: bool = False) -> None: 935 | """显示当前的迁移. 936 | 937 | 参数: 938 | config: `AlembicConfig` 对象 939 | verbose: 是否显示详细信息 940 | """ 941 | 942 | script = ScriptDirectory.from_config(config) 943 | 944 | def display_version( 945 | rev, context: MigrationContext 946 | ) -> Iterable[StampStep | RevisionStep]: 947 | if verbose: 948 | config.print_stdout( 949 | "Current revision(s) for %s:", 950 | cast(Connection, context.connection).engine.url.render_as_string(), 951 | ) 952 | for sc in cast("set[Script]", script.get_all_current(rev)): 953 | config.print_stdout(sc.cmd_format(verbose)) 954 | 955 | return () 956 | 957 | with EnvironmentContext(config, script, fn=display_version, dont_mutate=True): 958 | script.run_env() 959 | 960 | 961 | def stamp( 962 | config: AlembicConfig, 963 | revisions: tuple[str, ...] = ("heads",), 964 | sql: bool = False, 965 | tag: str | None = None, 966 | purge: bool = False, 967 | ) -> None: 968 | """将数据库标记为特定的迁移版本, 不运行任何迁移. 969 | 970 | 参数: 971 | config: `AlembicConfig` 对象 972 | revisions: 目标迁移 973 | sql: 是否以 SQL 的形式输出迁移脚本 974 | tag: 一个任意的字符串, 可在自定义的 `env.py` 中通过 `alembic.EnvironmentContext.get_tag_argument` 获得 975 | purge: 是否在标记前清空数据库版本表 976 | """ 977 | 978 | revisions = revisions or ("heads",) 979 | script = ScriptDirectory.from_config(config) 980 | 981 | starting_rev = None 982 | if sql: 983 | destination_revs = [] 984 | for revision in revisions: 985 | if ":" in revision: 986 | srev, revision = revision.split(":", 2) 987 | 988 | if starting_rev != srev: 989 | if starting_rev is None: 990 | starting_rev = srev 991 | else: 992 | raise click.BadParameter( 993 | "--sql 模式下标记操作仅支持一个起始迁移", 994 | param_hint="REVISIONS", 995 | ) 996 | destination_revs.append(revision) 997 | else: 998 | destination_revs = revisions 999 | 1000 | def do_stamp(rev, _) -> Iterable[StampStep | RevisionStep]: 1001 | yield from script._stamp_revs(destination_revs, rev) 1002 | _move_run_scripts(config, script, destination_revs) 1003 | 1004 | with EnvironmentContext( 1005 | config, 1006 | script, 1007 | fn=do_stamp, 1008 | as_sql=sql, 1009 | starting_rev=starting_rev, 1010 | destination_rev=destination_revs, 1011 | tag=tag, 1012 | purge=purge, 1013 | ): 1014 | script.run_env() 1015 | 1016 | 1017 | def edit(config: AlembicConfig, rev: str = "current") -> None: 1018 | """使用 `$EDITOR` 编辑迁移脚本. 1019 | 1020 | 参数: 1021 | config: `AlembicConfig` 对象 1022 | rev: 目标迁移 1023 | """ 1024 | 1025 | script = ScriptDirectory.from_config(config) 1026 | 1027 | if rev == "current": 1028 | 1029 | def edit_current(rev, _) -> Iterable[StampStep | RevisionStep]: 1030 | if not rev: 1031 | raise click.UsageError("当前没有迁移") 1032 | 1033 | for sc in cast("tuple[Script]", script.get_revisions(rev)): 1034 | script_path = config.move_script(sc) 1035 | open_in_editor(str(script_path)) 1036 | 1037 | return () 1038 | 1039 | with EnvironmentContext(config, script, fn=edit_current): 1040 | script.run_env() 1041 | else: 1042 | revs = cast("tuple[Script, ...]", script.get_revisions(rev)) 1043 | 1044 | if not revs: 1045 | raise click.BadParameter(f'没有 "{rev}" 指示的迁移脚本') 1046 | 1047 | for sc in cast("tuple[Script]", revs): 1048 | script_path = config.move_script(sc) 1049 | open_in_editor(str(script_path)) 1050 | 1051 | 1052 | def ensure_version(config: AlembicConfig, sql: bool = False) -> None: 1053 | """创建版本表. 1054 | 1055 | 参数: 1056 | config: `AlembicConfig` 对象 1057 | sql: 是否以 SQL 的形式输出迁移脚本 1058 | """ 1059 | 1060 | script = ScriptDirectory.from_config(config) 1061 | 1062 | def do_ensure_version( 1063 | _, context: MigrationContext 1064 | ) -> Iterable[StampStep | RevisionStep]: 1065 | context._ensure_version_table() 1066 | return () 1067 | 1068 | with EnvironmentContext( 1069 | config, 1070 | script, 1071 | fn=do_ensure_version, 1072 | as_sql=sql, 1073 | ): 1074 | script.run_env() 1075 | --------------------------------------------------------------------------------