├── .dockerignore ├── .env.example ├── .github └── workflows │ └── docker-image.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── app ├── __init__.py ├── api │ ├── endpoints │ │ ├── enterprise.py │ │ ├── health.py │ │ └── sso.py │ └── router.py ├── core │ └── config.py ├── db │ └── session.py ├── main.py ├── models │ ├── __init__.py │ ├── account.py │ ├── base.py │ ├── engine.py │ └── types.py ├── services │ ├── oidc.py │ ├── passport.py │ └── token.py └── utils │ └── redis.py ├── assets └── image-20250408142818633.png ├── requirements.txt └── yaml ├── docker-compose.yaml └── k8s-deployment.yaml /.dockerignore: -------------------------------------------------------------------------------- 1 | # Git 2 | .git 3 | .gitignore 4 | .gitattributes 5 | 6 | 7 | # CI 8 | .codeclimate.yml 9 | .travis.yml 10 | .taskcluster.yml 11 | 12 | # Docker 13 | docker-compose.yml 14 | Dockerfile 15 | .docker 16 | .dockerignore 17 | 18 | # Byte-compiled / optimized / DLL files 19 | **/__pycache__/ 20 | **/*.py[cod] 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Virtual environment 74 | .env.example 75 | .env 76 | .venv/ 77 | venv/ 78 | 79 | # PyCharm 80 | .idea 81 | 82 | # Python mode for VIM 83 | .ropeproject 84 | **/.ropeproject 85 | 86 | # Vim swap files 87 | **/*.swp 88 | 89 | # VS Code 90 | .vscode/ -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # 服务配置 2 | CONSOLE_WEB_URL= 3 | SECRET_KEY=your-secret-key 4 | TENANT_ID= 5 | EDITION=SELF_HOSTED 6 | 7 | # 令牌配置 8 | ACCESS_TOKEN_EXPIRE_MINUTES=600 9 | REFRESH_TOKEN_EXPIRE_DAYS=30 10 | REFRESH_TOKEN_PREFIX=refresh_token: 11 | ACCOUNT_REFRESH_TOKEN_PREFIX=account_refresh_token: 12 | 13 | # OIDC配置 14 | OIDC_ENABLED=false 15 | OIDC_CLIENT_ID= 16 | OIDC_CLIENT_SECRET= 17 | OIDC_DISCOVERY_URL= 18 | OIDC_REDIRECT_URI= 19 | OIDC_SCOPE=openid profile email 20 | OIDC_RESPONSE_TYPE=code 21 | 22 | # 数据库配置 23 | DB_USERNAME=dify_admin 24 | DB_PASSWORD=123456 25 | DB_HOST=127.0.0.1 26 | DB_PORT=5432 27 | DB_DATABASE=dify 28 | 29 | # Redis配置 30 | REDIS_HOST=127.0.0.1 31 | REDIS_PORT=6379 32 | REDIS_DB=0 33 | REDIS_PASSWORD= 34 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Login to DockerHub 17 | uses: docker/login-action@v1 18 | with: 19 | username: ${{ secrets.DOCKERHUB_USERNAME }} 20 | password: ${{ secrets.DOCKERHUB_TOKEN }} 21 | - 22 | name: Build and push 23 | id: docker_build 24 | uses: docker/build-push-action@v2 25 | with: 26 | push: true 27 | tags: lework/dify-sso:latest 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | **/__pycache__ 3 | **/__pycache__/** 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # UV 99 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | #uv.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 117 | .pdm.toml 118 | .pdm-python 119 | .pdm-build/ 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | # Ruff stuff: 172 | .ruff_cache/ 173 | 174 | # PyPI configuration file 175 | .pypirc 176 | 177 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # ---- 运行环境 ---- 2 | FROM python:3.11-alpine AS running 3 | 4 | ENV LANG='en_US.UTF-8' \ 5 | LANGUAGE='en_US.UTF-8' \ 6 | TZ='Asia/Shanghai' \ 7 | PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ 8 | 9 | RUN \ 10 | sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories \ 11 | && apk --update -t --no-cache add tzdata libpq \ 12 | && ln -snf /usr/share/zoneinfo/${TZ} /etc/localtime \ 13 | && echo "${TZ}" > /etc/timezone \ 14 | && apk add --no-cache --virtual .build-deps gcc python3-dev musl-dev postgresql-dev \ 15 | && pip install --upgrade pip \ 16 | && pip install --no-cache-dir psycopg2 \ 17 | && apk del --no-cache .build-deps 18 | 19 | WORKDIR /app 20 | 21 | # 下载依赖 22 | COPY requirements.txt . 23 | RUN --mount=type=cache,id=pip,target=/root/.cache \ 24 | pip install -r requirements.txt 25 | 26 | # 拷贝代码 27 | COPY . . 28 | 29 | CMD ["python", "-m", "app.main"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Lework 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dify SSO 集成 2 | 3 | 这个项目实现了 Dify 的企业 SSO 集成,当前支持 OIDC 协议。 4 | 5 | ## 功能特性 6 | 7 | - OIDC 登录集成 8 | - 配置灵活,通过环境变量控制 9 | - 支持授权码流程 10 | - 支持用户数据库自动创建和关联 11 | 12 | ## 技术栈 13 | 14 | - Python 3.8+ 15 | - FastAPI 16 | - SQLAlchemy 17 | - PostgreSQL 18 | - Redis 19 | - Flask-Login 20 | 21 | ## 系统要求 22 | 23 | - Python 3.8 或更高版本 24 | - PostgreSQL 12 或更高版本 25 | - Redis 6 或更高版本 26 | - 支持 OIDC 的身份提供商 27 | 28 | ## 项目结构 29 | 30 | ``` 31 | dify-sso/ 32 | ├── app/ # 主应用代码 33 | │ ├── api/ # API 路由和端点 34 | │ ├── core/ # 核心配置和功能 35 | │ ├── db/ # 数据库模型和迁移 36 | │ ├── models/ # 数据模型 37 | │ ├── services/ # 业务逻辑服务 38 | │ ├── utils/ # 工具函数 39 | │ ├── __init__.py # 包初始化文件 40 | │ └── main.py # 应用入口文件 41 | ├── assets/ # 静态资源和图片 42 | ├── yaml/ # 部署配置文件 43 | │ ├── docker-compose.yaml # Docker Compose 配置 44 | │ └── k8s-deployment.yaml # Kubernetes 部署配置 45 | ├── .env.example # 环境变量示例 46 | ├── .env # 环境变量配置 47 | ├── .dockerignore # Docker 忽略文件 48 | ├── .gitignore # Git 忽略文件 49 | ├── requirements.txt # 项目依赖 50 | └── Dockerfile # Docker 构建文件 51 | ``` 52 | 53 | ## 配置说明 54 | 55 | OIDC SSO 集成需要以下环境变量配置: 56 | 57 | ```bash 58 | # Dify 配置 59 | CONSOLE_WEB_URL=your-dify-web-address # dify 的 web 地址 60 | SECRET_KEY=your-secret-key # dify 的 secret key 61 | TENANT_ID=your-tenant-id # dify 的 tenant id 62 | EDITION=SELF_HOSTED 63 | 64 | # 令牌配置 65 | ACCESS_TOKEN_EXPIRE_MINUTES=600 66 | REFRESH_TOKEN_EXPIRE_DAYS=30 67 | REFRESH_TOKEN_PREFIX=refresh_token: 68 | ACCOUNT_REFRESH_TOKEN_PREFIX=account_refresh_token: 69 | 70 | # OIDC配置 71 | OIDC_ENABLED=true # 是否启用OIDC 72 | OIDC_CLIENT_ID=your-client-id # OIDC客户端ID 73 | OIDC_CLIENT_SECRET=your-client-secret # OIDC客户端密钥 74 | OIDC_DISCOVERY_URL=https://your-oidc-provider/.well-known/openid-configuration # OIDC发现端点 75 | OIDC_REDIRECT_URI=http://localhost:8000/enterprise/sso/oidc/callback # 回调URI 76 | OIDC_SCOPE=openid profile email # 请求的范围 77 | OIDC_RESPONSE_TYPE=code # 响应类型 78 | 79 | # 数据库配置 80 | DB_HOST=127.0.0.1 81 | DB_PORT=5432 82 | DB_DATABASE=dify 83 | DB_USERNAME=dify_admin 84 | DB_PASSWORD=123456 85 | 86 | # Redis配置 87 | REDIS_HOST=127.0.0.1 88 | REDIS_PORT=6379 89 | REDIS_DB=0 90 | REDIS_PASSWORD= # Redis密码,如无密码则留空 91 | ``` 92 | 93 | ## 安装与运行 94 | 95 | ### 使用 Docker 96 | 97 | 1. 构建镜像: 98 | 99 | ```bash 100 | docker build -t dify-sso . 101 | ``` 102 | 103 | 2. 运行容器: 104 | 105 | ```bash 106 | docker run -p 8000:8000 --env-file .env dify-sso 107 | ``` 108 | 109 | ### 本地开发 110 | 111 | 1. 克隆仓库: 112 | 113 | ```bash 114 | git clone https://github.com/lework/dify-sso.git 115 | cd dify-sso 116 | ``` 117 | 118 | 2. 创建并激活虚拟环境: 119 | 120 | ```bash 121 | python -m venv .venv 122 | source .venv/bin/activate # Linux/Mac 123 | .venv\Scripts\activate # Windows 124 | ``` 125 | 126 | 3. 安装依赖: 127 | 128 | ```bash 129 | pip install -r requirements.txt 130 | ``` 131 | 132 | 4. 配置环境变量: 133 | 134 | ```bash 135 | cp .env.example .env 136 | # 编辑.env文件,设置您的OIDC配置和数据库配置 137 | ``` 138 | 139 | 5. 运行应用: 140 | 141 | ```bash 142 | python -m app.main 143 | ``` 144 | 145 | ### 接入流程 146 | 147 | 1. 创建 sso 服务商 148 | 149 | ![image-20250408142818633](./assets/image-20250408142818633.png) 150 | 151 | 2. 启动 dify-sso 容器 152 | 153 | ```bash 154 | docker run -p 8000:8000 --env-file .env lework/dify-sso:0.0.1 155 | ``` 156 | 157 | 3. 在 dify-proxy 的 nginx 配置文件中添加以下配置: 158 | 159 | ```nginx 160 | location ~ (/console/api/system-features|/console/api/enterprise/sso/) { 161 | proxy_pass http://dify-sso:8000; 162 | include proxy.conf; 163 | } 164 | ``` 165 | 166 | > nginx 完整的配置看[default.conf.template](https://github.com/langgenius/dify/blob/main/docker/nginx/conf.d/default.conf.template) 167 | 168 | 如果 dify-proxy 是部署在 k8s 中。 可使用 [k8s-deployment.yaml](./yaml/k8s-deployment.yaml) 文件部署 dify-sso 。 169 | 170 | 如果 dify-proxy 是部署在 docker 中。 可使用 [docker-compose.yaml](./yaml/docker-compose.yaml) 文件部署 dify-sso 。 171 | 172 | ## API 端点 173 | 174 | OIDC SSO 集成提供以下 API 端点: 175 | 176 | - **GET /console/api/enterprise/sso/oidc/login**: 启动 OIDC 登录流程,将用户重定向到 OIDC 提供商 177 | - **GET /console/api/enterprise/sso/oidc/callback**: OIDC 回调处理,处理授权码并获取用户信息 178 | - **GET /console/api/system-features**: 获取系统功能配置 179 | - **GET /health**: 健康检查端点 180 | - **GET /info**: 获取企业信息 181 | 182 | ## OIDC 认证流程 183 | 184 | OIDC 登录流程遵循标准的授权码流程(Authorization Code Flow): 185 | 186 | 1. 用户访问 `/console/api/enterprise/sso/oidc/login` 端点 187 | 2. 系统生成授权 URL,重定向用户到 OIDC 提供商的登录页面 188 | 3. 用户在 OIDC 提供商处认证 189 | 4. OIDC 提供商将用户重定向回 `/console/api/enterprise/sso/oidc/callback`,带有授权码 190 | 5. 系统使用授权码获取访问令牌和 ID 令牌 191 | 6. 系统使用访问令牌获取用户信息 192 | 7. 系统通过 OIDC 用户信息中的 `sub` 或 `email` 查询数据库,确认用户是否存在: 193 | - 如果用户存在,更新其信息(如姓名)并记录登录时间和 IP 194 | - 如果用户不存在,创建新用户并关联到默认租户 195 | 8. 系统生成 JWT 令牌和刷新令牌,并将用户重定向到 Dify 控制台 196 | 197 | ## 数据库表说明 198 | 199 | 系统使用以下主要表格: 200 | 201 | - `accounts`: 存储用户账号信息 202 | - `tenants`: 存储租户信息 203 | - `tenant_account_joins`: 存储用户与租户的关联 204 | 205 | ## 贡献指南 206 | 207 | 1. Fork 本仓库 208 | 2. 创建特性分支 (`git checkout -b feature/AmazingFeature`) 209 | 3. 提交更改 (`git commit -m 'Add some AmazingFeature'`) 210 | 4. 推送到分支 (`git push origin feature/AmazingFeature`) 211 | 5. 创建 Pull Request 212 | 213 | ## 许可证 214 | 215 | 本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件 216 | 217 | ## ⚠️ 特别声明 218 | 219 | 请注意,Dify 官方版本包含商业授权模式,其内置的 SSO 功能通常属于其商业计划的一部分。建议有能力的去购买 Dify 商业授权。 220 | 221 | 本项目 `dify-sso` 是一个独立的外部集成方案,**并未修改任何 Dify 官方源代码**。它仅提供了一种通过标准 OIDC 协议对接企业现有身份认证系统、实现单点登录到 Dify 的**可选方式**,旨在方便那些已经拥有统一身份认证体系的企业用户。 222 | 223 | 我们尊重 Dify 的知识产权和商业模式。如果您认为本项目的存在可能对 Dify 的商业权益产生影响,请随时联系项目作者,我们将及时沟通处理或根据要求移除本项目。 224 | 225 | ## 参考资料 226 | 227 | - [OpenID Connect 规范](https://openid.net/connect/) 228 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dify SSO Service 3 | """ 4 | 5 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /app/api/endpoints/enterprise.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from datetime import datetime, timedelta 3 | 4 | router = APIRouter() 5 | 6 | # 模拟企业信息 7 | MOCK_ENTERPRISE_INFO = { 8 | "sso_enforced_for_signin": True, 9 | "sso_enforced_for_signin_protocol": "oidc", 10 | "sso_enforced_for_web": True, 11 | "sso_enforced_for_web_protocol": "oidc", 12 | "enable_web_sso_switch_component": True, 13 | "enable_email_code_login": True, 14 | "enable_email_password_login": True, 15 | "is_allow_register": True, 16 | "is_allow_create_workspace": False, 17 | "license": { 18 | "status": "active", 19 | "expired_at": (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d") 20 | } 21 | } 22 | 23 | # 模拟计费信息 24 | MOCK_BILLING_INFO = { 25 | "enabled": True, 26 | "subscription": { 27 | "plan": "enterprise", 28 | "interval": "year" 29 | }, 30 | "members": { 31 | "size": 1, 32 | "limit": 100 33 | }, 34 | "apps": { 35 | "size": 1, 36 | "limit": 200 37 | }, 38 | "vector_space": { 39 | "size": 1, 40 | "limit": 500 41 | }, 42 | "documents_upload_quota": { 43 | "size": 1, 44 | "limit": 10000 45 | }, 46 | "annotation_quota_limit": { 47 | "size": 1, 48 | "limit": 10000 49 | }, 50 | "docs_processing": "top-priority", 51 | "can_replace_logo": True, 52 | "model_load_balancing_enabled": True, 53 | "dataset_operator_enabled": True, 54 | "knowledge_rate_limit": { 55 | "limit": 200000, 56 | "subscription_plan": "enterprise" 57 | } 58 | } 59 | 60 | # 系统功能 61 | SYSTEM_FEATURES = { 62 | "sso_enforced_for_signin": True, 63 | "sso_enforced_for_signin_protocol": "oidc", 64 | "sso_enforced_for_web": True, 65 | "sso_enforced_for_web_protocol": "oidc", 66 | "enable_web_sso_switch_component": True, 67 | "enable_marketplace": True, 68 | "max_plugin_package_size": 52428800, 69 | "enable_email_code_login": False, 70 | "enable_email_password_login": True, 71 | "enable_social_oauth_login": False, 72 | "is_allow_register": False, 73 | "is_allow_create_workspace": False, 74 | "is_email_setup": True, 75 | "license": { 76 | "status": "active", 77 | "expired_at": (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d") 78 | } 79 | } 80 | 81 | 82 | @router.get("/info") 83 | async def get_enterprise_info(): 84 | return MOCK_ENTERPRISE_INFO 85 | 86 | @router.get("/app-sso-setting") 87 | async def get_app_sso_setting(app_code: str): 88 | return { 89 | "enabled": True, 90 | "protocol": "oidc", 91 | "app_code": app_code 92 | } 93 | 94 | # 计费相关接口 95 | @router.get("/subscription/info") 96 | async def get_billing_info(): 97 | return MOCK_BILLING_INFO 98 | 99 | 100 | # 系统功能 101 | @router.get("/console/api/system-features") 102 | async def get_system_features(): 103 | return SYSTEM_FEATURES 104 | -------------------------------------------------------------------------------- /app/api/endpoints/health.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException 2 | from app.utils.redis import check_redis_connection 3 | from app.db.session import check_database_connection 4 | 5 | router = APIRouter() 6 | 7 | @router.get("/health") 8 | async def health_check(detail: bool = False): 9 | if detail: 10 | health_status = { 11 | "status": "healthy", 12 | "redis": check_redis_connection(), 13 | "database": check_database_connection() 14 | } 15 | if not health_status["redis"] or not health_status["database"]: 16 | health_status["status"] = "unhealthy" 17 | raise HTTPException( 18 | status_code=503, 19 | detail={ 20 | "redis": "Redis connection failed" if not health_status["redis"] else "OK", 21 | "database": "Database connection failed" if not health_status["database"] else "OK" 22 | } 23 | ) 24 | else: 25 | health_status = { 26 | "status": "healthy", 27 | } 28 | 29 | return health_status -------------------------------------------------------------------------------- /app/api/endpoints/sso.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Request, Depends 2 | from fastapi.responses import RedirectResponse 3 | from sqlalchemy.orm import Session 4 | from app.db.session import get_db 5 | from app.services.oidc import OIDCService 6 | from app.core.config import settings 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | router = APIRouter() 12 | oidc_service = OIDCService() 13 | 14 | @router.get("/oidc/login") 15 | async def oidc_login( 16 | is_login: bool = False 17 | ): 18 | login_url = oidc_service.get_login_url() 19 | if is_login: 20 | return RedirectResponse(url=login_url) 21 | else: 22 | return {"url": login_url} 23 | 24 | @router.get("/oidc/callback") 25 | async def oidc_callback( 26 | code: str, 27 | db: Session = Depends(get_db), 28 | request: Request = None 29 | ): 30 | client_host = request.client.host 31 | xff = request.headers.get('X-Forwarded-For') 32 | if xff: 33 | xffs = xff.split(',') 34 | if len(xffs) > 0: 35 | client_host = xffs[0].strip() 36 | 37 | try: 38 | tokens = oidc_service.handle_callback(code, db, client_host) 39 | return RedirectResponse( 40 | url=f"{settings.CONSOLE_WEB_URL}/signin?access_token={tokens['access_token']}&refresh_token={tokens['refresh_token']}") 41 | except Exception as e: 42 | logger.exception("OIDC回调处理失败: %s", str(e)) 43 | raise HTTPException(status_code=400, detail=str(e)) -------------------------------------------------------------------------------- /app/api/router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from app.api.endpoints import enterprise, health, sso 3 | 4 | router = APIRouter() 5 | 6 | router.include_router(sso.router, prefix="/console/api/enterprise/sso") 7 | router.include_router(enterprise.router) 8 | router.include_router(health.router) 9 | -------------------------------------------------------------------------------- /app/core/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | from typing import Optional 3 | 4 | class Settings(BaseSettings): 5 | # 通用配置 6 | CONSOLE_WEB_URL: str = "" 7 | SECRET_KEY: str = "your-secret-key" 8 | TENANT_ID: str = "" 9 | EDITION: str = "SELF_HOSTED" 10 | 11 | # 刷新令牌配置 12 | ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 13 | REFRESH_TOKEN_EXPIRE_DAYS: str = "30" 14 | REFRESH_TOKEN_PREFIX: str = "refresh_token:" 15 | ACCOUNT_REFRESH_TOKEN_PREFIX: str = "account_refresh_token:" 16 | 17 | # OIDC配置 18 | OIDC_ENABLED: bool = False 19 | OIDC_CLIENT_ID: str = "" 20 | OIDC_CLIENT_SECRET: str = "" 21 | OIDC_DISCOVERY_URL: str = "" 22 | OIDC_REDIRECT_URI: str = "" 23 | OIDC_SCOPE: str = "openid profile email" 24 | OIDC_RESPONSE_TYPE: str = "code" 25 | 26 | # 数据库配置 27 | DB_USERNAME: str = "dify_admin" 28 | DB_PASSWORD: str = "123456" 29 | DB_HOST: str = "127.0.0.1" 30 | DB_PORT: str = "5432" 31 | DB_DATABASE: str = "dify" 32 | 33 | # Redis配置 34 | REDIS_HOST: str = "127.0.0.1" 35 | REDIS_PORT: str = "6379" 36 | REDIS_DB: str = "0" 37 | REDIS_PASSWORD: str = "" 38 | 39 | class Config: 40 | env_file = ".env" 41 | case_sensitive = True 42 | 43 | settings = Settings() -------------------------------------------------------------------------------- /app/db/session.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine, text 2 | from sqlalchemy.orm import sessionmaker 3 | from app.core.config import settings 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | DATABASE_URL = f"postgresql://{settings.DB_USERNAME}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_DATABASE}" 9 | engine = create_engine(DATABASE_URL) 10 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 11 | 12 | def get_db(): 13 | db = SessionLocal() 14 | try: 15 | yield db 16 | finally: 17 | db.close() 18 | 19 | def check_database_connection() -> bool: 20 | try: 21 | with engine.connect() as conn: 22 | conn.execute(text("SELECT 1")) 23 | return True 24 | except Exception as e: 25 | logger.exception("数据库连接检查失败: %s", str(e)) 26 | return False -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from fastapi import FastAPI 4 | from fastapi.middleware.cors import CORSMiddleware 5 | from contextlib import asynccontextmanager 6 | from app.api.router import router 7 | from app.db.session import check_database_connection, engine 8 | from app.utils.redis import check_redis_connection, redis_manager 9 | from app.models.account import Base 10 | from app.api.endpoints.sso import oidc_service 11 | 12 | # 配置日志 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | # 创建数据库表 17 | Base.metadata.create_all(bind=engine) 18 | 19 | @asynccontextmanager 20 | async def lifespan(app: FastAPI): 21 | """应用生命周期管理""" 22 | # 启动时执行 23 | logger.info("Starting application") 24 | # 检查数据库连接 25 | if not check_database_connection(): 26 | logger.error("Failed to connect to PostgreSQL database") 27 | raise Exception("Database connection failed") 28 | logger.info("Database connection successful") 29 | 30 | # 检查Redis连接 31 | if not check_redis_connection(): 32 | logger.error("Failed to connect to Redis") 33 | raise Exception("Redis connection failed") 34 | logger.info("Redis connection successful") 35 | 36 | # 检查OIDC配置 37 | if not oidc_service.check_oidc_config(): 38 | logger.error("OIDC configuration is incomplete") 39 | raise Exception("OIDC configuration is incomplete") 40 | logger.info("OIDC configuration is complete") 41 | 42 | yield 43 | 44 | # 关闭时执行 45 | redis_manager.close() 46 | logger.info("Redis connection pool closed") 47 | 48 | # 关闭数据库连接池 49 | engine.dispose() 50 | logger.info("PostgreSQL connection pool closed") 51 | 52 | app = FastAPI(lifespan=lifespan) 53 | 54 | # 添加CORS中间件 55 | app.add_middleware( 56 | CORSMiddleware, 57 | allow_origins=["*"], 58 | allow_credentials=True, 59 | allow_methods=["*"], 60 | allow_headers=["*"], 61 | ) 62 | 63 | # 注册路由 64 | app.include_router(router) 65 | 66 | if __name__ == "__main__": 67 | import uvicorn 68 | uvicorn.run(app, host="0.0.0.0", port=8000) -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .account import ( 2 | Account, 3 | Tenant, 4 | TenantAccountJoin, 5 | AccountStatus, 6 | TenantAccountRole, 7 | TenantStatus, 8 | AccountIntegrate, 9 | ) -------------------------------------------------------------------------------- /app/models/account.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import json 3 | from datetime import UTC, datetime 4 | from flask_login import UserMixin 5 | from sqlalchemy import func 6 | from sqlalchemy.orm import Mapped, mapped_column 7 | from sqlalchemy.dialects.postgresql import UUID 8 | 9 | from .base import Base 10 | from .engine import db 11 | from .types import StringUUID 12 | 13 | 14 | class AccountStatus(enum.StrEnum): 15 | PENDING = "pending" 16 | UNINITIALIZED = "uninitialized" 17 | ACTIVE = "active" 18 | BANNED = "banned" 19 | CLOSED = "closed" 20 | 21 | 22 | class Account(UserMixin, Base): 23 | __tablename__ = "accounts" 24 | __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) 25 | 26 | id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) 27 | name = db.Column(db.String(255), nullable=False) 28 | email = db.Column(db.String(255), nullable=False) 29 | password = db.Column(db.String(255), nullable=True) 30 | password_salt = db.Column(db.String(255), nullable=True) 31 | avatar = db.Column(db.String(255)) 32 | interface_language = db.Column(db.String(255)) 33 | interface_theme = db.Column(db.String(255)) 34 | timezone = db.Column(db.String(255)) 35 | last_login_at = db.Column(db.DateTime) 36 | last_login_ip = db.Column(db.String(255)) 37 | last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 38 | status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) 39 | initialized_at = db.Column(db.DateTime) 40 | created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 41 | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 42 | 43 | @property 44 | def is_password_set(self): 45 | return self.password is not None 46 | 47 | @property 48 | def current_tenant(self): 49 | # FIXME: fix the type error later, because the type is important maybe cause some bugs 50 | return self._current_tenant # type: ignore 51 | 52 | @current_tenant.setter 53 | def current_tenant(self, value: "Tenant"): 54 | tenant = value 55 | ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first() 56 | if ta: 57 | tenant.current_role = ta.role 58 | else: 59 | tenant = None # type: ignore 60 | 61 | self._current_tenant = tenant 62 | 63 | @property 64 | def current_tenant_id(self) -> str | None: 65 | return self._current_tenant.id if self._current_tenant else None 66 | 67 | @current_tenant_id.setter 68 | def current_tenant_id(self, value: str): 69 | try: 70 | tenant_account_join = ( 71 | db.session.query(Tenant, TenantAccountJoin) 72 | .filter(Tenant.id == value) 73 | .filter(TenantAccountJoin.tenant_id == Tenant.id) 74 | .filter(TenantAccountJoin.account_id == self.id) 75 | .one_or_none() 76 | ) 77 | 78 | if tenant_account_join: 79 | tenant, ta = tenant_account_join 80 | tenant.current_role = ta.role 81 | else: 82 | tenant = None 83 | except Exception: 84 | tenant = None 85 | 86 | self._current_tenant = tenant 87 | 88 | @property 89 | def current_role(self): 90 | return self._current_tenant.current_role 91 | 92 | def get_status(self) -> AccountStatus: 93 | status_str = self.status 94 | return AccountStatus(status_str) 95 | 96 | @classmethod 97 | def get_by_openid(cls, provider: str, open_id: str): 98 | account_integrate = ( 99 | db.session.query(AccountIntegrate) 100 | .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) 101 | .one_or_none() 102 | ) 103 | if account_integrate: 104 | return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() 105 | return None 106 | 107 | # check current_user.current_tenant.current_role in ['admin', 'owner'] 108 | @property 109 | def is_admin_or_owner(self): 110 | return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) 111 | 112 | @property 113 | def is_admin(self): 114 | return TenantAccountRole.is_admin_role(self._current_tenant.current_role) 115 | 116 | @property 117 | def is_editor(self): 118 | return TenantAccountRole.is_editing_role(self._current_tenant.current_role) 119 | 120 | @property 121 | def is_dataset_editor(self): 122 | return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role) 123 | 124 | @property 125 | def is_dataset_operator(self): 126 | return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR 127 | 128 | @classmethod 129 | def get_by_email(cls, db, email: str): 130 | """通过邮箱查找用户""" 131 | return db.query(cls).filter(cls.email == email).first() 132 | 133 | @classmethod 134 | def create(cls, db, email: str, name: str, avatar: str = None, tenant_id: str = None): 135 | """创建新用户""" 136 | account = cls( 137 | email=email, 138 | name=name, 139 | avatar=avatar, 140 | interface_theme="light", 141 | interface_language="zh-Hans", 142 | timezone="Asia/Shanghai", 143 | status=AccountStatus.ACTIVE, 144 | initialized_at=datetime.now(UTC), 145 | ) 146 | db.add(account) 147 | db.flush() # 获取account.id 148 | 149 | # 如果提供了tenant_id,创建租户关联 150 | if tenant_id: 151 | tenant_account_join = TenantAccountJoin( 152 | tenant_id=tenant_id, 153 | account_id=account.id, 154 | role=TenantAccountRole.EDITOR 155 | ) 156 | db.add(tenant_account_join) 157 | 158 | db.commit() 159 | return account 160 | 161 | 162 | class TenantStatus(enum.StrEnum): 163 | NORMAL = "normal" 164 | ARCHIVE = "archive" 165 | 166 | 167 | class TenantAccountRole(enum.StrEnum): 168 | OWNER = "owner" 169 | ADMIN = "admin" 170 | EDITOR = "editor" 171 | NORMAL = "normal" 172 | DATASET_OPERATOR = "dataset_operator" 173 | 174 | @staticmethod 175 | def is_valid_role(role: str) -> bool: 176 | if not role: 177 | return False 178 | return role in { 179 | TenantAccountRole.OWNER, 180 | TenantAccountRole.ADMIN, 181 | TenantAccountRole.EDITOR, 182 | TenantAccountRole.NORMAL, 183 | TenantAccountRole.DATASET_OPERATOR, 184 | } 185 | 186 | @staticmethod 187 | def is_privileged_role(role: str) -> bool: 188 | if not role: 189 | return False 190 | return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} 191 | 192 | @staticmethod 193 | def is_admin_role(role: str) -> bool: 194 | if not role: 195 | return False 196 | return role == TenantAccountRole.ADMIN 197 | 198 | @staticmethod 199 | def is_non_owner_role(role: str) -> bool: 200 | if not role: 201 | return False 202 | return role in { 203 | TenantAccountRole.ADMIN, 204 | TenantAccountRole.EDITOR, 205 | TenantAccountRole.NORMAL, 206 | TenantAccountRole.DATASET_OPERATOR, 207 | } 208 | 209 | @staticmethod 210 | def is_editing_role(role: str) -> bool: 211 | if not role: 212 | return False 213 | return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} 214 | 215 | @staticmethod 216 | def is_dataset_edit_role(role: str) -> bool: 217 | if not role: 218 | return False 219 | return role in { 220 | TenantAccountRole.OWNER, 221 | TenantAccountRole.ADMIN, 222 | TenantAccountRole.EDITOR, 223 | TenantAccountRole.DATASET_OPERATOR, 224 | } 225 | 226 | 227 | class Tenant(db.Model): # type: ignore[name-defined] 228 | __tablename__ = "tenants" 229 | __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) 230 | 231 | id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) 232 | name = db.Column(db.String(255), nullable=False) 233 | encrypt_public_key = db.Column(db.Text) 234 | plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) 235 | status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) 236 | custom_config = db.Column(db.Text) 237 | created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 238 | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 239 | 240 | def get_accounts(self) -> list[Account]: 241 | return ( 242 | db.session.query(Account) 243 | .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) 244 | .all() 245 | ) 246 | 247 | @property 248 | def custom_config_dict(self) -> dict: 249 | return json.loads(self.custom_config) if self.custom_config else {} 250 | 251 | @custom_config_dict.setter 252 | def custom_config_dict(self, value: dict): 253 | self.custom_config = json.dumps(value) 254 | 255 | 256 | class TenantAccountJoin(db.Model): # type: ignore[name-defined] 257 | __tablename__ = "tenant_account_joins" 258 | __table_args__ = ( 259 | db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), 260 | db.Index("tenant_account_join_account_id_idx", "account_id"), 261 | db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), 262 | db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), 263 | ) 264 | 265 | id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) 266 | tenant_id = db.Column(StringUUID, nullable=False) 267 | account_id = db.Column(StringUUID, nullable=False) 268 | current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) 269 | role = db.Column(db.String(16), nullable=False, server_default="normal") 270 | invited_by = db.Column(StringUUID, nullable=True) 271 | created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 272 | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 273 | 274 | 275 | class AccountIntegrate(db.Model): # type: ignore[name-defined] 276 | __tablename__ = "account_integrates" 277 | __table_args__ = ( 278 | db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), 279 | db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), 280 | db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), 281 | ) 282 | 283 | id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) 284 | account_id = db.Column(StringUUID, nullable=False) 285 | provider = db.Column(db.String(16), nullable=False) 286 | open_id = db.Column(db.String(255), nullable=False) 287 | encrypted_token = db.Column(db.String(255), nullable=False) 288 | created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 289 | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) 290 | 291 | 292 | class InvitationCode(db.Model): # type: ignore[name-defined] 293 | __tablename__ = "invitation_codes" 294 | __table_args__ = ( 295 | db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), 296 | db.Index("invitation_codes_batch_idx", "batch"), 297 | db.Index("invitation_codes_code_idx", "code", "status"), 298 | ) 299 | 300 | id = db.Column(db.Integer, nullable=False) 301 | batch = db.Column(db.String(255), nullable=False) 302 | code = db.Column(db.String(32), nullable=False) 303 | status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) 304 | used_at = db.Column(db.DateTime) 305 | used_by_tenant_id = db.Column(StringUUID) 306 | used_by_account_id = db.Column(StringUUID) 307 | deprecated_at = db.Column(db.DateTime) 308 | created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) 309 | 310 | 311 | class TenantPluginPermission(Base): 312 | class InstallPermission(enum.StrEnum): 313 | EVERYONE = "everyone" 314 | ADMINS = "admins" 315 | NOBODY = "noone" 316 | 317 | class DebugPermission(enum.StrEnum): 318 | EVERYONE = "everyone" 319 | ADMINS = "admins" 320 | NOBODY = "noone" 321 | 322 | __tablename__ = "account_plugin_permissions" 323 | __table_args__ = ( 324 | db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), 325 | db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), 326 | ) 327 | 328 | id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) 329 | tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) 330 | install_permission: Mapped[InstallPermission] = mapped_column( 331 | db.String(16), nullable=False, server_default="everyone" 332 | ) 333 | debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") -------------------------------------------------------------------------------- /app/models/base.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import declarative_base 2 | 3 | from .engine import metadata 4 | 5 | Base = declarative_base(metadata=metadata) -------------------------------------------------------------------------------- /app/models/engine.py: -------------------------------------------------------------------------------- 1 | from flask_sqlalchemy import SQLAlchemy 2 | from sqlalchemy import MetaData 3 | 4 | POSTGRES_INDEXES_NAMING_CONVENTION = { 5 | "ix": "%(column_0_label)s_idx", 6 | "uq": "%(table_name)s_%(column_0_name)s_key", 7 | "ck": "%(table_name)s_%(constraint_name)s_check", 8 | "fk": "%(table_name)s_%(column_0_name)s_fkey", 9 | "pk": "%(table_name)s_pkey", 10 | } 11 | 12 | metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) 13 | db = SQLAlchemy(metadata=metadata) -------------------------------------------------------------------------------- /app/models/types.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import CHAR, TypeDecorator 2 | from sqlalchemy.dialects.postgresql import UUID 3 | 4 | 5 | class StringUUID(TypeDecorator): 6 | impl = CHAR 7 | cache_ok = True 8 | 9 | def process_bind_param(self, value, dialect): 10 | if value is None: 11 | return value 12 | elif dialect.name == "postgresql": 13 | return str(value) 14 | else: 15 | return value.hex 16 | 17 | def load_dialect_impl(self, dialect): 18 | if dialect.name == "postgresql": 19 | return dialect.type_descriptor(UUID()) 20 | else: 21 | return dialect.type_descriptor(CHAR(36)) 22 | 23 | def process_result_value(self, value, dialect): 24 | if value is None: 25 | return value 26 | return str(value) -------------------------------------------------------------------------------- /app/services/oidc.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from urllib.parse import urlencode 3 | from typing import Dict 4 | from sqlalchemy.orm import Session 5 | from datetime import UTC, datetime, timedelta 6 | import logging 7 | 8 | from app.core.config import settings 9 | from .passport import PassportService 10 | from .token import TokenService 11 | from app.models.account import Account, AccountStatus 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class OIDCService: 16 | def __init__(self): 17 | self.client_id = settings.OIDC_CLIENT_ID 18 | self.client_secret = settings.OIDC_CLIENT_SECRET 19 | self.discovery_url = settings.OIDC_DISCOVERY_URL 20 | self.redirect_uri = settings.OIDC_REDIRECT_URI 21 | self.scope = settings.OIDC_SCOPE 22 | self.response_type = settings.OIDC_RESPONSE_TYPE 23 | self.tenant_id = settings.TENANT_ID 24 | self.passport_service = PassportService() 25 | self.token_service = TokenService() 26 | 27 | # 获取OIDC配置 28 | self._load_oidc_config() 29 | 30 | def _load_oidc_config(self): 31 | """加载OIDC配置""" 32 | response = requests.get(self.discovery_url) 33 | if response.status_code == 200: 34 | config = response.json() 35 | self.authorization_endpoint = config.get('authorization_endpoint') 36 | self.token_endpoint = config.get('token_endpoint') 37 | self.userinfo_endpoint = config.get('userinfo_endpoint') 38 | else: 39 | raise Exception("Failed to load OIDC configuration") 40 | 41 | def check_oidc_config(self) -> bool: 42 | if not self.authorization_endpoint or not self.token_endpoint or not self.userinfo_endpoint: 43 | return False 44 | return True 45 | 46 | def get_login_url(self) -> str: 47 | """生成登录URL""" 48 | params = { 49 | 'client_id': self.client_id, 50 | 'response_type': self.response_type, 51 | 'scope': self.scope, 52 | 'redirect_uri': self.redirect_uri, 53 | 'state': 'random_state' # 实际应用中应该使用随机生成的状态 54 | } 55 | return f"{self.authorization_endpoint}?{urlencode(params)}" 56 | 57 | def get_token(self, code: str) -> Dict: 58 | """获取访问令牌""" 59 | data = { 60 | 'grant_type': 'authorization_code', 61 | 'code': code, 62 | 'redirect_uri': self.redirect_uri, 63 | 'client_id': self.client_id, 64 | 'client_secret': self.client_secret 65 | } 66 | response = requests.post(self.token_endpoint, data=data) 67 | if response.status_code != 200: 68 | logger.exception("获取token失败: status_code=%d, response=%s", 69 | response.status_code, response.text) 70 | raise Exception("Failed to get token") 71 | return response.json() 72 | 73 | def get_user_info(self, access_token: str) -> Dict: 74 | """获取用户信息""" 75 | headers = {'Authorization': f'Bearer {access_token}'} 76 | response = requests.get(self.userinfo_endpoint, headers=headers) 77 | if response.status_code != 200: 78 | logger.exception("获取用户信息失败: status_code=%d, response=%s", 79 | response.status_code, response.text) 80 | raise Exception("Failed to get user info") 81 | return response.json() 82 | 83 | def handle_callback(self, code: str, db: Session, client_host: str) -> Dict[str, str]: 84 | """处理回调,返回access token和refresh token""" 85 | # 获取访问令牌 86 | token_response = self.get_token(code) 87 | access_token = token_response.get('access_token') 88 | 89 | # 获取用户信息 90 | user_info = self.get_user_info(access_token) 91 | user_name = user_info.get('name') 92 | user_email = user_info.get('email') 93 | 94 | # 查找系统用户 95 | account = Account.get_by_email(db, user_email) 96 | 97 | # 如果系统用户不存在,则创建系统用户 98 | if not account: 99 | account = Account.create( 100 | db=db, 101 | email=user_email, 102 | name=user_name, 103 | avatar="", 104 | tenant_id=self.tenant_id 105 | ) 106 | logger.info("创建用户: %s", user_email) 107 | 108 | 109 | # 更新用户登录信息 110 | account.last_login_at = datetime.now(UTC) 111 | account.last_login_ip = client_host 112 | if account.status != AccountStatus.ACTIVE: 113 | account.status = AccountStatus.ACTIVE 114 | if account.name != user_name: 115 | account.name = user_name 116 | db.commit() 117 | 118 | logger.info("用户验证成功: %s", user_email) 119 | 120 | # 生成JWT token 121 | exp_dt = datetime.now(UTC) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) 122 | exp = int(exp_dt.timestamp()) 123 | account_id = str(account.id) 124 | payload = { 125 | "user_id": account_id, # 将UUID转换为字符串 126 | "exp": exp, 127 | "iss": settings.EDITION, 128 | "sub": "Console API Passport", 129 | } 130 | 131 | # 生成access token 132 | console_access_token: str = self.passport_service.issue(payload) 133 | 134 | # 生成并存储refresh token 135 | refresh_token = self.token_service.generate_refresh_token() 136 | self.token_service.store_refresh_token(refresh_token, account_id) 137 | 138 | return { 139 | "access_token": console_access_token, 140 | "refresh_token": refresh_token, 141 | } -------------------------------------------------------------------------------- /app/services/passport.py: -------------------------------------------------------------------------------- 1 | import jwt 2 | 3 | from app.core.config import settings 4 | 5 | 6 | class PassportService: 7 | def __init__(self): 8 | self.sk = settings.SECRET_KEY 9 | 10 | def issue(self, payload): 11 | return jwt.encode(payload, self.sk, algorithm="HS256") -------------------------------------------------------------------------------- /app/services/token.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | from datetime import timedelta 3 | from app.utils.redis import redis_manager 4 | from app.core.config import settings 5 | 6 | class TokenService: 7 | @staticmethod 8 | def generate_refresh_token() -> str: 9 | """生成refresh token""" 10 | return secrets.token_hex(64) 11 | 12 | @staticmethod 13 | def store_refresh_token(refresh_token: str, account_id: str) -> None: 14 | """存储refresh token到Redis""" 15 | refresh_token_key = f"{settings.REFRESH_TOKEN_PREFIX}{refresh_token}" 16 | account_refresh_token_key = f"{settings.ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}" 17 | 18 | # 设置过期时间 19 | REFRESH_TOKEN_EXPIRY = timedelta(days=int(settings.REFRESH_TOKEN_EXPIRE_DAYS)) 20 | 21 | # 存储到Redis 22 | redis_manager.setex(refresh_token_key, REFRESH_TOKEN_EXPIRY, account_id) 23 | redis_manager.setex(account_refresh_token_key, REFRESH_TOKEN_EXPIRY, refresh_token) -------------------------------------------------------------------------------- /app/utils/redis.py: -------------------------------------------------------------------------------- 1 | import redis 2 | from app.core.config import settings 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | redis_manager = redis.Redis( 8 | host=settings.REDIS_HOST, 9 | port=settings.REDIS_PORT, 10 | password=settings.REDIS_PASSWORD, 11 | decode_responses=True 12 | ) 13 | 14 | def check_redis_connection() -> bool: 15 | try: 16 | redis_manager.ping() 17 | return True 18 | except Exception as e: 19 | logger.exception("Redis连接检查失败: %s", str(e)) 20 | return False -------------------------------------------------------------------------------- /assets/image-20250408142818633.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lework/dify-sso/b2f0ae4cc8f26f580917905d11ea73f4217e64e4/assets/image-20250408142818633.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | psycopg2==2.9.10 2 | redis==5.2.1 3 | requests==2.32.3 4 | SQLAlchemy==2.0.39 5 | fastapi==0.115.12 6 | uvicorn==0.34.0 7 | PyJWT==2.10.1 8 | python-dotenv==1.0.1 9 | Flask==3.1.0 10 | Flask-Login==0.6.3 11 | Flask-SQLAlchemy==3.1.1 12 | pydantic-settings==2.8.1 -------------------------------------------------------------------------------- /yaml/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | dify-sso: 3 | image: lework/dify-sso 4 | container_name: dify-sso 5 | restart: always 6 | environment: 7 | # 服务配置 8 | CONSOLE_WEB_URL: 'https://test-dify.test.com' 9 | SECRET_KEY: 'sk-123456' 10 | TENANT_ID: 'f9ea64ca-2a cf-44a7-aca0-123456' 11 | EDITION: 'SELF_HOSTED' 12 | ACCESS_TOKEN_EXPIRE_MINUTES: '300' 13 | REFRESH_TOKEN_EXPIRE_DAYS: '30' 14 | REFRESH_TOKEN_PREFIX: 'refresh_token:' 15 | ACCOUNT_REFRESH_TOKEN_PREFIX: 'account_refresh_token:' 16 | # OIDC配置 17 | OIDC_ENABLED: 'true' 18 | OIDC_CLIENT_ID: '123456' 19 | OIDC_CLIENT_SECRET: '123456' 20 | OIDC_DISCOVERY_URL: 'https://test-dify.sso.test.com/oidc/.well-known/openid-configuration' 21 | OIDC_REDIRECT_URI: 'https://test-dify.test.com/console/api/enterprise/sso/oidc/callback' 22 | OIDC_SCOPE: 'openid profile email' 23 | OIDC_RESPONSE_TYPE: 'code' 24 | # 数据库配置 25 | DB_HOST: '127.0.0.1' 26 | DB_PORT: '5432' 27 | DB_DATABASE: 'dify' 28 | DB_PASSWORD: '123456' 29 | DB_USERNAME: 'dify_admin' 30 | # Redis 配置 31 | REDIS_DB: '13' 32 | REDIS_HOST: '127.0.0.1' 33 | REDIS_PORT: '6379' 34 | REDIS_PASSWORD: '123456' 35 | -------------------------------------------------------------------------------- /yaml/k8s-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ConfigMap 3 | metadata: 4 | name: dify-sso 5 | labels: 6 | app.kubernetes.io/name: dify-sso 7 | data: 8 | # 服务配置 9 | CONSOLE_WEB_URL: 'https://test-dify.test.com' 10 | SECRET_KEY: 'sk-123456' 11 | TENANT_ID: 'f9ea64ca-2acf-44a7-aca0-123456' 12 | EDITION: 'SELF_HOSTED' 13 | ACCESS_TOKEN_EXPIRE_MINUTES: '300' 14 | REFRESH_TOKEN_EXPIRE_DAYS: '30' 15 | REFRESH_TOKEN_PREFIX: 'refresh_token:' 16 | ACCOUNT_REFRESH_TOKEN_PREFIX: 'account_refresh_token:' 17 | 18 | # OIDC配置 19 | OIDC_ENABLED: 'true' 20 | OIDC_CLIENT_ID: '123456' 21 | OIDC_CLIENT_SECRET: '123456' 22 | OIDC_DISCOVERY_URL: 'https://test-dify.sso.test.com/oidc/.well-known/openid-configuration' 23 | OIDC_REDIRECT_URI: 'https://test-dify.test.com/console/api/enterprise/sso/oidc/callback' 24 | OIDC_SCOPE: 'openid profile email' 25 | OIDC_RESPONSE_TYPE: 'code' 26 | 27 | # 数据库配置 28 | DB_HOST: '127.0.0.1' 29 | DB_PORT: '5432' 30 | DB_DATABASE: 'dify' 31 | DB_PASSWORD: '123456' 32 | DB_USERNAME: 'dify_admin' 33 | 34 | # Redis 配置 35 | REDIS_DB: '13' 36 | REDIS_HOST: '127.0.0.1' 37 | REDIS_PORT: '6379' 38 | REDIS_PASSWORD: '123456' 39 | 40 | --- 41 | apiVersion: apps/v1 42 | kind: Deployment 43 | metadata: 44 | name: dify-sso 45 | labels: 46 | app.kubernetes.io/name: dify-sso 47 | spec: 48 | replicas: 1 49 | selector: 50 | matchLabels: 51 | app.kubernetes.io/name: dify-sso 52 | template: 53 | metadata: 54 | labels: 55 | app.kubernetes.io/name: dify-sso 56 | spec: 57 | containers: 58 | - name: dify-sso 59 | image: lework/dify-sso 60 | ports: 61 | - name: api 62 | containerPort: 8000 63 | protocol: TCP 64 | envFrom: 65 | - configMapRef: 66 | name: dify-sso 67 | resources: 68 | limits: 69 | cpu: '500m' 70 | memory: '512Mi' 71 | requests: 72 | cpu: '100m' 73 | memory: '128Mi' 74 | terminationMessagePath: /dev/termination-log 75 | terminationMessagePolicy: File 76 | imagePullPolicy: Always 77 | lifecycle: 78 | preStop: 79 | exec: 80 | command: 81 | - /bin/sh 82 | - -c 83 | - 'sleep 10' 84 | livenessProbe: 85 | httpGet: 86 | path: /health 87 | port: api 88 | scheme: HTTP 89 | initialDelaySeconds: 3 90 | periodSeconds: 10 91 | successThreshold: 1 92 | failureThreshold: 3 93 | timeoutSeconds: 5 94 | readinessProbe: 95 | httpGet: 96 | path: /health 97 | port: api 98 | scheme: HTTP 99 | initialDelaySeconds: 3 100 | periodSeconds: 10 101 | successThreshold: 1 102 | failureThreshold: 3 103 | timeoutSeconds: 5 104 | restartPolicy: Always 105 | terminationGracePeriodSeconds: 90 106 | dnsPolicy: ClusterFirst 107 | securityContext: {} 108 | schedulerName: default-scheduler 109 | enableServiceLinks: false 110 | strategy: 111 | type: RollingUpdate 112 | rollingUpdate: 113 | maxUnavailable: 25% 114 | maxSurge: 25% 115 | revisionHistoryLimit: 5 116 | progressDeadlineSeconds: 600 117 | 118 | --- 119 | apiVersion: v1 120 | kind: Service 121 | metadata: 122 | name: dify-sso 123 | labels: 124 | app.kubernetes.io/name: dify-sso 125 | spec: 126 | ports: 127 | - name: http-api 128 | protocol: TCP 129 | port: 8000 130 | targetPort: api 131 | selector: 132 | app.kubernetes.io/name: dify-sso 133 | type: ClusterIP 134 | sessionAffinity: None 135 | ipFamilies: 136 | - IPv4 137 | ipFamilyPolicy: SingleStack 138 | internalTrafficPolicy: Cluster 139 | --------------------------------------------------------------------------------