├── .gitignore ├── README.md ├── README_cn.md ├── backend ├── .cursorrules ├── .env.template ├── .gitignore ├── Dockerfile ├── app.py ├── assets │ ├── demo_columns.csv │ └── demo_tables.csv ├── base_enum.py ├── database.py ├── docker_build.sh ├── docker_run.sh ├── dto │ ├── ai_comment_dto.py │ ├── definition_doc_query_result_dto.py │ ├── definition_rule_dto.py │ ├── disable_table_query.py │ ├── gen_ai_comments_dto.py │ ├── job_dto.py │ ├── learn_result_dto.py │ ├── project_settings_dto.py │ ├── refresh_index_query_dto.py │ ├── schema_dto.py │ ├── selected_column_dto.py │ ├── task_dto.py │ ├── update_ddl_by_query_dto.py │ └── update_task_query.py ├── enums.py ├── gunicorn.conf.py ├── jobs │ ├── __init__.py │ ├── job_job.py │ └── job_vector_db.py ├── logger.py ├── models │ ├── __init__.py │ ├── base.py │ ├── definition_column.py │ ├── definition_doc.py │ ├── definition_relation.py │ ├── definition_rule.py │ ├── definition_table.py │ ├── job.py │ ├── project.py │ ├── task.py │ ├── task_column.py │ ├── task_doc.py │ ├── task_sql.py │ └── task_table.py ├── prompt_templates │ ├── gen_ai_comments.mustache │ ├── gen_related_columns.mustache │ ├── gen_sql.mustache │ ├── learn.mustache │ └── optimize_question.mustache ├── requirements.txt ├── routes │ ├── __init__.py │ ├── main.py │ ├── project.py │ └── test.py ├── services │ ├── __init__.py │ ├── def_service.py │ ├── job_service.py │ ├── openai_service.py │ ├── project_service.py │ ├── task_service.py │ └── translate_service.py ├── start.sh ├── utils │ ├── __init__.py │ ├── prompt_util.py │ ├── schemas.py │ ├── structure_util.py │ └── utils.py ├── vector_stores.py ├── vectors │ ├── __init__.py │ ├── translate_wrapper.py │ ├── vector_chroma.py │ └── vector_store.py └── wsgi.py └── frontend ├── .cursorrules ├── .gitignore ├── eslint.config.js ├── getOpenapi.cjs ├── index.html ├── openapitools.json ├── package.json ├── pnpm-lock.yaml ├── postcss.config.js ├── public ├── columns_template.csv ├── favicon.svg └── tables_template.csv ├── src ├── App.tsx ├── components │ ├── AppSidebar.tsx │ ├── LanguageSwitcher.tsx │ ├── ProjectLayout.tsx │ ├── ddl │ │ ├── DDL.tsx │ │ ├── FileImportTab.tsx │ │ ├── QueryImportTab.tsx │ │ ├── RelationEditModal.tsx │ │ ├── TableCommentEditor.tsx │ │ ├── TableList.tsx │ │ └── UploadDDLModal.tsx │ ├── docs │ │ └── DocList.tsx │ ├── projects │ │ └── ProjectList.tsx │ ├── records │ │ ├── GenerationRecords.tsx │ │ ├── JobList.tsx │ │ ├── LearningDetails.tsx │ │ ├── MainContent.tsx │ │ ├── OptimizedQuestion.tsx │ │ ├── QuestionInput.tsx │ │ ├── QuestionList.tsx │ │ ├── QuestionSupplement.tsx │ │ ├── Refs.tsx │ │ ├── RuleList.tsx │ │ ├── SqlFeedback.tsx │ │ ├── SqlLearning.tsx │ │ ├── SqlResult.tsx │ │ └── refs │ │ │ ├── AICommentModal.tsx │ │ │ ├── DDLRefs.tsx │ │ │ ├── DDLRefsModal.css │ │ │ ├── DDLRefsModal.tsx │ │ │ ├── DocRefs.tsx │ │ │ ├── RelatedColumnsRefs.tsx │ │ │ └── SqlRefs.tsx │ ├── rules │ │ └── RuleList.tsx │ └── settings │ │ └── Settings.tsx ├── consts.ts ├── hooks │ ├── useAppService.ts │ ├── useEnvService.ts │ └── useTask.ts ├── i18n │ ├── en-US.json │ ├── i18n.ts │ └── zh-CN.json ├── index.css ├── main.tsx ├── store │ ├── hooks.ts │ ├── index.ts │ └── slices │ │ ├── appSlice.ts │ │ ├── ddlSlice.ts │ │ ├── recordsSlice.ts │ │ ├── schemaSlice.ts │ │ └── taskSlice.ts ├── typings.d.ts ├── utils │ ├── bizUtil.ts │ ├── learnUtil.ts │ └── stringUtils.ts └── vite-env.d.ts ├── tailwind.config.js ├── tsconfig.app.json ├── tsconfig.json ├── tsconfig.node.json └── vite.config.ts /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | .DS_Store 3 | .env 4 | dist/ 5 | *.log 6 | .idea/ 7 | .vscode/ 8 | *.pyc 9 | __pycache__/ 10 | venv/ 11 | .env.local 12 | .env.development.local 13 | .env.test.local 14 | .env.production.local 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | Generate precise SQL queries through simple natural language descriptions. 3 | 4 | ### Support Status 5 | - Interface languages: Chinese, English (more to come) 6 | - Supported databases: Any type 7 | - For query import scripts, currently supports: SQLite, MySQL, PostgreSQL, SQLServer 8 | - Supported vector databases: Chroma (more to come) 9 | - Supported LLMs: OpenAI (more to come) 10 | - Supported translation (optional): Azure Translator (more to come) 11 | 12 | ### SQL Generation Process 13 | Similar to a workflow concept, executed step by step: 14 | 15 | 1. Match business documents 16 | 2. Match generation records 17 | 3. AI generates potentially relevant fields 18 | 4. Match most similar tables and fields based on AI-generated fields 19 | 5. AI generates SQL 20 | 6. Learning: Learn from results to improve table comments, field comments, and field relationships 21 | 22 | ## Features 23 | - Progressive improvement, becomes more accurate with use 24 | - Convenient and easy to understand 25 | - Rich and refined functionality 26 | 27 | ## Usage 28 | ### Project Structure 29 | - `backend`: Backend project, using Python and Flask framework 30 | - `frontend`: Frontend project, using React and Tailwind CSS framework 31 | 32 | ### Dependencies 33 | Backend project depends on the following: 34 | 35 | - Database: Stores application data. Default configuration uses SQLite (created in `backend/instance` directory). Can be modified to connect to other databases. Tables are automatically created on first startup 36 | - Vector database: Stores application data. Can start a Chroma docker container and modify configuration to connect to it 37 | - LLM: Uses OpenAI by default. Can modify apikey and other configurations 38 | - Translation: Uses Azure Translator by default. Can modify configuration. If you're using English, this is optional. However, it's highly recommended for other languages as it greatly improves vector database matching accuracy 39 | 40 | ### Backend Deployment 41 | - Install dependencies: `pip install -r requirements.txt` 42 | - Local development: Enter backend directory, run `python app.py` 43 | - Production deployment: Deploy using gunicorn, refer to `./start.sh` 44 | - Docker deployment: 45 | - Refer to `./docker_build.sh` to build Docker image 46 | - Refer to `./docker_run.sh` to run Docker image 47 | 48 | ### Frontend Deployment 49 | - Local development: Enter frontend directory, run `npm run dev` for development debugging, or `npm run start` to fetch backend API and regenerate API definitions before starting debug 50 | - Production deployment: Use vite to package, run `npm run build` 51 | 52 | ### Development URLs 53 | In local development mode, frontend URL: `http://localhost:5173` 54 | 55 | ## Screenshots 56 | ### Projects page 57 | ![image](https://github.com/user-attachments/assets/f20f4bfd-8e21-435c-a195-088381ca8b97) 58 | 59 | ### Project page 60 | ![image](https://github.com/user-attachments/assets/e56e74ea-bf18-4bc4-bbdb-f36c158a3bbb) 61 | ![image](https://github.com/user-attachments/assets/998ffe78-9f67-4c69-8c54-8164a9c2c938) 62 | ![image](https://github.com/user-attachments/assets/e77129b5-c315-447d-b552-77e2da86a29d) 63 | 64 | ### DDL page 65 | ![image](https://github.com/user-attachments/assets/9f8c3798-15ce-4f7b-bbb4-fabf07d39fc8) 66 | 67 | ### Document page 68 | ![image](https://github.com/user-attachments/assets/c06320d4-2670-4391-85b0-4db19d291e33) 69 | 70 | ### Rule page 71 | ![image](https://github.com/user-attachments/assets/3d676c9e-8b33-4036-9d66-8975bcc1e0b9) 72 | 73 | ### Settings page 74 | ![image](https://github.com/user-attachments/assets/9793426f-7bce-4a50-9f3b-ca94fcc87247) 75 | -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | ## 介绍 2 | 通过简洁的自然语言描述,生成精准的SQL查询语句。 3 | 4 | ### 支持情况 5 | - 界面语言:中文,英文(后续会支持更多) 6 | - 支持数据库:任意类型 7 | - 查询导入示例脚本那里,目前只添加了:SQLite, MySQL, PostgreSQL, SQLServer 8 | - 支持向量数据库:chroma(后续会支持更多) 9 | - 支持LLM:OpenAI(后续会支持更多) 10 | - 支持翻译(可选):Azure Translator(后续会支持更多) 11 | 12 | ### 生成SQL流程 13 | 类似工作流的概念,按步骤一步步执行: 14 | 15 | 1. 匹配业务文档 16 | 2. 匹配生成记录 17 | 3. AI生成可能相关的字段 18 | 4. 根据AI生成的字段匹配出最相似的表与字段 19 | 5. AI生成SQL 20 | 6. 学习,根据结果学习表备注,字段备注,字段关系 21 | 22 | ## 特点 23 | - 渐进式完善,越用越准确 24 | - 方便快捷,容易理解 25 | - 功能丰富且精细 26 | 27 | ## 使用 28 | ### 项目结构 29 | - `backend`: 后端项目,使用python, flask框架 30 | - `frontend`: 前端项目,使用react,tailwindcss框架 31 | 32 | ### 依赖要求 33 | 后端项目依赖以下内容: 34 | 35 | - 数据库,存放应用数据,默认配置的sqlite数据库连接(会生成在`backend/instance`目录),可以自行修改连接为其他数据库,第一次启动会自动创建表 36 | - 向量数据库,存放应用数据,可以先启动一个chroma的docker容器,然后修改配置连接到chroma 37 | - LLM,默认使用OpenAI,可以自行修改apikey等配置 38 | - 翻译,默认使用Azure Translator,可以自行修改配置,如果你使用的是英语,可以不用配置,否则非常推荐配置上,会大大提高向量数据库匹配的准确性 39 | 40 | ### 后端部署 41 | - 安装依赖: `pip install -r requirements.txt` 42 | - 本地开发调试:进入backend目录,`python app.py` 43 | - 线上部署:使用gunicorn部署,参考脚本`./start.sh` 44 | - Docker部署: 45 | - 参考脚本`./docker_build.sh`,构建Docker镜像 46 | - 参考脚本`./docker_run.sh`,运行Docker镜像 47 | 48 | ### 前端部署 49 | - 本地开发调试:进入frontend目录,`npm run dev`开启开发调试,或者`npm run start`拉取后端接口重新生成接口定义并开启调试 50 | - 线上部署:使用vite打包,`npm run build` 51 | 52 | ### 开发网址 53 | 本地开发调试情况下,前端网址:`http://localhost:5173` 54 | 55 | ## 截图 56 | ### 项目列表页面 57 | ![image](https://github.com/user-attachments/assets/f20f4bfd-8e21-435c-a195-088381ca8b97) 58 | 59 | ### 项目页面 60 | ![image](https://github.com/user-attachments/assets/e56e74ea-bf18-4bc4-bbdb-f36c158a3bbb) 61 | ![image](https://github.com/user-attachments/assets/998ffe78-9f67-4c69-8c54-8164a9c2c938) 62 | ![image](https://github.com/user-attachments/assets/e77129b5-c315-447d-b552-77e2da86a29d) 63 | 64 | ### DDL页面 65 | ![image](https://github.com/user-attachments/assets/9f8c3798-15ce-4f7b-bbb4-fabf07d39fc8) 66 | 67 | ### 文档页面 68 | ![image](https://github.com/user-attachments/assets/c06320d4-2670-4391-85b0-4db19d291e33) 69 | 70 | ### 规则页面 71 | ![image](https://github.com/user-attachments/assets/3d676c9e-8b33-4036-9d66-8975bcc1e0b9) 72 | 73 | ### 设置页面 74 | ![image](https://github.com/user-attachments/assets/9793426f-7bce-4a50-9f3b-ca94fcc87247) 75 | -------------------------------------------------------------------------------- /backend/.cursorrules: -------------------------------------------------------------------------------- 1 | This project use python+flask+apiflask+psycopg2-binary+SQLAlchemy -------------------------------------------------------------------------------- /backend/.env.template: -------------------------------------------------------------------------------- 1 | # database 2 | DATABASE_URL=sqlite:///sqlwise.db 3 | # DATABASE_URL=postgresql://postgres:password@localhost:5432/sqlwise 4 | 5 | # chroma 6 | CHROMA_HOST=localhost 7 | CHROMA_PORT=8051 8 | 9 | # openai 10 | OPENAI_API_KEY=sk-xxx 11 | OPENAI_API_BASE=https://api.openai.com/v1 12 | OPENAI_API_MODEL=gpt-4o-mini 13 | OPENAI_API_TEMPERATURE=0.0 14 | 15 | # azure translator(optional, recommended for non-English languages!) 16 | # AZURE_TRANSLATOR_KEY=xxx 17 | # AZURE_TRANSLATOR_ENDPOINT=https://api.cognitive.microsofttranslator.com 18 | # AZURE_TRANSLATOR_LOCATION=global -------------------------------------------------------------------------------- /backend/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | venv/ 25 | env/ 26 | ENV/ 27 | .env 28 | 29 | # IDE 30 | .idea/ 31 | .vscode/ 32 | *.swp 33 | *.swo 34 | .DS_Store 35 | 36 | # Database 37 | instance/ 38 | *.db 39 | *.sqlite3 40 | *.sqlite 41 | 42 | # Logs 43 | *.log 44 | logs/ 45 | 46 | # Local development settings 47 | .env.local 48 | .env.development.local 49 | .env.test.local 50 | .env.production.local 51 | 52 | # Coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | 62 | # Gunicorn 63 | gunicorn.pid 64 | 65 | # Temp data 66 | temp-data/ -------------------------------------------------------------------------------- /backend/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12 2 | 3 | # set working directory 4 | WORKDIR /app 5 | 6 | # set environment variable 7 | ENV PYTHONUNBUFFERED=1 8 | 9 | # copy dependency files 10 | COPY requirements.txt . 11 | 12 | # install python dependencies 13 | RUN pip3 install --no-cache-dir -r requirements.txt 14 | 15 | # copy application code 16 | COPY . . 17 | 18 | # expose port 19 | EXPOSE 8000 20 | 21 | CMD ["./start.sh"] -------------------------------------------------------------------------------- /backend/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from flask_smorest import Api 3 | from dotenv import load_dotenv 4 | from flask_cors import CORS 5 | from sqlalchemy.exc import SQLAlchemyError 6 | from openai import OpenAIError 7 | import traceback 8 | from jobs import init_scheduler 9 | from database import db, init_db, OptimisticLockException 10 | from logger import init_logger 11 | from routes.main import main_bp 12 | from routes.test import test_bp 13 | from routes.project import project_bp 14 | import os 15 | 16 | # Load .env file 17 | load_dotenv() 18 | 19 | # Import routes 20 | from routes.main import main_bp 21 | from routes.test import test_bp 22 | from routes.project import project_bp 23 | 24 | # Initialize logger 25 | logger = init_logger() 26 | 27 | # Create Flask instance 28 | app = Flask(__name__) 29 | 30 | # API configuration 31 | app.config["API_TITLE"] = "SQLWise API" 32 | app.config["API_VERSION"] = "1.0" 33 | app.config["OPENAPI_VERSION"] = "3.0.2" 34 | app.config["OPENAPI_URL_PREFIX"] = "/api" 35 | app.config["OPENAPI_SWAGGER_UI_PATH"] = "/swagger-ui" 36 | app.config["OPENAPI_SWAGGER_UI_URL"] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist/" 37 | 38 | # Initialize API 39 | api = Api(app) 40 | 41 | translator_key = os.getenv('AZURE_TRANSLATOR_KEY') 42 | if translator_key: 43 | app.logger.info("Azure Translator service is active!") 44 | else: 45 | app.logger.info("Azure Translator service is not active!") 46 | 47 | # Log startup 48 | app.logger.info('Application startup') 49 | 50 | # Register error handlers 51 | @app.errorhandler(SQLAlchemyError) 52 | def handle_db_error(error): 53 | app.logger.error(f"Database error: {str(error)}\n{traceback.format_exc()}") 54 | return { 55 | 'code': 500, 56 | 'message': f"Database error: {str(error)}" 57 | }, 500 58 | 59 | @app.errorhandler(ValueError) 60 | def handle_validation_error(error): 61 | app.logger.error(f"Validation error: {str(error)}\n{traceback.format_exc()}") 62 | return { 63 | 'code': 400, 64 | 'message': str(error) 65 | }, 400 66 | 67 | @app.errorhandler(Exception) 68 | def handle_generic_error(error): 69 | app.logger.error(f"Internal server error: {str(error)}\n{traceback.format_exc()}") 70 | return { 71 | 'code': 500, 72 | 'message': f"Internal server error: {str(error)}" 73 | }, 500 74 | 75 | @app.errorhandler(OpenAIError) 76 | def handle_openai_error(error): 77 | app.logger.error(f"OpenAI API error: {str(error)}\n{traceback.format_exc()}") 78 | return { 79 | 'code': 500, 80 | 'message': f"OpenAI API error: {str(error)}" 81 | }, 500 82 | 83 | @app.errorhandler(OptimisticLockException) 84 | def handle_optimistic_lock_error(error): 85 | return { 86 | 'code': 409, 87 | 'message': str(error) 88 | }, 409 89 | 90 | @app.before_request 91 | def setup_db_connection(): 92 | """确保每个请求都有新的数据库会话""" 93 | try: 94 | if not db.session.is_active: 95 | db.session.remove() 96 | db.session = db.create_scoped_session() 97 | app.logger.debug("Created new database session") 98 | except Exception as e: 99 | app.logger.error(f"Error setting up database session: {str(e)}") 100 | raise 101 | 102 | @app.teardown_appcontext 103 | def shutdown_session(exception=None): 104 | """确保请求结束后正确清理数据库会话""" 105 | try: 106 | if db.session.is_active: 107 | if exception: 108 | db.session.rollback() 109 | db.session.remove() 110 | except Exception as e: 111 | app.logger.error(f"Error during session cleanup: {str(e)}") 112 | finally: 113 | # 确保连接返回到连接池 114 | db.session.close() 115 | 116 | # Enable CORS 117 | CORS(app, resources={r"/*": {"origins": "*"}}) 118 | 119 | # Register blueprints 120 | api.register_blueprint(main_bp, url_prefix='/main') 121 | api.register_blueprint(test_bp, url_prefix='/test') 122 | api.register_blueprint(project_bp) 123 | 124 | # Initialize database 125 | init_db(app) 126 | 127 | if __name__ == "__main__": 128 | # This code block only runs when directly running python app.py 129 | # When using gunicorn, this code block will not be executed 130 | if os.environ.get('WERKZEUG_RUN_MAIN') == 'true': 131 | init_scheduler(app) 132 | app.run(debug=True, host='0.0.0.0', port=8000) 133 | -------------------------------------------------------------------------------- /backend/assets/demo_columns.csv: -------------------------------------------------------------------------------- 1 | TABLE_NAME,COLUMN_NAME,COLUMN_TYPE,COLUMN_COMMENT 2 | user,id,varchar(36),User ID 3 | user,username,varchar(50),Username 4 | user,password,varchar(100),Encrypted Password 5 | user,email,varchar(100),User Email 6 | user,phone,varchar(20),Phone Number 7 | user,created_at,timestamp,Account Creation Time 8 | user,updated_at,timestamp,Account Update Time 9 | user,status,smallint,Account Status: 1-Active 0-Inactive 10 | 11 | product,id,varchar(36),Product ID 12 | product,category_id,varchar(36),Category ID 13 | product,name,varchar(100),Product Name 14 | product,description,text,Product Description 15 | product,price,decimal(10;2),Product Price 16 | product,stock,int,Current Stock 17 | product,created_at,timestamp,Product Creation Time 18 | product,updated_at,timestamp,Product Update Time 19 | product,status,smallint,Product Status: 1-Active 0-Inactive 20 | 21 | category,id,varchar(36),Category ID 22 | category,name,varchar(50),Category Name 23 | category,parent_id,varchar(36),Parent Category ID 24 | category,description,text,Category Description 25 | category,created_at,timestamp,Category Creation Time 26 | 27 | order,id,varchar(36),Order ID 28 | order,user_id,varchar(36),User ID 29 | order,address_id,varchar(36),Delivery Address ID 30 | order,total_amount,decimal(10;2),Order Total Amount 31 | order,status,smallint,Order Status: 0-Pending 1-Paid 2-Shipped 3-Delivered 4-Cancelled 32 | order,created_at,timestamp,Order Creation Time 33 | order,updated_at,timestamp,Order Update Time 34 | 35 | order_item,id,varchar(36),Order Item ID 36 | order_item,order_id,varchar(36),Order ID 37 | order_item,product_id,varchar(36),Product ID 38 | order_item,quantity,int,Product Quantity 39 | order_item,price,decimal(10;2),Product Price at Order Time 40 | order_item,subtotal,decimal(10;2),Item Subtotal 41 | 42 | payment,id,varchar(36),Payment ID 43 | payment,order_id,varchar(36),Order ID 44 | payment,amount,decimal(10;2),Payment Amount 45 | payment,payment_method,varchar(20),Payment Method 46 | payment,status,smallint,Payment Status: 0-Pending 1-Success 2-Failed 47 | payment,created_at,timestamp,Payment Time 48 | payment,transaction_id,varchar(100),External Transaction ID 49 | 50 | address,id,varchar(36),Address ID 51 | address,user_id,varchar(36),User ID 52 | address,receiver_name,varchar(50),Receiver Name 53 | address,phone,varchar(20),Receiver Phone 54 | address,province,varchar(50),Province 55 | address,city,varchar(50),City 56 | address,district,varchar(50),District 57 | address,detail_address,varchar(200),Detailed Address 58 | address,is_default,smallint,Default Address Flag: 1-Yes 0-No 59 | -------------------------------------------------------------------------------- /backend/assets/demo_tables.csv: -------------------------------------------------------------------------------- 1 | TABLE_NAME,TABLE_COMMENT 2 | user,User Table 3 | product,Product Table 4 | category,Product Category Table 5 | order,Order Table 6 | order_item,Order Items Table 7 | payment,Payment Table 8 | address,User Address Table 9 | -------------------------------------------------------------------------------- /backend/base_enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class BaseEnum(Enum): 4 | """Base enumeration class providing common methods""" 5 | def __init__(self, value, display_name): 6 | self._value_ = value 7 | self.display_name = display_name 8 | 9 | @property 10 | def value(self): 11 | """Get the status value""" 12 | return self._value_ 13 | 14 | @classmethod 15 | def values(cls): 16 | """Get a list of all status values""" 17 | return [item.value for item in cls] 18 | 19 | @classmethod 20 | def names(cls): 21 | """Get a list of all status names""" 22 | return [item.display_name for item in cls] 23 | 24 | @classmethod 25 | def get_by_value(cls, value): 26 | """Get enum value by value""" 27 | for item in cls: 28 | if item.value == value: 29 | return item 30 | return None 31 | 32 | @classmethod 33 | def get_display_name_by_value(cls, value): 34 | """Get status name by value""" 35 | item = cls.get_by_value(value) 36 | return item.display_name if item else None -------------------------------------------------------------------------------- /backend/database.py: -------------------------------------------------------------------------------- 1 | from flask_sqlalchemy import SQLAlchemy 2 | import importlib 3 | import pkgutil 4 | import os 5 | from contextlib import contextmanager 6 | 7 | # Define database engine options as a constant 8 | DB_ENGINE_OPTIONS = { 9 | "pool_pre_ping": True, 10 | "pool_recycle": 300, # 5 minutes 11 | "pool_size": 5, # 增加连接池大小 12 | "pool_timeout": 30, # 添加连接超时 13 | "max_overflow": 10, # 增加最大溢出连接数 14 | "pool_use_lifo": True, # 使用LIFO策略以更好地处理突发负载 15 | } 16 | 17 | # Create SQLAlchemy instance with engine options 18 | db = SQLAlchemy(engine_options=DB_ENGINE_OPTIONS) 19 | 20 | class OptimisticLockException(Exception): 21 | """Exception for optimistic locking""" 22 | def __init__(self, message="Concurrent modification conflict, please refresh and try again"): 23 | super().__init__(message) 24 | 25 | def init_db(app): 26 | """Initialize database""" 27 | app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL") 28 | app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False 29 | app.config["SQLALCHEMY_ENGINE_OPTIONS"] = DB_ENGINE_OPTIONS 30 | 31 | # Initialize SQLAlchemy 32 | db.init_app(app) 33 | 34 | # Import all models 35 | import models 36 | for _, name, _ in pkgutil.iter_modules(models.__path__): 37 | importlib.import_module(f'models.{name}') 38 | 39 | # Create all tables 40 | with app.app_context(): 41 | db.create_all() 42 | 43 | # def get_db(): 44 | # """Get database session""" 45 | # return db.session 46 | 47 | @contextmanager 48 | def session_scope(read_only=False): 49 | """提供一个事务范围的会话上下文 50 | Args: 51 | read_only (bool): 如果为True,则不会尝试提交事务 52 | """ 53 | session = db.session 54 | try: 55 | yield session 56 | if not read_only: 57 | session.commit() 58 | except: 59 | session.rollback() 60 | raise 61 | finally: 62 | session.close() 63 | -------------------------------------------------------------------------------- /backend/docker_build.sh: -------------------------------------------------------------------------------- 1 | # build docker image 2 | docker build -t sqlwise . -------------------------------------------------------------------------------- /backend/docker_run.sh: -------------------------------------------------------------------------------- 1 | # start docker container 2 | docker run \ 3 | -it \ 4 | -p 8000:8000 \ 5 | --env-file .env.template \ 6 | --name sqlwise \ 7 | sqlwise -------------------------------------------------------------------------------- /backend/dto/ai_comment_dto.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | from dataclasses_json import dataclass_json 4 | from dto.learn_result_dto import TableRelationDTO 5 | 6 | @dataclass_json 7 | @dataclass 8 | class ColumnAICommentDTO: 9 | col: str 10 | comment: str 11 | 12 | @dataclass_json 13 | @dataclass 14 | class UpdateAICommentDTO: 15 | project_id: int 16 | table: str 17 | comment: str 18 | columns: List[ColumnAICommentDTO] 19 | relations: List[TableRelationDTO] -------------------------------------------------------------------------------- /backend/dto/definition_doc_query_result_dto.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class DefinitionDocQueryResultDTO: 5 | id: int 6 | def_doc: str 7 | def_selected: bool 8 | disabled: bool -------------------------------------------------------------------------------- /backend/dto/definition_rule_dto.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class DefinitionRuleDTO: 5 | id: int 6 | name: str 7 | content: str 8 | def_selected: bool 9 | disabled: bool -------------------------------------------------------------------------------- /backend/dto/disable_table_query.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class DisableTableQueryDTO: 5 | project_id: int 6 | table: str 7 | disabled: bool -------------------------------------------------------------------------------- /backend/dto/gen_ai_comments_dto.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from marshmallow_dataclass import dataclass 3 | 4 | @dataclass 5 | class GenAICommentsColumnDTO: 6 | """ 7 | Column information DTO 8 | """ 9 | column: str 10 | type: str 11 | comment: str 12 | 13 | @dataclass 14 | class GenAICommentsTableDTO: 15 | """ 16 | Table information DTO 17 | """ 18 | table: str 19 | comment: str 20 | columns: List[GenAICommentsColumnDTO] 21 | 22 | @dataclass 23 | class GenAICommentsResponseColumnDTO: 24 | """Column information for AI generated comments""" 25 | column: str 26 | comment: str 27 | 28 | @dataclass 29 | class GenAICommentsResponseDTO: 30 | """Response DTO for AI generated comments""" 31 | table: str 32 | comment: str 33 | columns: list[GenAICommentsResponseColumnDTO] 34 | -------------------------------------------------------------------------------- /backend/dto/job_dto.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import datetime 3 | 4 | @dataclass 5 | class JobDTO: 6 | id: int 7 | version: int 8 | task_id: int 9 | project_id: int 10 | job_type: str 11 | job_data: dict 12 | job_status: str 13 | job_type_display_name: str 14 | job_status_display_name: str 15 | error_message: str 16 | created_at: datetime 17 | updated_at: datetime 18 | job_cost_time: int -------------------------------------------------------------------------------- /backend/dto/learn_result_dto.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | @dataclass 5 | class TableDescDTO: 6 | table: str 7 | desc: str 8 | 9 | @dataclass 10 | class ColumnDescDTO: 11 | table: str 12 | column: str 13 | desc: str 14 | 15 | @dataclass 16 | class TableRelationDTO: 17 | table1: str 18 | column1: str 19 | table2: str 20 | column2: str 21 | relation_type: str 22 | 23 | @dataclass 24 | class LearnResultDTO: 25 | tables: List[TableDescDTO] 26 | columns: List[ColumnDescDTO] 27 | relations: List[TableRelationDTO] 28 | -------------------------------------------------------------------------------- /backend/dto/project_settings_dto.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class ProjectSettingsDTO: 5 | name: str 6 | description: str 7 | db_type: str 8 | db_version: str 9 | 10 | vector_waiting_table_count: int 11 | vector_waiting_column_count: int 12 | vector_waiting_doc_count: int 13 | vector_waiting_task_count: int 14 | 15 | definition_doc_count: int 16 | definition_rule_count: int 17 | definition_table_count: int 18 | definition_column_count: int 19 | definition_relation_count: int 20 | 21 | task_count: int 22 | task_doc_count: int 23 | task_sql_count: int 24 | task_table_count: int 25 | task_column_count: int 26 | job_count: int 27 | -------------------------------------------------------------------------------- /backend/dto/refresh_index_query_dto.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class RefreshIndexQueryDTO: 5 | project_id: int 6 | refresh_table: bool = False 7 | refresh_column: bool = False 8 | refresh_doc: bool = False 9 | refresh_sql: bool = False 10 | -------------------------------------------------------------------------------- /backend/dto/schema_dto.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | from dto.definition_rule_dto import DefinitionRuleDTO 4 | 5 | @dataclass 6 | class DefinitionTableDTO: 7 | table: str 8 | comment: str 9 | ai_comment: str 10 | disabled: bool 11 | 12 | @dataclass 13 | class DefinitionColumnDTO: 14 | table: str 15 | type: str 16 | column: str 17 | comment: str 18 | ai_comment: str 19 | 20 | @dataclass 21 | class DefinitionRelationDTO: 22 | table1: str 23 | column1: str 24 | table2: str 25 | column2: str 26 | relation_type: str 27 | 28 | @dataclass 29 | class SchemaDTO: 30 | tables: List[DefinitionTableDTO] 31 | columns: List[DefinitionColumnDTO] 32 | relations: List[DefinitionRelationDTO] 33 | rules: List[DefinitionRuleDTO] -------------------------------------------------------------------------------- /backend/dto/selected_column_dto.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class SelectedColumnDTO: 5 | table: str 6 | columns: list[str] -------------------------------------------------------------------------------- /backend/dto/task_dto.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from datetime import datetime 3 | from dto.job_dto import JobDTO 4 | from dto.learn_result_dto import LearnResultDTO 5 | from marshmallow_dataclass import dataclass 6 | 7 | @dataclass 8 | class TaskTableDTO: 9 | table_name: str 10 | 11 | @dataclass 12 | class TaskColumnDTO: 13 | table_name: str 14 | column_name: str 15 | 16 | @dataclass 17 | class TaskDocDTO: 18 | doc_id: int 19 | def_doc: str 20 | 21 | @dataclass 22 | class TaskSQLDTO: 23 | task_id: int 24 | question: str 25 | sql: str 26 | 27 | @dataclass 28 | class TaskDTO: 29 | id: int 30 | project_id: int 31 | version: int 32 | question: str 33 | question_supplement: str 34 | options: dict 35 | rules: List[int] 36 | related_columns: str 37 | sql: str 38 | sql_right: bool 39 | sql_refer: bool 40 | learn_result: LearnResultDTO 41 | created_at: datetime 42 | updated_at: datetime 43 | tables: List[TaskTableDTO] 44 | columns: List[TaskColumnDTO] 45 | docs: List[TaskDocDTO] 46 | sqls: List[TaskSQLDTO] 47 | jobs: List[JobDTO] -------------------------------------------------------------------------------- /backend/dto/update_ddl_by_query_dto.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class UpdateDDLByQueryTableDTO: 5 | table: str 6 | comment: str | None = None 7 | 8 | @dataclass 9 | class UpdateDDLByQueryColumnDTO: 10 | table: str 11 | column: str 12 | type: str 13 | comment: str | None = None 14 | 15 | @dataclass 16 | class UpdateDDLByQueryDTO: 17 | project_id: int 18 | tables: list[UpdateDDLByQueryTableDTO] 19 | columns: list[UpdateDDLByQueryColumnDTO] 20 | -------------------------------------------------------------------------------- /backend/dto/update_task_query.py: -------------------------------------------------------------------------------- 1 | from marshmallow_dataclass import dataclass 2 | 3 | @dataclass 4 | class UpdateTaskColumnQueryDTO: 5 | table: str 6 | columns: list[str] 7 | 8 | @dataclass 9 | class UpdateTaskQueryDTO: 10 | task_id: int 11 | 12 | question: str 13 | question_supplement: str 14 | options: dict 15 | rules: list[int] 16 | doc_ids: list[int] 17 | sql_ids: list[int] 18 | columns: list[UpdateTaskColumnQueryDTO] 19 | sql: str 20 | 21 | question_modified: bool | None = False 22 | question_supplement_modified: bool | None = False 23 | options_modified: bool | None = False 24 | rules_modified: bool | None = False 25 | doc_ids_modified: bool | None = False 26 | sql_ids_modified: bool | None = False 27 | columns_modified: bool | None = False 28 | sql_modified: bool | None = False 29 | -------------------------------------------------------------------------------- /backend/enums.py: -------------------------------------------------------------------------------- 1 | from base_enum import BaseEnum 2 | 3 | class JobType(BaseEnum): 4 | """Job type enumeration""" 5 | GEN_RELATED_COLUMNS = ('gen_related_columns', 'Generate Related Columns') 6 | MATCH_DOC = ('match_doc', 'Match Document') 7 | MATCH_SQL_LOG = ('match_sql_log', 'Match SQL Log') 8 | MATCH_DDL = ('match_ddl', 'Match DDL') 9 | GENERATE_SQL = ('generate_sql', 'Generate SQL') 10 | LEARN_FROM_SQL = ('learn_from_sql', 'Learn') 11 | 12 | class JobStatus(BaseEnum): 13 | """Job status enumeration""" 14 | INIT = ('init', 'Initial') # Initial state 15 | RUNNING = ('running', 'Running') # In progress 16 | SUCCESS = ('success', 'Success') # Successful 17 | FAIL = ('fail', 'Failed') # Failed 18 | CANCELED = ('canceled', 'Canceled') # Canceled 19 | 20 | class DbType(BaseEnum): 21 | """Database type enumeration""" 22 | SQLITE = ('sqlite', 'SQLite') 23 | MYSQL = ('mysql', 'MySQL') 24 | POSTGRESQL = ('postgresql', 'PostgreSQL') 25 | SQLSERVER = ('sqlserver', 'SQLServer') 26 | -------------------------------------------------------------------------------- /backend/gunicorn.conf.py: -------------------------------------------------------------------------------- 1 | # Gunicorn configuration file 2 | import multiprocessing 3 | import os 4 | from wsgi import on_post_fork, on_when_ready 5 | 6 | # Get current directory 7 | current_dir = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | # Bind IP and port 10 | bind = "0.0.0.0:8000" 11 | 12 | # Number of worker processes 13 | workers = multiprocessing.cpu_count() * 2 + 1 14 | 15 | # Worker mode 16 | worker_class = "sync" 17 | 18 | # Maximum number of concurrent clients 19 | worker_connections = 2000 20 | 21 | # Process ID file 22 | pidfile = os.path.join(current_dir, "gunicorn.pid") 23 | 24 | # Access and error logs 25 | accesslog = os.path.join(current_dir, "logs/access.log") 26 | errorlog = os.path.join(current_dir, "logs/error.log") 27 | 28 | # Log level 29 | loglevel = "info" 30 | 31 | # Prevent multiple scheduler instances 32 | preload_app = True 33 | 34 | # Keep workers alive for long-running tasks 35 | timeout = 300 36 | 37 | # 确保启用钩子函数 38 | post_fork = on_post_fork 39 | when_ready = on_when_ready 40 | -------------------------------------------------------------------------------- /backend/jobs/__init__.py: -------------------------------------------------------------------------------- 1 | from flask_apscheduler import APScheduler 2 | import os 3 | 4 | scheduler = APScheduler() 5 | 6 | def init_scheduler(app): 7 | # Apply scheduler configuration 8 | app.config['SCHEDULER_API_ENABLED'] = True 9 | app.config['SCHEDULER_TIMEZONE'] = "Asia/Shanghai" 10 | 11 | scheduler.init_app(app) 12 | 13 | # Import jobs module to register tasks 14 | from . import job_job 15 | from . import job_vector_db 16 | 17 | scheduler.start() 18 | app.logger.info("Scheduler started successfully") 19 | -------------------------------------------------------------------------------- /backend/jobs/job_job.py: -------------------------------------------------------------------------------- 1 | from . import scheduler 2 | from services.task_service import TaskService 3 | from services.job_service import JobService 4 | from enums import JobType, JobStatus 5 | import asyncio 6 | import time 7 | from models.job import Job 8 | from models.task import Task 9 | from database import db 10 | from app import app 11 | from database import session_scope 12 | 13 | # Define scheduled task 14 | @scheduler.task('interval', id='job', seconds=2, coalesce=True, max_instances=1) 15 | def job(): 16 | with app.app_context(): 17 | with session_scope() as session: 18 | while True: 19 | # Query all jobs with initial status 20 | job_ids = JobService.get_init_job_ids() 21 | if len(job_ids) == 0: 22 | break 23 | print(f"jobs count: {len(job_ids)}") 24 | for job_id in job_ids: 25 | try: 26 | # Update job status 27 | job = session.query(Job).get(job_id) 28 | if job.job_status != JobStatus.INIT.value: 29 | print(f"job {job_id} status is not INIT, skip") 30 | continue 31 | job_type = job.job_type 32 | job.job_status = JobStatus.RUNNING.value 33 | session.commit() 34 | 35 | # Execute task (time-consuming) 36 | start_time = time.time() 37 | if job_type == JobType.GENERATE_SQL.value: 38 | asyncio.run(TaskService.generate_sql_async(session, job_id)) 39 | elif job_type == JobType.GEN_RELATED_COLUMNS.value: 40 | asyncio.run(TaskService.gen_related_columns_async(session, job_id)) 41 | elif job_type == JobType.MATCH_DOC.value: 42 | asyncio.run(TaskService.match_doc_async(session, job_id)) 43 | elif job_type == JobType.MATCH_SQL_LOG.value: 44 | asyncio.run(TaskService.match_sql_log_async(session, job_id)) 45 | elif job_type == JobType.MATCH_DDL.value: 46 | asyncio.run(TaskService.match_ddl_async(session, job_id)) 47 | elif job_type == JobType.LEARN_FROM_SQL.value: 48 | asyncio.run(TaskService.learn_from_sql_async(session, job_id)) 49 | 50 | end_time = time.time() 51 | job_cost_time = int((end_time - start_time) * 1000) 52 | # Update job status 53 | job = session.query(Job).get(job_id) 54 | if job.job_status != JobStatus.RUNNING.value: 55 | print(f"job {job_id} status is not RUNNING, skip") 56 | continue 57 | job.job_status = JobStatus.SUCCESS.value 58 | job.job_cost_time = job_cost_time 59 | session.commit() 60 | 61 | # Create next job based on task options 62 | task_options = session.query(Task).get(job.task_id).options 63 | if job_type == JobType.MATCH_DOC.value: 64 | if task_options.get('autoMatchSqlLog'): 65 | JobService.create_job(session, job.task_id, JobType.MATCH_SQL_LOG.value) 66 | elif job_type == JobType.MATCH_SQL_LOG.value: 67 | if task_options.get('autoGenRelatedColumns'): 68 | JobService.create_job(session, job.task_id, JobType.GEN_RELATED_COLUMNS.value) 69 | elif job_type == JobType.GEN_RELATED_COLUMNS.value: 70 | if task_options.get('autoMatchDDL'): 71 | JobService.create_job(session, job.task_id, JobType.MATCH_DDL.value) 72 | elif job_type == JobType.MATCH_DDL.value: 73 | if task_options.get('autoGenSql'): 74 | JobService.create_job(session, job.task_id, JobType.GENERATE_SQL.value) 75 | elif job_type == JobType.GENERATE_SQL.value: 76 | pass 77 | elif job_type == JobType.LEARN_FROM_SQL.value: 78 | pass 79 | except Exception as e: 80 | # Update job status on error 81 | job = session.query(Job).get(job_id) 82 | if job.job_status != JobStatus.RUNNING.value: 83 | print(f"job {job_id} status is not RUNNING, skip") 84 | continue 85 | job.job_status = JobStatus.FAIL.value 86 | job.error_message = str(e) 87 | session.commit() 88 | -------------------------------------------------------------------------------- /backend/jobs/job_vector_db.py: -------------------------------------------------------------------------------- 1 | from . import scheduler 2 | from models.project import Project 3 | from models.definition_table import DefinitionTable 4 | from models.definition_column import DefinitionColumn 5 | from models.definition_doc import DefinitionDoc 6 | from models.task import Task 7 | from services.def_service import DefService 8 | from services.task_service import TaskService 9 | from database import db 10 | from app import app 11 | 12 | # Task for adding to vector database 13 | @scheduler.task('interval', id='add_to_vector_db_job', seconds=1, coalesce=True, max_instances=1) 14 | def add_to_vector_db_job(): 15 | with app.app_context(): 16 | session = db.session 17 | 18 | table_count = 0 19 | column_count = 0 20 | doc_count = 0 21 | task_count = 0 22 | 23 | # Query 10 pending table definitions 24 | table_definitions = session.query(DefinitionTable).filter_by(def_waiting=True).limit(10).all() 25 | for table_definition in table_definitions: 26 | DefService.add_or_update_table_vector_db(table_definition) 27 | table_definition.def_waiting = False 28 | table_count += 1 29 | if table_count > 0: 30 | print(f"Number of table definitions added to vector database: {table_count}") 31 | 32 | # Query 10 pending column definitions 33 | column_definitions = session.query(DefinitionColumn).filter_by(def_waiting=True).limit(10).all() 34 | for column_definition in column_definitions: 35 | DefService.add_or_update_column_vector_db(column_definition) 36 | column_definition.def_waiting = False 37 | column_count += 1 38 | if column_count > 0: 39 | print(f"Number of column definitions added to vector database: {column_count}") 40 | 41 | # Query 10 pending document definitions 42 | doc_definitions = session.query(DefinitionDoc).filter_by(def_waiting=True).limit(10).all() 43 | for doc_definition in doc_definitions: 44 | DefService.refresh_doc_vector_db(doc_definition) 45 | doc_definition.def_waiting = False 46 | doc_count += 1 47 | if doc_count > 0: 48 | print(f"Number of document definitions added to vector database: {doc_count}") 49 | 50 | # Query 10 pending tasks 51 | tasks = session.query(Task).filter_by(def_waiting=True).limit(10).all() 52 | for task in tasks: 53 | TaskService.refresh_task_vector_db(task) 54 | task.def_waiting = False 55 | task_count += 1 56 | if task_count > 0: 57 | print(f"Number of tasks added to vector database: {task_count}") 58 | 59 | session.commit() 60 | 61 | # # Check if building is complete 62 | # if table_count + column_count + doc_count + task_count == 0: 63 | # index_model = session.query(IndexModel).first() 64 | # if index_model.status == IndexModelStatus.BUILDING.value: 65 | # index_model.status = IndexModelStatus.READY.value 66 | # session.commit() 67 | 68 | # Task for removing old version table and column definitions 69 | @scheduler.task('interval', id='remove_old_version_defs_job', seconds=10, coalesce=True, max_instances=1) 70 | def remove_old_version_defs_job(): 71 | with app.app_context(): 72 | session = db.session 73 | table_count = 0 74 | column_count = 0 75 | 76 | # Query 50 old version table definitions 77 | table_definitions = session.query(DefinitionTable) \ 78 | .join(Project, Project.id == DefinitionTable.project_id) \ 79 | .filter(DefinitionTable.def_version < Project.cur_version) \ 80 | .limit(50).all() 81 | for table_definition in table_definitions: 82 | DefService.remove_table_vector_db(table_definition) 83 | session.delete(table_definition) 84 | table_count += 1 85 | if table_count > 0: 86 | print(f"Number of old version table definitions deleted: {table_count}") 87 | 88 | # Query 50 old version column definitions 89 | column_definitions = session.query(DefinitionColumn) \ 90 | .join(Project, Project.id == DefinitionColumn.project_id) \ 91 | .filter(DefinitionColumn.def_version < Project.cur_version) \ 92 | .limit(50).all() 93 | for column_definition in column_definitions: 94 | DefService.remove_column_vector_db(column_definition) 95 | session.delete(column_definition) 96 | column_count += 1 97 | if column_count > 0: 98 | print(f"Number of old version column definitions deleted: {column_count}") 99 | 100 | if table_count + column_count > 0: 101 | session.commit() -------------------------------------------------------------------------------- /backend/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging.handlers import RotatingFileHandler 3 | import os 4 | 5 | # Create logs directory if it doesn't exist 6 | if not os.path.exists('logs'): 7 | os.makedirs('logs') 8 | 9 | my_logger = logging.getLogger(__name__) 10 | # Configure log format 11 | formatter = logging.Formatter( 12 | '[%(asctime)s] %(levelname)s in %(module)s: %(message)s' 13 | ) 14 | file_handler = logging.FileHandler('logs/my_logger.log') 15 | file_handler.setFormatter(formatter) 16 | file_handler.setLevel(logging.INFO) 17 | my_logger.addHandler(file_handler) 18 | 19 | def init_logger(): 20 | """Initialize logger configuration""" 21 | # Configure root logger 22 | logging.basicConfig(level=logging.INFO) 23 | root_logger = logging.getLogger() 24 | 25 | # Configure log format 26 | formatter = logging.Formatter( 27 | '[%(asctime)s] %(levelname)s in %(module)s: %(message)s' 28 | ) 29 | 30 | # File handler - with size-based rotation 31 | file_handler = RotatingFileHandler( 32 | 'logs/app.log', 33 | maxBytes=10485760, # 10MB 34 | backupCount=10, 35 | encoding='utf-8' 36 | ) 37 | file_handler.setFormatter(formatter) 38 | file_handler.setLevel(logging.ERROR) 39 | 40 | # Console handler 41 | console_handler = logging.StreamHandler() 42 | console_handler.setFormatter(formatter) 43 | console_handler.setLevel(logging.INFO) 44 | 45 | # Clear existing handlers 46 | root_logger.handlers.clear() 47 | 48 | # Add handlers to root logger 49 | root_logger.addHandler(file_handler) 50 | root_logger.addHandler(console_handler) 51 | 52 | return root_logger -------------------------------------------------------------------------------- /backend/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty file to make models a Python package 2 | -------------------------------------------------------------------------------- /backend/models/base.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | from database import db 3 | from sqlalchemy import event 4 | 5 | class BaseModel(db.Model): 6 | """Base model class, containing common fields""" 7 | __abstract__ = True 8 | 9 | id = db.Column(db.Integer, primary_key=True) 10 | created_at = db.Column(db.DateTime(timezone=True), nullable=False) 11 | updated_at = db.Column(db.DateTime(timezone=True), nullable=False) 12 | version = db.Column(db.Integer, nullable=False, default=0) 13 | 14 | # Define fields that should not be updated 15 | do_not_update_fields = {'created_at', 'updated_at', 'version'} 16 | 17 | def update(self, **kwargs): 18 | """Smart update properties, excluding specified fields""" 19 | self.version += 1 20 | for key, value in kwargs.items(): 21 | if key not in self.do_not_update_fields: 22 | setattr(self, key, value) 23 | return self 24 | 25 | # Add SQLAlchemy event listener 26 | @event.listens_for(BaseModel, 'before_insert', propagate=True) 27 | def set_created_updated_at(mapper, connection, target): 28 | """Set creation and update times before insertion""" 29 | now = datetime.now(timezone.utc) 30 | target.created_at = now 31 | target.updated_at = now 32 | 33 | @event.listens_for(BaseModel, 'before_update', propagate=True) 34 | def set_updated_at(mapper, connection, target): 35 | """Set update time before update""" 36 | target.updated_at = datetime.now(timezone.utc) 37 | 38 | class ProjectBaseModel(BaseModel): 39 | """Project base model class""" 40 | __abstract__ = True 41 | 42 | project_id = db.Column(db.Integer, nullable=False) -------------------------------------------------------------------------------- /backend/models/definition_column.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class DefinitionColumn(ProjectBaseModel): 6 | """Column definition model""" 7 | __tablename__ = 'definition_column' 8 | 9 | def_table = db.Column(db.String(100), nullable=False, comment='Table name') 10 | def_type = db.Column(db.String(50), nullable=False, comment='Field type') 11 | def_column = db.Column(db.String(100), nullable=False, comment='Field name') 12 | def_comment = db.Column(db.String(500), comment='Field comment') 13 | def_ai_comment = db.Column(db.String(500), comment='AI field comment') 14 | def_waiting = db.Column(db.Boolean, default=False, comment='Whether waiting for building') 15 | def_version = db.Column(db.Integer, default=1, comment='Current index version') 16 | 17 | __table_args__ = ( 18 | db.UniqueConstraint('project_id', 'def_table', 'def_column', name='uix_table_column'), 19 | db.Index('ix_column_definition_def_waiting', 'def_waiting'), 20 | db.Index('ix_column_definition_def_version', 'def_version'), 21 | ) 22 | 23 | class DefinitionColumnSchema(SQLAlchemyAutoSchema): 24 | class Meta: 25 | model = DefinitionColumn 26 | load_instance = True 27 | include_relationships = True 28 | sqla_session = db.session 29 | 30 | definition_column_schema = DefinitionColumnSchema() 31 | definition_columns_schema = DefinitionColumnSchema(many=True) 32 | -------------------------------------------------------------------------------- /backend/models/definition_doc.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class DefinitionDoc(ProjectBaseModel): 6 | """Document definition model""" 7 | __tablename__ = 'definition_doc' 8 | 9 | def_doc = db.Column(db.Text, nullable=False, comment='Document content') 10 | def_selected = db.Column(db.Boolean, default=False, comment='Whether default selected') 11 | def_waiting = db.Column(db.Boolean, default=False, comment='Whether waiting for building') 12 | disabled = db.Column(db.Boolean, default=False, comment='Whether disabled') 13 | 14 | # Non-unique index: def_selected 15 | __table_args__ = (db.Index('idx_definition_doc_def_selected', 'def_selected'),) 16 | 17 | class DefinitionDocSchema(SQLAlchemyAutoSchema): 18 | class Meta: 19 | model = DefinitionDoc 20 | load_instance = True 21 | include_relationships = True 22 | sqla_session = db.session 23 | 24 | definition_doc_schema = DefinitionDocSchema() 25 | definition_docs_schema = DefinitionDocSchema(many=True) -------------------------------------------------------------------------------- /backend/models/definition_relation.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class DefinitionRelation(ProjectBaseModel): 6 | """Relation definition model""" 7 | __tablename__ = 'definition_relation' 8 | 9 | table1 = db.Column(db.String(100), nullable=False, comment='Table name 1') 10 | column1 = db.Column(db.String(100), nullable=False, comment='Table 1 column name') 11 | table2 = db.Column(db.String(100), nullable=False, comment='Table name 2') 12 | column2 = db.Column(db.String(100), nullable=False, comment='Table 2 column name') 13 | relation_type = db.Column(db.String(100), nullable=False, comment='Relationship type between table 1 and table 2, options: 1-1(one-to-one),1-n(one-to-many),n-1(many-to-one),n-n(many-to-many)') 14 | 15 | __table_args__ = ( 16 | db.UniqueConstraint('project_id', 'table1', 'column1', 'table2', 'column2', name='uix_relation_definition'), 17 | ) 18 | 19 | class DefinitionRelationSchema(SQLAlchemyAutoSchema): 20 | class Meta: 21 | model = DefinitionRelation 22 | load_instance = True 23 | include_relationships = True 24 | sqla_session = db.session 25 | 26 | definition_relation_schema = DefinitionRelationSchema() 27 | definition_relations_schema = DefinitionRelationSchema(many=True) -------------------------------------------------------------------------------- /backend/models/definition_rule.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class DefinitionRule(ProjectBaseModel): 6 | """Rule definition model""" 7 | __tablename__ = 'definition_rule' 8 | 9 | name = db.Column(db.String(100), nullable=False, comment='Rule name') 10 | content = db.Column(db.Text, nullable=False, comment='Rule content') 11 | def_selected = db.Column(db.Boolean, default=False, comment='Whether default selected') 12 | disabled = db.Column(db.Boolean, default=False, comment='Whether disabled') 13 | 14 | __table_args__ = ( 15 | db.Index('idx_definition_rule_def_selected', 'def_selected'), 16 | ) 17 | 18 | class DefinitionRuleSchema(SQLAlchemyAutoSchema): 19 | class Meta: 20 | model = DefinitionRule 21 | load_instance = True 22 | include_relationships = True 23 | sqla_session = db.session 24 | 25 | definition_rule_schema = DefinitionRuleSchema() 26 | definition_rules_schema = DefinitionRuleSchema(many=True) -------------------------------------------------------------------------------- /backend/models/definition_table.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class DefinitionTable(ProjectBaseModel): 6 | """Table definition model""" 7 | __tablename__ = 'definition_table' 8 | 9 | def_table = db.Column(db.String(100), nullable=False, comment='Table name') 10 | def_comment = db.Column(db.String(500), comment='Table comment') 11 | def_ai_comment = db.Column(db.String(500), comment='AI table comment') 12 | def_waiting = db.Column(db.Boolean, default=False, comment='Whether waiting for building') 13 | def_version = db.Column(db.Integer, default=1, comment='Current index version') 14 | disabled = db.Column(db.Boolean, default=False, comment='Whether disabled') 15 | 16 | __table_args__ = ( 17 | db.Index('uix_table_definition_def_table', 'project_id', 'def_table', unique=True), 18 | db.Index('ix_table_definition_def_waiting', 'def_waiting'), 19 | db.Index('ix_table_definition_def_version', 'def_version'), 20 | ) 21 | 22 | class DefinitionTableSchema(SQLAlchemyAutoSchema): 23 | class Meta: 24 | model = DefinitionTable 25 | load_instance = True 26 | include_relationships = True 27 | sqla_session = db.session 28 | 29 | definition_table_schema = DefinitionTableSchema() 30 | definition_tables_schema = DefinitionTableSchema(many=True) -------------------------------------------------------------------------------- /backend/models/job.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | from enums import JobStatus 5 | 6 | class Job(ProjectBaseModel): 7 | """Task model""" 8 | __tablename__ = 'job' 9 | 10 | task_id = db.Column(db.Integer, nullable=False, comment='Task ID') 11 | job_type = db.Column(db.String(20), nullable=False, comment='Task type') 12 | job_data = db.Column(db.JSON, comment='Task data') 13 | job_status = db.Column(db.String(20), nullable=False, default=JobStatus.INIT.value, comment='Task status') 14 | job_cost_time = db.Column(db.Integer, nullable=False, default=0, comment='Task cost time, unit: ms') 15 | error_message = db.Column(db.Text, comment='Error message') 16 | 17 | class JobSchema(SQLAlchemyAutoSchema): 18 | class Meta: 19 | model = Job 20 | load_instance = True 21 | include_relationships = True 22 | sqla_session = db.session 23 | 24 | job_schema = JobSchema() 25 | jobs_schema = JobSchema(many=True) 26 | -------------------------------------------------------------------------------- /backend/models/project.py: -------------------------------------------------------------------------------- 1 | from models.base import BaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema, auto_field 4 | 5 | class Project(BaseModel): 6 | """Project model""" 7 | __tablename__ = 'project' 8 | 9 | name = db.Column(db.String(20), comment='Project name') 10 | description = db.Column(db.String(200), comment='Project description') 11 | db_type = db.Column(db.String(20), comment='Database type') 12 | db_version = db.Column(db.String(255), comment='Database version information') 13 | cur_version = db.Column(db.Integer, default=1, comment='Current index version') 14 | 15 | class ProjectSchema(SQLAlchemyAutoSchema): 16 | class Meta: 17 | model = Project 18 | load_instance = True 19 | include_relationships = True 20 | sqla_session = db.session 21 | 22 | project_schema = ProjectSchema() 23 | projects_schema = ProjectSchema(many=True) 24 | -------------------------------------------------------------------------------- /backend/models/task.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class Task(ProjectBaseModel): 6 | """Task model""" 7 | __tablename__ = 'task' 8 | 9 | question = db.Column(db.Text, nullable=False, comment='User question content') 10 | question_supplement = db.Column(db.Text, comment='Question supplement') 11 | options = db.Column(db.JSON, nullable=False, default={}, comment='Task options') 12 | rules = db.Column(db.JSON, comment='Rule id list') 13 | related_columns = db.Column(db.Text, comment='Related columns') 14 | sql = db.Column(db.Text, comment='Generated SQL') 15 | sql_right = db.Column(db.Boolean, comment='Whether the generated SQL is correct') 16 | sql_refer = db.Column(db.Boolean, comment='Whether it can be referenced') 17 | learn_result = db.Column(db.Text, comment='Learning result') 18 | def_waiting = db.Column(db.Boolean, default=False, comment='Whether to wait for construction') 19 | 20 | def __repr__(self): 21 | return f"" 22 | 23 | class TaskSchema(SQLAlchemyAutoSchema): 24 | class Meta: 25 | model = Task 26 | load_instance = True 27 | include_relationships = True 28 | sqla_session = db.session 29 | 30 | task_schema = TaskSchema() 31 | tasks_schema = TaskSchema(many=True) -------------------------------------------------------------------------------- /backend/models/task_column.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class TaskColumn(ProjectBaseModel): 6 | """Task model: which columns are selected""" 7 | __tablename__ = 'task_column' 8 | 9 | task_id = db.Column(db.Integer, nullable=False, comment='Task ID') 10 | table_name = db.Column(db.String(100), nullable=False, comment='Table name') 11 | column_name = db.Column(db.String(100), nullable=False, comment='Column name') 12 | 13 | __table_args__ = ( 14 | db.Index('ix_task_column_task_id', 'task_id'), 15 | ) 16 | 17 | class TaskColumnSchema(SQLAlchemyAutoSchema): 18 | class Meta: 19 | model = TaskColumn 20 | load_instance = True 21 | include_relationships = True 22 | sqla_session = db.session 23 | 24 | task_column_schema = TaskColumnSchema() 25 | task_columns_schema = TaskColumnSchema(many=True) 26 | -------------------------------------------------------------------------------- /backend/models/task_doc.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class TaskDoc(ProjectBaseModel): 6 | """Task model: which documents are selected""" 7 | __tablename__ = 'task_doc' 8 | 9 | task_id = db.Column(db.Integer, nullable=False, comment='Task ID') 10 | doc_id = db.Column(db.Integer, nullable=False, comment='Document definition ID') 11 | 12 | __table_args__ = ( 13 | db.Index('ix_task_doc_task_id', 'task_id'), 14 | ) 15 | 16 | class TaskDocSchema(SQLAlchemyAutoSchema): 17 | class Meta: 18 | model = TaskDoc 19 | load_instance = True 20 | include_relationships = True 21 | sqla_session = db.session 22 | 23 | task_doc_schema = TaskDocSchema() 24 | task_docs_schema = TaskDocSchema(many=True) 25 | -------------------------------------------------------------------------------- /backend/models/task_sql.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class TaskSQL(ProjectBaseModel): 6 | """Task model: which SQL records are selected""" 7 | __tablename__ = 'task_sql' 8 | 9 | task_id = db.Column(db.Integer, nullable=False, comment='Task ID') 10 | sql_id = db.Column(db.Integer, nullable=False, comment='SQL record ID') 11 | 12 | __table_args__ = ( 13 | db.Index('ix_task_sql_task_id', 'task_id'), 14 | ) 15 | 16 | class TaskSQLSchema(SQLAlchemyAutoSchema): 17 | class Meta: 18 | model = TaskSQL 19 | load_instance = True 20 | include_relationships = True 21 | sqla_session = db.session 22 | 23 | task_sql_schema = TaskSQLSchema() 24 | task_sqls_schema = TaskSQLSchema(many=True) 25 | -------------------------------------------------------------------------------- /backend/models/task_table.py: -------------------------------------------------------------------------------- 1 | from models.base import ProjectBaseModel 2 | from database import db 3 | from marshmallow_sqlalchemy import SQLAlchemyAutoSchema 4 | 5 | class TaskTable(ProjectBaseModel): 6 | """Task model: which tables are selected""" 7 | __tablename__ = 'task_table' 8 | 9 | task_id = db.Column(db.Integer, nullable=False, comment='Task ID') 10 | table_name = db.Column(db.String(100), nullable=False, comment='Table name') 11 | 12 | __table_args__ = ( 13 | db.Index('ix_task_table_task_id', 'task_id'), 14 | ) 15 | 16 | class TaskTableSchema(SQLAlchemyAutoSchema): 17 | class Meta: 18 | model = TaskTable 19 | load_instance = True 20 | include_relationships = True 21 | sqla_session = db.session 22 | 23 | task_table_schema = TaskTableSchema() 24 | task_tables_schema = TaskTableSchema(many=True) 25 | -------------------------------------------------------------------------------- /backend/prompt_templates/gen_ai_comments.mustache: -------------------------------------------------------------------------------- 1 | 2 | You are a database expert, skilled at generating AI comments for tables and columns based on provided table information. 3 | 4 | 5 | 6 | {{{tableStr}}} 7 |
8 | 9 | 10 | Please generate AI comments for tables and columns based on the table information. 11 | 12 | 13 | 14 | Please output JSON result in the following format, only including comments for tables and columns: 15 | 16 | ```json 17 | { 18 | "table": { 19 | "t": "Table Name", 20 | "v": "Table Comment", 21 | "cols": [ 22 | { 23 | "c": "Column Name", 24 | "v": "Column Comment" 25 | } 26 | ] 27 | } 28 | } 29 | ``` 30 | 31 | -------------------------------------------------------------------------------- /backend/prompt_templates/gen_related_columns.mustache: -------------------------------------------------------------------------------- 1 | 2 | You are a database expert, skilled at inferring all possible table names and column names based on user questions and provided information. 3 | 4 | 5 | 6 | {{{doc_content}}} 7 | 8 | 9 | 10 | {{{sql_content}}} 11 | 12 | 13 | 14 | {{{question}}} 15 | 16 | 17 | 18 | {{{question_supplement}}} 19 | 20 | 21 | 22 | Please analyze and infer all possible table names and column names based on the user question. Consider the following aspects: 23 | 24 | 1. Fields to query in the SELECT clause 25 | 2. Tables and fields used for joining in JOIN/ON clauses 26 | 3. Filter fields used in WHERE conditions 27 | 4. Fields used in GROUP BY/ORDER BY clauses 28 | 5. Any other potentially relevant tables and fields 29 | 30 | Prioritize using table names and column names that already exist in related documents and SQL. If exact matches cannot be found, infer appropriate table and column names based on business context. 31 | 32 | Please list all relevant table and field combinations as completely as possible. Even if you're not sure whether they will be used, please include them as long as they might be relevant. 33 | 34 | 35 | 36 | Please output JSON results in the following format: 37 | 38 | ```json 39 | { 40 | "tables": [ 41 | { 42 | "t": "Table Name", 43 | "d": "Table Description" 44 | }, 45 | { 46 | "t": "Table Name", 47 | "d": "Table Description" 48 | } 49 | ], 50 | "columns": [ 51 | { 52 | "t": "Table Name", 53 | "c": "Column Name", 54 | "d": "Column Description" 55 | }, 56 | { 57 | "t": "Table Name", 58 | "c": "Column Name", 59 | "d": "Column Description" 60 | } 61 | ] 62 | } 63 | ``` 64 | 65 | -------------------------------------------------------------------------------- /backend/prompt_templates/gen_sql.mustache: -------------------------------------------------------------------------------- 1 | 2 | You are a database expert, skilled at generating SQL queries based on user questions and provided information. 3 | 4 | 5 | 6 | {{{doc_content}}} 7 | 8 | 9 | 10 | {{{sql_content}}} 11 | 12 | 13 | 14 | {{{table_structure}}} 15 |
16 | 17 | 18 | {{{relation_structure}}} 19 |
20 | 21 | 22 | {{{rules}}} 23 | 24 | 25 | 26 | {{{question}}} 27 | 28 | 29 | 30 | {{{question_supplement}}} 31 | 32 | 33 | 34 | Please generate SQL according to the following requirements: 35 | 1. Use concise and clear query statements 36 | 2. Ensure query results accurately answer user questions 37 | 3. Use appropriate JOIN types when connecting multiple tables 38 | 4. Add necessary WHERE conditions to ensure result accuracy 39 | 5. Use subqueries or Common Table Expression (CTE) if needed 40 | 41 | 42 | 43 | {{{db_type_name}}} {{{db_version}}} 44 | 45 | 46 | 47 | Please output the SQL query in JSON format without any other content. 48 | 49 | The JSON should contain the following field: 50 | 51 | - sql: The generated SQL query 52 | -------------------------------------------------------------------------------- /backend/prompt_templates/learn.mustache: -------------------------------------------------------------------------------- 1 | 2 | You are a database expert, skilled at extracting table descriptions, field descriptions, field relationships and other information based on user questions and provided information. 3 | 4 | 5 | 6 | {{{table_structure}}} 7 | 8 | 9 | 10 | {{{sql_structure}}} 11 | 12 | 13 | 14 | {{{question}}} 15 | 16 | 17 | 18 | {{{question_supplement}}} 19 | 20 | 21 | 22 | {{{sql}}} 23 | 24 | 25 | 26 | Please extract table descriptions, field descriptions, field relationships based on the provided information and output in a standardized format. 27 | 28 | Requirements: 29 | 30 | - Extract descriptions of tables, fields and relationships from user questions and generated SQL statements 31 | - For tables, fields and relationships that do not appear in user questions and generated SQL statements, they can be ignored even if they appear in possible related table structures and SQL 32 | - Table descriptions should include table name and business purpose 33 | - Field descriptions should include field name, belonging table, and field meaning 34 | - Relationship descriptions should include related tables and related fields 35 | - If the meaning of a field is unclear, it can be ignored 36 | - All descriptions should be clear and accurate, avoiding ambiguity 37 | 38 | 39 | 40 | Please output in JSON format. 41 | 42 | ```json 43 | { 44 | "tables": [ 45 | { 46 | "table": "Table Name", 47 | "desc": "Table Description" 48 | } 49 | ], 50 | "columns": [ 51 | { 52 | "table": "Table Name", 53 | "column": "Column Name", 54 | "desc": "Column Description" 55 | } 56 | ], 57 | "relations": [ 58 | { 59 | "table1": "Table 1", 60 | "column1": "Column 1 of Table 1", 61 | "table2": "Table 2", 62 | "column2": "Column 2 of Table 2", 63 | "relation_type": "Relationship type between Table 1 and Table 2, values:1-1(one to one),1-n(one to many),n-1(many to one),n-n(many to many)" 64 | } 65 | ] 66 | } 67 | ``` 68 | 69 | Example: 70 | 71 | ```json 72 | { 73 | "tables": [ 74 | { 75 | "table": "user", 76 | "desc": "User Table" 77 | } 78 | ], 79 | "columns": [ 80 | { 81 | "table": "user", 82 | "column": "id", 83 | "desc": "User ID" 84 | } 85 | ], 86 | "relations": [ 87 | { 88 | "table1": "user", 89 | "column1": "id", 90 | "table2": "order", 91 | "column2": "user_id", 92 | "relation_type": "1-n" 93 | } 94 | ] 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /backend/prompt_templates/optimize_question.mustache: -------------------------------------------------------------------------------- 1 | ## Context 2 | User Question: 3 | 4 | ```txt 5 | {{{question}}} 6 | ``` 7 | 8 | ## Instructions 9 | 10 | Please optimize the expression to be more standardized, clear, and unambiguous, eliminating ambiguity, and making it easier to generate SQL statements. If there are contradictions or ambiguities in the statement, please understand the user's intention and fix them: 11 | 12 | ### Optimization Example 13 | 14 | Example 1: 15 | Original Question: Query the number of orders per day, aggregated by customer 16 | Optimized: Please count the total number of orders per customer per day, including date, customer ID, and order count. Results should be sorted by date and customer ID. 17 | 18 | Example 2: 19 | Original Question: Query items with sales greater than 1000 20 | Optimized: Please query information about items with a cumulative sales total exceeding 1000 yuan, including item ID, item name, and sales total. Results should be sorted in descending order by sales total. 21 | 22 | ## Output Format 23 | Output in JSON format, as follows: 24 | 25 | ```json 26 | { 27 | "result": "Optimized Question" 28 | } 29 | ``` -------------------------------------------------------------------------------- /backend/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==3.0.2 2 | Flask-APScheduler==1.13.1 3 | python-dotenv==1.0.1 4 | flask-smorest==0.45.0 5 | flask-cors==4.0.0 6 | psycopg2-binary==2.9.9 7 | SQLAlchemy==2.0.27 8 | gunicorn==23.0.0 9 | chromadb==0.5.18 10 | openai==1.54.4 11 | sqlparse==0.5.2 12 | dataclasses-json==0.6.7 13 | marshmallow-dataclass==8.7.1 14 | marshmallow-sqlalchemy==1.1.0 15 | marshmallow==3.23.1 16 | Flask-SQLAlchemy==3.1.1 17 | requests==2.31.0 18 | chevron==0.14.0 19 | httpx==0.27.2 -------------------------------------------------------------------------------- /backend/routes/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty file, used to mark the directory as a Python package 2 | -------------------------------------------------------------------------------- /backend/routes/project.py: -------------------------------------------------------------------------------- 1 | from flask.views import MethodView 2 | from flask_smorest import Blueprint, abort 3 | from marshmallow import Schema, fields 4 | from services.project_service import ProjectService 5 | from models.project import ProjectSchema 6 | from flask import jsonify 7 | from database import session_scope 8 | from utils.schemas import MessageResponseSchema 9 | from models.project import Project 10 | 11 | # create blueprint 12 | project_bp = Blueprint('project', __name__, description='Project operations') 13 | 14 | class CreateProjectSchema(Schema): 15 | name = fields.Str(required=True, description='Project name') 16 | description = fields.Str(required=True, description='Project description') 17 | db_type = fields.Str(required=True, description='Database type') 18 | db_version = fields.Str(required=True, description='Database version information') 19 | 20 | class UpdateProjectSchema(Schema): 21 | name = fields.Str(required=True, description='Project name') 22 | description = fields.Str(required=True, description='Project description') 23 | db_type = fields.Str(required=True, description='Database type') 24 | db_version = fields.Str(required=True, description='Database version information') 25 | 26 | class ProjectListResponseSchema(Schema): 27 | projects = fields.List(fields.Nested(ProjectSchema), description='Project list') 28 | 29 | @project_bp.route('/project') 30 | class ProjectView(MethodView): 31 | @project_bp.arguments(CreateProjectSchema) 32 | @project_bp.response(200, MessageResponseSchema) 33 | def post(self, json_data): 34 | """Create new project""" 35 | try: 36 | with session_scope() as session: 37 | ProjectService.create_project( 38 | session, 39 | json_data['name'], 40 | json_data['description'], 41 | json_data['db_type'], 42 | json_data['db_version'] 43 | ) 44 | return {'message': 'Project created'} 45 | except Exception as e: 46 | abort(400, message=str(e)) 47 | 48 | @project_bp.response(200, ProjectListResponseSchema) 49 | def get(self): 50 | """Get all project list""" 51 | try: 52 | with session_scope(read_only=True) as session: 53 | projects = ProjectService.get_all_projects(session) 54 | return {"projects": projects} 55 | except Exception as e: 56 | abort(400, message=str(e)) 57 | 58 | @project_bp.route('/project/') 59 | class ProjectDetailView(MethodView): 60 | @project_bp.arguments(UpdateProjectSchema) 61 | @project_bp.response(200, ProjectSchema) 62 | def put(self, json_data, id): 63 | """Update project information""" 64 | try: 65 | with session_scope() as session: 66 | ProjectService.update_project( 67 | session, 68 | id, 69 | json_data['name'], 70 | json_data['description'], 71 | json_data['db_type'], 72 | json_data['db_version'] 73 | ) 74 | 75 | with session_scope(read_only=True) as session: 76 | return session.query(Project).get(id) 77 | except ValueError as e: 78 | abort(404, message=str(e)) 79 | except Exception as e: 80 | abort(400, message=str(e)) 81 | 82 | @project_bp.response(200, ProjectSchema) 83 | def get(self, id): 84 | """Get project detail""" 85 | with session_scope(read_only=True) as session: 86 | project = ProjectService.get_project_by_id(session, id) 87 | session.refresh(project) 88 | return project 89 | 90 | @project_bp.response(204) 91 | def delete(self, id): 92 | """Delete project""" 93 | try: 94 | with session_scope() as session: 95 | ProjectService.delete_project(session, id) 96 | except ValueError as e: 97 | abort(404, message=str(e)) 98 | except Exception as e: 99 | abort(400, message=str(e)) 100 | 101 | @project_bp.route('/example', methods=['POST']) 102 | class CreateExampleProjectView(MethodView): 103 | """Create an example project with predefined documents and rules""" 104 | @project_bp.response(200, MessageResponseSchema) 105 | def post(self): 106 | with session_scope() as session: 107 | ProjectService.create_example_project(session) 108 | return {'message': 'Example project created successfully'} 109 | -------------------------------------------------------------------------------- /backend/routes/test.py: -------------------------------------------------------------------------------- 1 | from flask_smorest import Blueprint 2 | from flask.views import MethodView 3 | from marshmallow import Schema, fields 4 | 5 | test_bp = Blueprint('test', __name__) 6 | 7 | class TranslateSchema(Schema): 8 | text = fields.Str(description='Text') 9 | target_language = fields.Str(description='Target language') 10 | source_language = fields.Str(description='Source language', default=None) 11 | 12 | class TranslateResponseSchema(Schema): 13 | translated_text = fields.Str(description='Translated text') 14 | 15 | class SearchTableDefStoreSchema(Schema): 16 | query_text = fields.Str(description='Query text') 17 | n_results = fields.Int(description='Result number', default=5) 18 | 19 | class SearchTableDefStoreResponseSchema(Schema): 20 | results = fields.List(fields.Dict(description='Result'), description='Query results') 21 | 22 | @test_bp.route('/search_table_def_store', methods=['POST']) 23 | class SearchTableDefStore(MethodView): 24 | @test_bp.arguments(SearchTableDefStoreSchema) 25 | @test_bp.response(200, SearchTableDefStoreResponseSchema) 26 | def post(self, json_data): 27 | from vector_stores import table_def_store 28 | results = table_def_store.query_documents(json_data.get('query_text'), n_results=json_data.get('n_results'), where={"project_id": 2}) 29 | return {'results': results['metadatas'][0]} 30 | 31 | @test_bp.route('/translate', methods=['POST']) 32 | class Translate(MethodView): 33 | @test_bp.arguments(TranslateSchema) 34 | @test_bp.response(200, TranslateResponseSchema) 35 | def post(self, json_data): 36 | """Translate text""" 37 | from services.translate_service import translate_service 38 | text = json_data.get('text') 39 | target_language = json_data.get('target_language') 40 | source_language = json_data.get('source_language', None) 41 | 42 | if not text or not target_language: 43 | return {'message': 'Missing required parameters'}, 400 44 | 45 | try: 46 | translated_text = translate_service.translate( 47 | text=text, 48 | target_language=target_language, 49 | source_language=source_language 50 | ) 51 | return {'translated_text': translated_text} 52 | except Exception as e: 53 | return {'message': str(e)}, 500 -------------------------------------------------------------------------------- /backend/services/__init__.py: -------------------------------------------------------------------------------- 1 | # empty file, used to mark the directory as a Python package 2 | -------------------------------------------------------------------------------- /backend/services/job_service.py: -------------------------------------------------------------------------------- 1 | from models.job import Job 2 | from models.task_doc import TaskDoc 3 | from models.task_sql import TaskSQL 4 | from models.task_table import TaskTable 5 | from models.task_column import TaskColumn 6 | from models.task import Task 7 | from enums import JobStatus 8 | from dto.job_dto import JobDTO 9 | from database import db 10 | from enums import JobType 11 | 12 | class JobService: 13 | @staticmethod 14 | def job_to_job_dto(job: Job): 15 | return JobDTO(id=job.id, 16 | version=job.version, 17 | task_id=job.task_id, 18 | project_id=job.project_id, 19 | job_type=job.job_type, 20 | job_data=job.job_data, 21 | job_status=job.job_status, 22 | job_type_display_name=JobType.get_display_name_by_value(job.job_type), 23 | job_status_display_name=JobStatus.get_display_name_by_value(job.job_status), 24 | error_message=job.error_message, 25 | created_at=job.created_at, 26 | updated_at=job.updated_at, 27 | job_cost_time=job.job_cost_time) 28 | 29 | @staticmethod 30 | def create_job(session, task_id: int, job_type: str): 31 | """Create a new job""" 32 | task = session.query(Task).get(task_id) 33 | if not task: 34 | raise ValueError("Task does not exist") 35 | # Clean up task information 36 | if job_type == JobType.MATCH_DOC.value: 37 | session.query(TaskDoc).filter(TaskDoc.task_id == task_id).delete() 38 | if job_type == JobType.MATCH_SQL_LOG.value: 39 | session.query(TaskSQL).filter(TaskSQL.task_id == task_id).delete() 40 | if job_type == JobType.GEN_RELATED_COLUMNS.value: 41 | task.related_columns = None 42 | session.commit() 43 | if job_type == JobType.MATCH_DDL.value: 44 | session.query(TaskTable).filter(TaskTable.task_id == task_id).delete() 45 | session.query(TaskColumn).filter(TaskColumn.task_id == task_id).delete() 46 | if job_type == JobType.GENERATE_SQL.value: 47 | task.sql = None 48 | task.sql_right = None 49 | task.sql_refer = None 50 | session.commit() 51 | if job_type == JobType.LEARN_FROM_SQL.value: 52 | task.learn_result = None 53 | session.commit() 54 | 55 | # Create job 56 | job = Job(project_id=task.project_id, task_id=task_id, job_type=job_type) 57 | session.add(job) 58 | 59 | @staticmethod 60 | def cancel_job(session, job_id: int): 61 | """Cancel a job""" 62 | session.query(Job).filter(Job.id == job_id).update({'job_status': JobStatus.CANCELED.value}) 63 | 64 | @staticmethod 65 | def get_init_job_ids() -> list[int]: 66 | """Get all job IDs with initialization status""" 67 | session = db.session 68 | return [job.id for job in session.query(Job).filter(Job.job_status == JobStatus.INIT.value).all()] 69 | 70 | @staticmethod 71 | def get_jobs(session, task_id: int) -> list[JobDTO]: 72 | """Get all jobs for a specific task""" 73 | jobs = session.query(Job).filter(Job.task_id == task_id).all() 74 | return [JobService.job_to_job_dto(job) for job in jobs] 75 | 76 | @staticmethod 77 | def get_job(session, job_id: int) -> JobDTO: 78 | """Get details of a specific job""" 79 | job = session.query(Job).get(job_id) 80 | return JobService.job_to_job_dto(job) -------------------------------------------------------------------------------- /backend/services/openai_service.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from openai.types.chat import ChatCompletionMessageParam 3 | from openai.types.chat.completion_create_params import ResponseFormat 4 | from openai import AsyncOpenAI 5 | import os 6 | 7 | openai_client = AsyncOpenAI( 8 | api_key=os.getenv('OPENAI_API_KEY'), 9 | base_url=os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1') # Support custom base URL 10 | ) 11 | 12 | default_model = os.getenv('OPENAI_API_MODEL') 13 | default_temperature = float(os.getenv('OPENAI_API_TEMPERATURE', 0.5)) 14 | 15 | class OpenAIService: 16 | @staticmethod 17 | async def chat_completion( 18 | messages: List[ChatCompletionMessageParam], 19 | model: str = default_model, 20 | temperature: float = default_temperature, 21 | response_format: ResponseFormat = None 22 | ) -> str: 23 | """ 24 | Call OpenAI Chat Completion API 25 | 26 | Args: 27 | messages: List of messages 28 | model: Model name 29 | temperature: Temperature parameter 30 | response_format: Response format parameter 31 | 32 | Returns: 33 | str: AI response text 34 | """ 35 | try: 36 | response = await openai_client.chat.completions.create( 37 | model=model, 38 | messages=messages, 39 | temperature=temperature, 40 | response_format=response_format 41 | ) 42 | print('openai response', response) 43 | return response.choices[0].message.content 44 | except Exception as e: 45 | raise Exception(f"OpenAI API call failed: {str(e)}") -------------------------------------------------------------------------------- /backend/services/project_service.py: -------------------------------------------------------------------------------- 1 | from database import db 2 | from models.project import Project 3 | from typing import List 4 | from models.definition_column import DefinitionColumn 5 | from models.definition_relation import DefinitionRelation 6 | from models.definition_doc import DefinitionDoc 7 | from models.definition_table import DefinitionTable 8 | from models.definition_rule import DefinitionRule 9 | from models.task import Task 10 | import csv 11 | import os 12 | from sqlalchemy.orm import Session 13 | 14 | class ProjectService: 15 | @staticmethod 16 | def create_project(session: Session, name: str, description: str, db_type: str, db_version: str) -> Project: 17 | """Create a new project""" 18 | project = Project(name=name, description=description, db_type=db_type, db_version=db_version) 19 | session.add(project) 20 | return project 21 | 22 | @staticmethod 23 | def update_project(session: Session, id: int, name: str, description: str, db_type: str, db_version: str) -> Project: 24 | """Update project information""" 25 | project = session.query(Project).get(id) 26 | if not project: 27 | raise ValueError(f"Project with id {id} not found") 28 | 29 | project.update( 30 | name=name, 31 | description=description, 32 | db_type=db_type, 33 | db_version=db_version 34 | ) 35 | return project 36 | 37 | @staticmethod 38 | def delete_project(session: Session, id: int) -> None: 39 | """Delete a project and all its related records""" 40 | project = session.query(Project).get(id) 41 | if not project: 42 | raise ValueError(f"Project with id {id} not found") 43 | 44 | # Delete all related records 45 | session.query(DefinitionTable).filter(DefinitionTable.project_id == id).delete() 46 | session.query(DefinitionColumn).filter(DefinitionColumn.project_id == id).delete() 47 | session.query(DefinitionRelation).filter(DefinitionRelation.project_id == id).delete() 48 | session.query(DefinitionDoc).filter(DefinitionDoc.project_id == id).delete() 49 | session.query(DefinitionRule).filter(DefinitionRule.project_id == id).delete() 50 | session.query(Task).filter(Task.project_id == id).delete() 51 | 52 | # Delete project 53 | session.delete(project) 54 | 55 | @staticmethod 56 | def get_all_projects(session: Session) -> List[Project]: 57 | """Get all projects list""" 58 | return session.query(Project).order_by(Project.id).all() 59 | 60 | @staticmethod 61 | def get_project_by_id(session: Session, id: int) -> Project: 62 | """Get project by ID""" 63 | project = session.query(Project).get(id) 64 | if not project: 65 | raise ValueError(f"Project with id {id} not found") 66 | return project 67 | 68 | @classmethod 69 | def create_example_project(cls, session: Session): 70 | """Create an example project with predefined tables, columns, documents and rules""" 71 | # Create a new project 72 | project = Project( 73 | name="Example Project", 74 | description="An example project with predefined tables and rules", 75 | db_type="mysql", 76 | db_version="8.0" 77 | ) 78 | session.add(project) 79 | session.flush() # Get project.id 80 | 81 | # Read demo tables from CSV 82 | demo_tables_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'demo_tables.csv') 83 | with open(demo_tables_path, 'r') as f: 84 | reader = csv.DictReader(f) 85 | for row in reader: 86 | table = DefinitionTable( 87 | project_id=project.id, 88 | def_table=row['TABLE_NAME'], 89 | def_comment=row['TABLE_COMMENT'], 90 | def_waiting=True 91 | ) 92 | session.add(table) 93 | 94 | # Read demo columns from CSV 95 | demo_columns_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'demo_columns.csv') 96 | with open(demo_columns_path, 'r') as f: 97 | reader = csv.DictReader(f) 98 | for row in reader: 99 | column = DefinitionColumn( 100 | project_id=project.id, 101 | def_table=row['TABLE_NAME'], 102 | def_column=row['COLUMN_NAME'], 103 | def_type=row['COLUMN_TYPE'], 104 | def_comment=row['COLUMN_COMMENT'], 105 | def_waiting=True 106 | ) 107 | session.add(column) 108 | 109 | # Create example documents 110 | example_docs = [ 111 | { 112 | "def_doc": "The order table contains basic order information, and the user table contains basic user information.\n" 113 | "The order table is linked to the user table through user_id.\n" 114 | "Order status: 0-Pending 1-Paid 2-Shipped 3-Delivered 4-Cancelled", 115 | "def_selected": True 116 | }, 117 | { 118 | "def_doc": "The product table contains basic product information.\n" 119 | "The order_item table is a junction table between order and product, containing product information within order.\n" 120 | "Product status: 1-Active 0-Inactive", 121 | "def_selected": False 122 | } 123 | ] 124 | 125 | for doc in example_docs: 126 | definition_doc = DefinitionDoc( 127 | project_id=project.id, 128 | def_doc=doc["def_doc"], 129 | def_selected=doc["def_selected"] 130 | ) 131 | session.add(definition_doc) 132 | 133 | # Create example rules 134 | example_rules = [ 135 | { 136 | "name": "Prefer LEFT JOIN", 137 | "content": "When writing SQL queries, prefer LEFT JOIN over INNER JOIN to preserve all records from the main table.", 138 | "def_selected": False 139 | }, 140 | { 141 | "name": "Use Table Aliases", 142 | "content": "Use meaningful table aliases in SQL queries to improve readability. For example: 'o' for order table, 'u' for user table.", 143 | "def_selected": False 144 | }, 145 | { 146 | "name": "Filter Condition Placement", 147 | "content": "Place non-join filter conditions in the WHERE clause, and join conditions in the ON clause.", 148 | "def_selected": True 149 | } 150 | ] 151 | 152 | for rule in example_rules: 153 | definition_rule = DefinitionRule( 154 | project_id=project.id, 155 | name=rule["name"], 156 | content=rule["content"], 157 | def_selected=rule["def_selected"] 158 | ) 159 | session.add(definition_rule) 160 | 161 | return project -------------------------------------------------------------------------------- /backend/services/translate_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import uuid 4 | from typing import Optional 5 | 6 | class TranslateService: 7 | def __init__(self): 8 | self.subscription_key = os.getenv('AZURE_TRANSLATOR_KEY') 9 | self.endpoint = os.getenv('AZURE_TRANSLATOR_ENDPOINT', 'https://api.cognitive.microsofttranslator.com') 10 | self.location = os.getenv('AZURE_TRANSLATOR_LOCATION', 'global') 11 | 12 | # check if the service is active 13 | def is_active(self) -> bool: 14 | return self.subscription_key is not None and self.endpoint is not None and self.location is not None 15 | 16 | def translate(self, text: str, target_language: str, source_language: Optional[str] = None) -> str: 17 | """ 18 | Translate text using Microsoft Translator service 19 | 20 | Args: 21 | text: Text to translate 22 | target_language: Target language code (e.g., 'en', 'zh-Hans') 23 | source_language: Source language code (optional) 24 | 25 | Returns: 26 | Translated text 27 | """ 28 | path = '/translate' 29 | constructed_url = self.endpoint + path 30 | 31 | params = { 32 | 'api-version': '3.0', 33 | 'to': target_language 34 | } 35 | 36 | if source_language: 37 | params['from'] = source_language 38 | 39 | headers = { 40 | 'Ocp-Apim-Subscription-Key': self.subscription_key, 41 | 'Ocp-Apim-Subscription-Region': self.location, 42 | 'Content-type': 'application/json', 43 | 'X-ClientTraceId': str(uuid.uuid4()) 44 | } 45 | 46 | body = [{ 47 | 'text': text 48 | }] 49 | 50 | response = requests.post(constructed_url, params=params, headers=headers, json=body) 51 | response.raise_for_status() 52 | 53 | translations = response.json() 54 | return translations[0]['translations'][0]['text'] 55 | 56 | translate_service = TranslateService() -------------------------------------------------------------------------------- /backend/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit on error 4 | set -e 5 | 6 | # Set project root directory 7 | PROJECT_ROOT=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) 8 | 9 | # Configuration 10 | LOG_DIR="$PROJECT_ROOT/logs" 11 | PID_FILE="$PROJECT_ROOT/gunicorn.pid" 12 | ACCESS_LOG_FILE="$LOG_DIR/gunicorn_access_$(date +%Y%m%d_%H%M%S).log" 13 | ERROR_LOG_FILE="$LOG_DIR/gunicorn_error_$(date +%Y%m%d_%H%M%S).log" 14 | 15 | # Create logs directory if it doesn't exist 16 | if [ ! -d "$LOG_DIR" ]; then 17 | echo "$(date '+%Y-%m-%d %H:%M:%S') Creating logs directory..." 18 | mkdir -p "$LOG_DIR" 19 | fi 20 | 21 | # Function to stop gunicorn 22 | stop_gunicorn() { 23 | if [ -f "$PID_FILE" ]; then 24 | echo "$(date '+%Y-%m-%d %H:%M:%S') Stopping existing gunicorn process..." 25 | if kill -15 $(cat "$PID_FILE") 2>/dev/null; then 26 | echo "$(date '+%Y-%m-%d %H:%M:%S') Gunicorn process stopped gracefully" 27 | rm -f "$PID_FILE" 28 | else 29 | echo "$(date '+%Y-%m-%d %H:%M:%S') No running gunicorn process found" 30 | rm -f "$PID_FILE" 31 | fi 32 | fi 33 | } 34 | 35 | # Stop any existing gunicorn process 36 | stop_gunicorn 37 | 38 | # Start gunicorn 39 | echo "$(date '+%Y-%m-%d %H:%M:%S') Starting gunicorn..." 40 | exec gunicorn -c gunicorn.conf.py "wsgi:app" \ 41 | --pid "$PID_FILE" \ 42 | --access-logfile "$ACCESS_LOG_FILE" \ 43 | --error-logfile "$ERROR_LOG_FILE" \ 44 | --capture-output \ 45 | --log-level info -------------------------------------------------------------------------------- /backend/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty file to mark the directory as a Python package 2 | -------------------------------------------------------------------------------- /backend/utils/prompt_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import chevron 3 | from dto.gen_ai_comments_dto import GenAICommentsTableDTO 4 | from enums import DbType 5 | 6 | def get_gen_ai_comments(table: GenAICommentsTableDTO) -> str: 7 | """ 8 | Generate AI comments based on table information 9 | """ 10 | 11 | markdowns = [] 12 | markdowns.append(f"## {table.table}\n") 13 | markdowns.append(f"{table.comment}\n") 14 | 15 | markdowns.append("### Column Information\n") 16 | markdowns.append("| Column | Type | Comment |") 17 | markdowns.append("|--------|------|---------|") 18 | 19 | for col in table.columns: 20 | markdowns.append(f"| {col.column} | {col.type} | {col.comment} |") 21 | 22 | tableStr = "\n".join(markdowns) 23 | 24 | # Read the template 25 | template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 26 | 'prompt_templates/gen_ai_comments.mustache') 27 | with open(template_path, 'r', encoding='utf-8') as f: 28 | template = f.read() 29 | 30 | return chevron.render(template, { 31 | 'tableStr': tableStr 32 | }) 33 | 34 | def get_gen_sql(question: str, question_supplement: str, doc_content: str, sql_content: str, table_structure: str, relation_structure: str, rules: str, db_type: str, db_version: str) -> str: 35 | """ 36 | Generate SQL based on user's question 37 | """ 38 | db_type_name = DbType.get_display_name_by_value(db_type) 39 | 40 | # Read the template 41 | template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 42 | 'prompt_templates/gen_sql.mustache') 43 | with open(template_path, 'r', encoding='utf-8') as f: 44 | template = f.read() 45 | 46 | return chevron.render(template, { 47 | 'question': question, 48 | 'question_supplement': question_supplement, 49 | 'doc_content': doc_content, 50 | 'sql_content': sql_content, 51 | 'table_structure': table_structure, 52 | 'relation_structure': relation_structure, 53 | 'rules': rules, 54 | 'db_type_name': db_type_name, 55 | 'db_version': db_version 56 | }) 57 | 58 | def get_gen_related_columns(question: str, question_supplement: str, doc_content: str, sql_content: str) -> str: 59 | """ 60 | Generate related columns based on user's question 61 | """ 62 | # Read the template 63 | template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 64 | 'prompt_templates/gen_related_columns.mustache') 65 | with open(template_path, 'r', encoding='utf-8') as f: 66 | template = f.read() 67 | 68 | return chevron.render(template, { 69 | 'question': question, 70 | 'question_supplement': question_supplement, 71 | 'doc_content': doc_content, 72 | 'sql_content': sql_content 73 | }) 74 | 75 | def get_learn(question: str, question_supplement: str, sql: str, table_structure: str, sql_structure: str) -> str: 76 | """ 77 | Generate learning prompt 78 | """ 79 | # Read the template 80 | template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 81 | 'prompt_templates/learn.mustache') 82 | with open(template_path, 'r', encoding='utf-8') as f: 83 | template = f.read() 84 | 85 | return chevron.render(template, { 86 | 'question': question, 87 | 'question_supplement': question_supplement, 88 | 'sql': sql, 89 | 'table_structure': table_structure, 90 | 'sql_structure': sql_structure 91 | }) 92 | 93 | def get_optimize_question(question: str) -> str: 94 | """ 95 | Generate optimize question prompt 96 | """ 97 | # Read the template 98 | template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 99 | 'prompt_templates/optimize_question.mustache') 100 | with open(template_path, 'r', encoding='utf-8') as f: 101 | template = f.read() 102 | 103 | return chevron.render(template, { 104 | 'question': question 105 | }) -------------------------------------------------------------------------------- /backend/utils/schemas.py: -------------------------------------------------------------------------------- 1 | from marshmallow import Schema, fields 2 | from marshmallow_dataclass import class_schema 3 | 4 | class MessageResponseSchema(Schema): 5 | message = fields.Str(description='Return message') 6 | 7 | class PaginationQuerySchema(Schema): 8 | page = fields.Int(load_default=1, description="Page number") 9 | per_page = fields.Int(load_default=20, description="Number of items per page") 10 | 11 | class ProjectIdQuerySchema(Schema): 12 | project_id = fields.Int(description="Project ID") 13 | 14 | class PaginationBaseSchema(Schema): 15 | """Base schema for pagination responses""" 16 | items = fields.List(fields.Nested(Schema), description='List of items') 17 | total = fields.Int(description='Total number of items') 18 | page = fields.Int(description='Current page number') 19 | per_page = fields.Int(description='Number of items per page') 20 | pages = fields.Int(description='Total number of pages') 21 | 22 | class PaginationSchema: 23 | """Generic pagination schema factory""" 24 | @classmethod 25 | def create(cls, nested_schema, name): 26 | class_dict = { 27 | '__name__': name, 28 | 'items': fields.List(fields.Nested(nested_schema)) 29 | } 30 | return type(name, (PaginationBaseSchema,), class_dict) -------------------------------------------------------------------------------- /backend/utils/structure_util.py: -------------------------------------------------------------------------------- 1 | from models.task_doc import TaskDoc 2 | from models.definition_doc import DefinitionDoc 3 | from models.task import Task 4 | from models.definition_rule import DefinitionRule 5 | from models.definition_relation import DefinitionRelation 6 | from models.task_table import TaskTable 7 | from models.task_sql import TaskSQL 8 | from models.definition_table import DefinitionTable 9 | from models.definition_column import DefinitionColumn 10 | from models.task_column import TaskColumn 11 | from dto.gen_ai_comments_dto import GenAICommentsTableDTO 12 | from database import db 13 | 14 | def fix_question(question: str) -> str: 15 | """ 16 | Fix line breaks in the question 17 | """ 18 | return question.replace('\n', ' ') 19 | 20 | def get_rule_structure_markdown(session, task: Task) -> str: 21 | """ 22 | Get markdown description of rules 23 | Separated by --- 24 | """ 25 | rules = session.query(DefinitionRule).filter(DefinitionRule.id.in_(task.rules)).order_by(DefinitionRule.id).all() 26 | return "\n\n---\n\n".join([rule.content for rule in rules]) 27 | 28 | def get_doc_content(session, task_id: int) -> str: 29 | """ 30 | Get associated document content 31 | Separated by --- 32 | """ 33 | task_docs = session.query(TaskDoc.id, DefinitionDoc.def_doc)\ 34 | .join(DefinitionDoc, DefinitionDoc.id == TaskDoc.doc_id)\ 35 | .filter(TaskDoc.task_id == task_id)\ 36 | .all() 37 | 38 | contents = [] 39 | for task_doc in task_docs: 40 | content = task_doc.def_doc 41 | contents.append(content) 42 | 43 | return "\n\n---\n\n".join(contents) 44 | 45 | def get_sql_log_structure_markdown(session, task_id: int) -> str: 46 | """ 47 | Get SQL log table structure description (markdown format) 48 | Separated by --- 49 | """ 50 | # Get associated SQL records (replace line breaks with spaces in questions) 51 | task_sqls = session.query(TaskSQL.id, Task.question, Task.sql)\ 52 | .join(Task, Task.id == TaskSQL.sql_id)\ 53 | .filter(TaskSQL.task_id == task_id)\ 54 | .all() 55 | 56 | markdowns = [] 57 | for s in task_sqls: 58 | markdown = [] 59 | markdown.append(f"## {fix_question(s.question)}\n") 60 | markdown.append("```sql") 61 | markdown.append(f"{s.sql}") 62 | markdown.append("```") 63 | markdowns.append("\n".join(markdown)) 64 | 65 | return "\n\n---\n\n".join(markdowns) 66 | 67 | def get_table_structure_markdown(session, task_id: int) -> str: 68 | """ 69 | Generate table structure description based on selected columns (markdown format) 70 | 71 | Args: 72 | task_id: Task ID 73 | 74 | Returns: 75 | str: Table structure description in markdown format 76 | """ 77 | # Get all table and column definitions 78 | table_defs = {t.def_table: t.def_ai_comment or t.def_comment for t in session.query(DefinitionTable).all()} 79 | column_defs = { 80 | f"{c.def_table}.{c.def_column}": { 81 | "type": c.def_type, 82 | "comment": c.def_ai_comment or c.def_comment 83 | } for c in session.query(DefinitionColumn).all() 84 | } 85 | 86 | selected_columns = session.query(TaskColumn).filter(TaskColumn.task_id == task_id).all() 87 | 88 | # Group selected columns by table 89 | table_columns = {} 90 | for col in selected_columns: 91 | if col.table_name not in table_columns: 92 | table_columns[col.table_name] = [] 93 | table_columns[col.table_name].append(col.column_name) 94 | 95 | # Generate markdown 96 | markdowns = [] 97 | for table_name, columns in table_columns.items(): 98 | markdown = [] 99 | # Add table title and description 100 | table_comment = table_defs.get(table_name, '') 101 | markdown.append(f"## {table_name}\n") 102 | if table_comment: 103 | markdown.append(f"{table_comment}\n") 104 | 105 | # Add column headers 106 | markdown.append("| Column | Type | Description |") 107 | markdown.append("|--------|------|-------------|") 108 | 109 | # Add column information 110 | for column in columns: 111 | col_key = f"{table_name}.{column}" 112 | col_info = column_defs.get(col_key, {"type": "", "comment": ""}) 113 | markdown.append( 114 | f"| {column} | {col_info['type']} | {col_info['comment']} |" 115 | ) 116 | markdowns.append("\n".join(markdown)) 117 | 118 | return "\n\n".join(markdowns) 119 | 120 | def get_relation_structure_markdown(session, task_id: int) -> str: 121 | """ 122 | Generate table relationship description based on selected tables (markdown format) 123 | 124 | Args: 125 | task_id: Task ID 126 | 127 | Returns: 128 | str: Table relationship description in markdown format 129 | """ 130 | # Get task-related tables 131 | selected_tables = session.query(TaskTable).filter(TaskTable.task_id == task_id).all() 132 | table_names = [t.table_name for t in selected_tables] 133 | 134 | # Get related table relationship definitions 135 | relations = session.query(DefinitionRelation).filter( 136 | DefinitionRelation.table1.in_(table_names), 137 | DefinitionRelation.table2.in_(table_names) 138 | ).all() 139 | 140 | # Generate markdown 141 | markdowns = [] 142 | if relations: 143 | markdown = [] 144 | markdown.append("## Table Relationships\n") 145 | 146 | # Add headers 147 | markdown.append("| Table1 | Column1 | Table2 | Column2 | Relationship Type |") 148 | markdown.append("|--------|---------|---------|---------|------------------|") 149 | 150 | # Add relationship information 151 | for relation in relations: 152 | markdown.append( 153 | f"| {relation.table1} | {relation.column1} | " 154 | f"{relation.table2} | {relation.column2} | {relation.relation_type} |" 155 | ) 156 | markdowns.append("\n".join(markdown)) 157 | 158 | return "\n\n".join(markdowns) 159 | -------------------------------------------------------------------------------- /backend/utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def extract_json(text: str) -> dict: 4 | """ 5 | Extract JSON from text 6 | """ 7 | text = text.strip() 8 | if text.startswith('```json'): 9 | return json.loads(text[len('```json'):-len('```')]) 10 | return json.loads(text) 11 | 12 | def reverse_relation_type(relation_type: str) -> str: 13 | """Reverse relation type""" 14 | return { 15 | '1-1': '1-1', 16 | '1-n': 'n-1', 17 | 'n-1': '1-n', 18 | 'n-n': 'n-n' 19 | }[relation_type] -------------------------------------------------------------------------------- /backend/vector_stores.py: -------------------------------------------------------------------------------- 1 | import os 2 | from vectors.vector_chroma import ChromaDBHandler 3 | from vectors.translate_wrapper import TranslateWrapper 4 | 5 | table_def_store = TranslateWrapper( 6 | vector_store=ChromaDBHandler( 7 | host=os.getenv("CHROMA_HOST", "localhost"), 8 | port=int(os.getenv("CHROMA_PORT", "8000")), 9 | collection_name="table_def" 10 | ), 11 | target_language='en' 12 | ) 13 | 14 | column_def_store = TranslateWrapper( 15 | vector_store=ChromaDBHandler( 16 | host=os.getenv("CHROMA_HOST", "localhost"), 17 | port=int(os.getenv("CHROMA_PORT", "8000")), 18 | collection_name="column_def" 19 | ), 20 | target_language='en' 21 | ) 22 | 23 | doc_def_store = TranslateWrapper( 24 | vector_store=ChromaDBHandler( 25 | host=os.getenv("CHROMA_HOST", "localhost"), 26 | port=int(os.getenv("CHROMA_PORT", "8000")), 27 | collection_name="doc_def" 28 | ), 29 | target_language='en' 30 | ) 31 | 32 | sql_log_store = TranslateWrapper( 33 | vector_store=ChromaDBHandler( 34 | host=os.getenv("CHROMA_HOST", "localhost"), 35 | port=int(os.getenv("CHROMA_PORT", "8000")), 36 | collection_name="sql_log" 37 | ), 38 | target_language='en' 39 | ) -------------------------------------------------------------------------------- /backend/vectors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IndieYe/sqlwise/ed4901f7a28ffb9150ac0acbdd6f9839392779b9/backend/vectors/__init__.py -------------------------------------------------------------------------------- /backend/vectors/translate_wrapper.py: -------------------------------------------------------------------------------- 1 | from vectors.vector_store import VectorStore 2 | 3 | class TranslateWrapper(VectorStore): 4 | def __init__(self, vector_store: VectorStore, target_language: str='en'): 5 | self.vector_store = vector_store 6 | self.target_language = target_language 7 | 8 | def add_document(self, document, metadata, doc_id): 9 | from services.translate_service import translate_service 10 | try: 11 | if translate_service.is_active(): 12 | translated_document = translate_service.translate(document, self.target_language) 13 | else: 14 | translated_document = document 15 | except Exception as e: 16 | translated_document = document 17 | self.vector_store.add_document(translated_document, metadata, doc_id) 18 | 19 | def query_documents(self, query_text, n_results=1, where=None): 20 | from services.translate_service import translate_service 21 | try: 22 | if translate_service.is_active(): 23 | translated_query_text = translate_service.translate(query_text, self.target_language) 24 | else: 25 | translated_query_text = query_text 26 | except Exception as e: 27 | translated_query_text = query_text 28 | return self.vector_store.query_documents(translated_query_text, n_results, where) 29 | 30 | def delete_documents(self, where): 31 | self.vector_store.delete_documents(where) 32 | 33 | def clear_collection(self): 34 | self.vector_store.clear_collection() -------------------------------------------------------------------------------- /backend/vectors/vector_chroma.py: -------------------------------------------------------------------------------- 1 | import chromadb 2 | from vectors.vector_store import VectorStore 3 | 4 | class ChromaDBHandler(VectorStore): 5 | def __init__(self, host, port, collection_name): 6 | self.client = chromadb.HttpClient(host=host, port=port) 7 | self.collection = self.client.get_or_create_collection(name=collection_name) 8 | print(f"Initialized ChromaDBHandler for collection: {collection_name}") 9 | 10 | def add_document(self, document, metadata, doc_id): 11 | self.collection.upsert( 12 | documents=[document], 13 | metadatas=[metadata], 14 | ids=[doc_id] 15 | ) 16 | 17 | def query_documents(self, query_text, n_results=1, where=None): 18 | return self.collection.query( 19 | query_texts=[query_text], 20 | n_results=n_results, 21 | where=where 22 | ) 23 | 24 | def delete_documents(self, where): 25 | self.collection.delete(where=where) 26 | 27 | def clear_collection(self): 28 | self.collection.delete(ids=self.collection.get()["ids"]) -------------------------------------------------------------------------------- /backend/vectors/vector_store.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class VectorStore(ABC): 4 | """Abstract base class for vector databases""" 5 | 6 | @abstractmethod 7 | def add_document(self, document, metadata, doc_id): 8 | """Add document to vector storage""" 9 | pass 10 | 11 | @abstractmethod 12 | def query_documents(self, query_text, n_results=1): 13 | """Query documents""" 14 | pass 15 | 16 | @abstractmethod 17 | def delete_documents(self, where): 18 | """Delete documents""" 19 | pass 20 | 21 | @abstractmethod 22 | def clear_collection(self): 23 | """Clear collection""" 24 | pass -------------------------------------------------------------------------------- /backend/wsgi.py: -------------------------------------------------------------------------------- 1 | from app import app 2 | from jobs import init_scheduler 3 | 4 | def on_post_fork(server, worker): 5 | """This runs in each worker process.""" 6 | from database import db 7 | 8 | with app.app_context(): 9 | if hasattr(db, 'engine'): 10 | # 断开继承自主进程的数据库连接 11 | db.engine.dispose() 12 | 13 | def on_when_ready(server): 14 | """This runs in the master process before spawning workers.""" 15 | app.logger.info('Initializing scheduler in master process') 16 | with app.app_context(): 17 | try: 18 | init_scheduler(app) 19 | app.logger.info('Scheduler initialized successfully') 20 | except Exception as e: 21 | app.logger.error(f'Failed to initialize scheduler: {str(e)}') 22 | 23 | if __name__ == "__main__": 24 | app.run() -------------------------------------------------------------------------------- /frontend/.cursorrules: -------------------------------------------------------------------------------- 1 | Project Name: SQLWise 2 | Project Description: Use AI to generate SQL. 3 | Project Framework: typescript+vite+react+tailwindcss+flowbite-react+react-icons+react-toastify 4 | Import path use @ instead of /src directory -------------------------------------------------------------------------------- /frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | pnpm-debug.log* 8 | lerna-debug.log* 9 | 10 | node_modules 11 | dist 12 | dist-ssr 13 | *.local 14 | 15 | # Editor directories and files 16 | .vscode/* 17 | !.vscode/extensions.json 18 | .idea 19 | .DS_Store 20 | *.suo 21 | *.ntvs* 22 | *.njsproj 23 | *.sln 24 | *.sw? 25 | 26 | # OpenAPI generated files 27 | openapi.json 28 | src/api-docs -------------------------------------------------------------------------------- /frontend/eslint.config.js: -------------------------------------------------------------------------------- 1 | import js from '@eslint/js' 2 | import globals from 'globals' 3 | import reactHooks from 'eslint-plugin-react-hooks' 4 | import reactRefresh from 'eslint-plugin-react-refresh' 5 | import tseslint from 'typescript-eslint' 6 | 7 | export default tseslint.config( 8 | { ignores: ['dist'] }, 9 | { 10 | extends: [js.configs.recommended, ...tseslint.configs.recommended], 11 | files: ['**/*.{ts,tsx}'], 12 | languageOptions: { 13 | ecmaVersion: 2020, 14 | globals: globals.browser, 15 | }, 16 | plugins: { 17 | 'react-hooks': reactHooks, 18 | 'react-refresh': reactRefresh, 19 | }, 20 | rules: { 21 | ...reactHooks.configs.recommended.rules, 22 | 'react-refresh/only-export-components': [ 23 | 'warn', 24 | { allowConstantExport: true }, 25 | ], 26 | }, 27 | }, 28 | ) 29 | -------------------------------------------------------------------------------- /frontend/getOpenapi.cjs: -------------------------------------------------------------------------------- 1 | // Get the http://localhost:8000/openapi.json file and write it to openapi.json 2 | 3 | const axios = require('axios') 4 | const fs = require('fs') 5 | 6 | axios.get('http://127.0.0.1:8000/api/openapi.json').then(res => { 7 | fs.writeFileSync('openapi.json', JSON.stringify(res.data)) 8 | }) 9 | 10 | fs.rmSync('src/api-docs', { recursive: true, force: true }) 11 | -------------------------------------------------------------------------------- /frontend/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | SQLWise 8 | 9 | 10 |
11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /frontend/openapitools.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "./node_modules/@openapitools/openapi-generator-cli/config.schema.json", 3 | "spaces": 2, 4 | "generator-cli": { 5 | "version": "7.9.0" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sqlwise", 3 | "private": true, 4 | "version": "0.1.0", 5 | "type": "module", 6 | "scripts": { 7 | "dev": "vite", 8 | "start": "npm run openapi:get && npm run openapi:gen && vite", 9 | "build": "tsc -b && vite build", 10 | "lint": "eslint .", 11 | "preview": "vite preview", 12 | "openapi:get": "node getOpenapi.cjs", 13 | "openapi:gen": "openapi-generator-cli generate -i openapi.json -g typescript-axios -o src/api-docs" 14 | }, 15 | "dependencies": { 16 | "@reduxjs/toolkit": "^2.2.1", 17 | "ahooks": "^3.8.1", 18 | "axios": "^1.7.7", 19 | "flowbite-react": "^0.10.2", 20 | "i18next": "^24.0.2", 21 | "react": "^18.3.1", 22 | "react-dom": "^18.3.1", 23 | "react-i18next": "^15.1.3", 24 | "react-icons": "^5.3.0", 25 | "react-markdown": "^9.0.1", 26 | "react-redux": "^9.1.0", 27 | "react-router-dom": "^6.28.0", 28 | "react-toastify": "^10.0.6", 29 | "react-transition-group": "^4.4.5", 30 | "sql-formatter": "^15.4.6", 31 | "tailwind-merge": "^2.5.4", 32 | "tailwind-scrollbar-hide": "^1.1.7" 33 | }, 34 | "devDependencies": { 35 | "@eslint/js": "^9.13.0", 36 | "@openapitools/openapi-generator-cli": "^2.15.3", 37 | "@tailwindcss/typography": "^0.5.15", 38 | "@types/react": "^18.3.12", 39 | "@types/react-dom": "^18.3.1", 40 | "@types/react-transition-group": "^4.4.11", 41 | "@vitejs/plugin-react-swc": "^3.5.0", 42 | "autoprefixer": "^10.4.18", 43 | "eslint": "^9.13.0", 44 | "eslint-plugin-react-hooks": "^5.0.0", 45 | "eslint-plugin-react-refresh": "^0.4.14", 46 | "globals": "^15.11.0", 47 | "postcss": "^8.4.35", 48 | "tailwindcss": "^3.4.1", 49 | "typescript": "~5.6.2", 50 | "typescript-eslint": "^8.11.0", 51 | "vite": "^5.4.10" 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /frontend/postcss.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | } -------------------------------------------------------------------------------- /frontend/public/columns_template.csv: -------------------------------------------------------------------------------- 1 | TABLE_NAME,COLUMN_NAME,COLUMN_TYPE,COLUMN_COMMENT 2 | user,id,varchar(36),User ID 3 | -------------------------------------------------------------------------------- /frontend/public/tables_template.csv: -------------------------------------------------------------------------------- 1 | TABLE_NAME,TABLE_COMMENT 2 | user,User Table 3 | -------------------------------------------------------------------------------- /frontend/src/App.tsx: -------------------------------------------------------------------------------- 1 | import { Configuration, MainApi, ProjectApi } from './api-docs' 2 | import { Flowbite, CustomFlowbiteTheme } from 'flowbite-react' 3 | import { ToastContainer } from 'react-toastify' 4 | import 'react-toastify/dist/ReactToastify.css' 5 | import useEnvService from './hooks/useEnvService' 6 | import { BrowserRouter, Routes, Route } from 'react-router-dom' 7 | import { ProjectLayout } from './components/ProjectLayout' 8 | import { ProjectList } from '@/components/projects/ProjectList' 9 | import { I18nextProvider } from 'react-i18next'; 10 | import i18n from './i18n/i18n' 11 | 12 | // Create API instance 13 | export const mainApi = new MainApi(new Configuration({ 14 | basePath: 'http://localhost:8000' 15 | })) 16 | 17 | // Create project API instance 18 | export const projectApi = new ProjectApi(new Configuration({ 19 | basePath: 'http://localhost:8000' 20 | })); 21 | 22 | // Custom theme configuration 23 | const customTheme: CustomFlowbiteTheme = { 24 | sidebar: { 25 | item: { 26 | active: 'bg-cyan-700 text-white hover:bg-cyan-800 [&>svg]:text-white', 27 | base: 'group flex items-center justify-center rounded-lg p-2 text-base font-normal text-gray-900 hover:bg-gray-200 [&>svg]:text-gray-500', 28 | } 29 | } 30 | } 31 | 32 | function App() { 33 | useEnvService() 34 | 35 | return ( 36 | 37 | 38 | 39 | 40 | } /> 41 | } /> 42 | 43 | 44 | 45 | 46 | 47 | ) 48 | } 49 | 50 | export default App 51 | -------------------------------------------------------------------------------- /frontend/src/components/AppSidebar.tsx: -------------------------------------------------------------------------------- 1 | import { Sidebar } from 'flowbite-react'; 2 | import { 3 | HiOutlineDocumentText, 4 | HiOutlineDatabase, 5 | HiOutlineClipboardList, 6 | HiOutlineBookOpen, 7 | HiOutlineCog 8 | } from 'react-icons/hi'; 9 | import { useAppDispatch, useAppSelector } from '@/store/hooks'; 10 | import { setCurrentMenu } from '@/store/slices/appSlice'; 11 | import { twMerge } from 'tailwind-merge'; 12 | import { useHover } from 'ahooks'; 13 | import { useLocation, useNavigate } from 'react-router-dom'; 14 | import { useTranslation } from 'react-i18next'; 15 | 16 | export function AppSidebar() { 17 | const dispatch = useAppDispatch(); 18 | const projectId = useAppSelector(state => state.app.projectId); 19 | const location = useLocation(); 20 | const navigate = useNavigate(); 21 | const { t } = useTranslation(); 22 | 23 | const handleMenuClick = (menu: string) => { 24 | dispatch(setCurrentMenu(menu)); 25 | navigate(`/${projectId}/${menu}`); 26 | }; 27 | 28 | const isCollapsed = location.pathname.endsWith('/records'); 29 | 30 | const hovered = useHover(document.querySelector('.my-sidebar')); 31 | 32 | return ( 33 | 40 | 41 | 42 | {[ 43 | { icon: HiOutlineClipboardList, text: t('project.records'), menu: 'records' }, 44 | { icon: HiOutlineDatabase, text: t('project.ddl'), menu: 'ddls' }, 45 | { icon: HiOutlineDocumentText, text: t('project.docs'), menu: 'docs' }, 46 | { icon: HiOutlineBookOpen, text: t('project.rules'), menu: 'rules' }, 47 | { icon: HiOutlineCog, text: t('project.settings'), menu: 'settings' } 48 | ].map(({ icon: Icon, text, menu }) => ( 49 | handleMenuClick(menu)} 54 | className={twMerge( 55 | "flex my-sidebar-item cursor-pointer", 56 | isCollapsed ? "justify-center px-0 group-hover:justify-start group-hover:px-2" : "" 57 | )} 58 | > 59 | 60 | {(!isCollapsed || hovered) ? text : } 61 | 62 | 63 | ))} 64 | 65 | 66 | 67 | ); 68 | } -------------------------------------------------------------------------------- /frontend/src/components/LanguageSwitcher.tsx: -------------------------------------------------------------------------------- 1 | import { Dropdown } from 'flowbite-react'; 2 | import { useTranslation } from 'react-i18next'; 3 | import { IoLanguageOutline } from 'react-icons/io5'; 4 | import { useAppDispatch } from '@/store/hooks'; 5 | import { setLanguage } from '@/store/slices/appSlice'; 6 | import { changeLanguage } from '@/i18n/i18n'; 7 | 8 | const languages = [ 9 | { code: 'zh-CN', name: '中文' }, 10 | { code: 'en-US', name: 'English' }, 11 | ]; 12 | 13 | export const LanguageSwitcher = () => { 14 | const { i18n } = useTranslation(); 15 | const dispatch = useAppDispatch(); 16 | 17 | const handleLanguageChange = (langCode: string) => { 18 | dispatch(setLanguage(langCode)); 19 | changeLanguage(langCode); 20 | }; 21 | 22 | const currentLanguage = languages.find(lang => lang.code === i18n.language)?.name || '中文'; 23 | 24 | return ( 25 | ( 29 | 35 | )} 36 | > 37 | {languages.map((lang) => ( 38 | handleLanguageChange(lang.code)} 41 | className={i18n.language === lang.code ? 'bg-gray-100 dark:bg-gray-600' : ''} 42 | > 43 | {lang.name} 44 | 45 | ))} 46 | 47 | ); 48 | }; 49 | -------------------------------------------------------------------------------- /frontend/src/components/ProjectLayout.tsx: -------------------------------------------------------------------------------- 1 | import { Routes, Route, Navigate, useLocation } from 'react-router-dom'; 2 | import { GenerationRecords } from '@/components/records/GenerationRecords'; 3 | import { DDL } from '@/components/ddl/DDL'; 4 | import { DocList } from '@/components/docs/DocList'; 5 | import { RuleList } from '@/components/rules/RuleList'; 6 | import Settings from '@/components/settings/Settings'; 7 | import { AppSidebar } from '@/components/AppSidebar'; 8 | import AICommentModal from '@/components/records/refs/AICommentModal'; 9 | import useAppService from '@/hooks/useAppService'; 10 | import { Dropdown } from 'flowbite-react'; 11 | import { HiChevronDown } from 'react-icons/hi'; 12 | import { useNavigate } from 'react-router-dom'; 13 | import { useRequest } from 'ahooks'; 14 | import { projectApi } from '@/App'; 15 | import { useAppSelector } from '@/store/hooks'; 16 | import { LanguageSwitcher } from './LanguageSwitcher'; 17 | import { useTranslation } from 'react-i18next'; 18 | import { DB_TYPE_LABELS, type DbType } from '@/consts'; 19 | 20 | export function ProjectLayout() { 21 | const navigate = useNavigate(); 22 | const location = useLocation(); 23 | const projectId = useAppSelector(state => state.app.projectId); 24 | const { t } = useTranslation(); 25 | 26 | // Get project list 27 | const { data: projectList } = useRequest(async () => { 28 | const response = await projectApi.projectGet(); 29 | return response.data.projects || []; 30 | }); 31 | 32 | // Get current project information 33 | const { data: currentProject } = useRequest(async () => { 34 | if (projectId) { 35 | const response = await projectApi.projectIdGet(projectId); 36 | return response.data; 37 | } 38 | return null; 39 | }, { 40 | refreshDeps: [projectId] 41 | }); 42 | 43 | // Services 44 | useAppService(); 45 | 46 | if (!projectId) { 47 | return
48 |
Loading...
49 |
; 50 | } 51 | 52 | return ( 53 |
54 |
55 |
56 | 59 | {currentProject?.name || t('project.loading', 'Loading...')} 60 | 61 |
62 | } 63 | dismissOnClick={true} 64 | inline={true} 65 | arrowIcon={false} 66 | > 67 | {projectList?.map((project) => ( 68 | navigate(`/${project.id}${location.pathname.substring(location.pathname.indexOf('/', 1))}`, { replace: true })} 71 | className="text-sm" 72 | > 73 | 74 | {project.name} 75 | {project.db_type && ( 76 | 77 | {DB_TYPE_LABELS[project.db_type as DbType]} {project.db_version} 78 | 79 | )} 80 | 81 | 82 | ))} 83 | 84 | navigate('/')}> 85 | {t('project.backToList', 'Back to project list')} 86 | 87 | 88 |
89 | 90 |
91 |
92 |
93 | 94 |
95 | 96 | } /> 97 | } /> 98 | } /> 99 | } /> 100 | } /> 101 | } /> 102 | 103 |
104 | 105 |
106 |
107 | ); 108 | } -------------------------------------------------------------------------------- /frontend/src/components/ddl/DDL.tsx: -------------------------------------------------------------------------------- 1 | import { useAppSelector } from '@/store/hooks'; 2 | import { TableList } from './TableList'; 3 | import TableCommentEditor from '@/components/ddl/TableCommentEditor'; 4 | 5 | export function DDL() { 6 | const selectedTable = useAppSelector(state => state.ddl.selectedTable); 7 | 8 | const handleConfirm = () => { 9 | // Handle logic after confirmation 10 | }; 11 | 12 | return ( 13 |
14 | {/* Left list */} 15 | 16 | {/* Right content */} 17 |
18 | {selectedTable && ( 19 | 23 | )} 24 |
25 |
26 | ); 27 | } -------------------------------------------------------------------------------- /frontend/src/components/ddl/FileImportTab.tsx: -------------------------------------------------------------------------------- 1 | import { Button } from 'flowbite-react'; 2 | import { HiUpload, HiDownload } from 'react-icons/hi'; 3 | import { useState } from 'react'; 4 | import { toast } from 'react-toastify'; 5 | import { useTranslation } from 'react-i18next'; 6 | import { mainApi } from '@/App'; 7 | import useTask from '@/hooks/useTask'; 8 | import { useAppSelector } from '@/store/hooks'; 9 | 10 | interface FileImportTabProps { 11 | onClose: () => void; 12 | } 13 | 14 | export function FileImportTab({ onClose }: FileImportTabProps) { 15 | const { t } = useTranslation(); 16 | const [selectedTableFile, setSelectedTableFile] = useState(null); 17 | const [selectedColumnFile, setSelectedColumnFile] = useState(null); 18 | const { refreshSchema } = useTask(); 19 | const projectId = useAppSelector(state => state.app.projectId); 20 | 21 | const handleTableFileChange = (event: React.ChangeEvent) => { 22 | if (event.target.files && event.target.files[0]) { 23 | setSelectedTableFile(event.target.files[0]); 24 | } 25 | }; 26 | 27 | const handleColumnFileChange = (event: React.ChangeEvent) => { 28 | if (event.target.files && event.target.files[0]) { 29 | setSelectedColumnFile(event.target.files[0]); 30 | } 31 | }; 32 | 33 | const handleDragOver = (event: React.DragEvent) => { 34 | event.preventDefault(); 35 | event.currentTarget.classList.add('border-blue-500'); 36 | }; 37 | 38 | const handleDragLeave = (event: React.DragEvent) => { 39 | event.preventDefault(); 40 | event.currentTarget.classList.remove('border-blue-500'); 41 | }; 42 | 43 | const handleTableDrop = (event: React.DragEvent) => { 44 | event.preventDefault(); 45 | event.currentTarget.classList.remove('border-blue-500'); 46 | 47 | if (event.dataTransfer.files && event.dataTransfer.files[0]) { 48 | const file = event.dataTransfer.files[0]; 49 | if (file.type === 'text/csv') { 50 | setSelectedTableFile(file); 51 | } else { 52 | toast.error(t('ddl.onlySupportCSV')); 53 | } 54 | } 55 | }; 56 | 57 | const handleColumnDrop = (event: React.DragEvent) => { 58 | event.preventDefault(); 59 | event.currentTarget.classList.remove('border-blue-500'); 60 | 61 | if (event.dataTransfer.files && event.dataTransfer.files[0]) { 62 | const file = event.dataTransfer.files[0]; 63 | if (file.type === 'text/csv') { 64 | setSelectedColumnFile(file); 65 | } else { 66 | toast.error(t('ddl.onlySupportCSV')); 67 | } 68 | } 69 | }; 70 | 71 | const handleUpload = async () => { 72 | if (!selectedTableFile || !selectedColumnFile) { 73 | toast.error(t('ddl.selectBothFiles')); 74 | return; 75 | } 76 | 77 | try { 78 | await mainApi.mainUpdateDDLPost(projectId, selectedTableFile, selectedColumnFile); 79 | toast.success(t('ddl.uploadSuccess')); 80 | refreshSchema(); 81 | onClose(); 82 | } catch (error: any) { 83 | if (error.response && error.response.status === 400) { 84 | const errorMessage = error.response.data?.message || t('common.invalidParams'); 85 | toast.error(t('ddl.uploadFailed') + ': ' + errorMessage); 86 | } else { 87 | toast.error(t('ddl.uploadFailed') + ': ' + (error.message || t('common.unknownError'))); 88 | } 89 | } 90 | }; 91 | 92 | const downloadTemplate = (type: 'tables' | 'columns') => { 93 | const link = document.createElement('a'); 94 | link.href = `/${type}_template.csv`; 95 | link.download = `${type}_template.csv`; 96 | document.body.appendChild(link); 97 | link.click(); 98 | document.body.removeChild(link); 99 | }; 100 | 101 | return ( 102 |
103 |
104 | 124 | 132 |
133 | {selectedTableFile && ( 134 |

135 | {t('ddl.selectedTableFile')}: {selectedTableFile.name} 136 |

137 | )} 138 | 139 |
140 | 160 | 168 |
169 | {selectedColumnFile && ( 170 |

171 | {t('ddl.selectedColumnFile')}: {selectedColumnFile.name} 172 |

173 | )} 174 | 175 |
176 | 182 |
183 |
184 | ); 185 | } -------------------------------------------------------------------------------- /frontend/src/components/ddl/QueryImportTab.tsx: -------------------------------------------------------------------------------- 1 | import { Button, Select } from 'flowbite-react'; 2 | import { HiClipboard } from 'react-icons/hi'; 3 | import { useState, useEffect } from 'react'; 4 | import { toast } from 'react-toastify'; 5 | import useTask from '@/hooks/useTask'; 6 | import { mainApi } from '@/App'; 7 | import { useAppSelector } from '@/store/hooks'; 8 | import { useTranslation } from 'react-i18next'; 9 | import { DB_COLUMN_QUERIES, DB_TABLE_QUERIES, DB_TYPES, DB_TYPE_LABELS, DbType } from '@/consts'; 10 | import useAsyncEffect from 'ahooks/lib/useAsyncEffect'; 11 | 12 | interface QueryImportTabProps { 13 | onClose: () => void; 14 | } 15 | 16 | export function QueryImportTab({ onClose }: QueryImportTabProps) { 17 | const { t } = useTranslation(); 18 | const { refreshSchema } = useTask(); 19 | const projectId = useAppSelector(state => state.app.projectId); 20 | const [selectedDbType, setSelectedDbType] = useState(DB_TYPES.MYSQL); 21 | const [tableQuery, setTableQuery] = useState(''); 22 | const [columnQuery, setColumnQuery] = useState(''); 23 | const [tableQueryResult, setTableQueryResult] = useState(''); 24 | const [columnQueryResult, setColumnQueryResult] = useState(''); 25 | 26 | useAsyncEffect(async () => { 27 | if (projectId) { 28 | const response = await mainApi.mainProjectSettingsGet(projectId); 29 | setSelectedDbType(response.data.db_type as DbType); 30 | } 31 | }, [projectId]); 32 | 33 | useEffect(() => { 34 | const tableQueryTemplate = DB_TABLE_QUERIES[selectedDbType] || ''; 35 | const columnQueryTemplate = DB_COLUMN_QUERIES[selectedDbType] || ''; 36 | setTableQuery(tableQueryTemplate); 37 | setColumnQuery(columnQueryTemplate); 38 | }, [selectedDbType]); 39 | 40 | const handleQuerySubmit = async () => { 41 | if (!tableQuery.trim() || !columnQuery.trim()) { 42 | toast.error(t('ddl.importMethods.queryPlaceholder')); 43 | return; 44 | } 45 | 46 | try { 47 | await mainApi.mainUpdateDDLByQueryPost({ 48 | project_id: projectId, 49 | tables: JSON.parse(tableQueryResult), 50 | columns: JSON.parse(columnQueryResult), 51 | }); 52 | toast.success(t('ddl.importMethods.querySuccess')); 53 | setTableQuery(''); 54 | setColumnQuery(''); 55 | setTableQueryResult(''); 56 | setColumnQueryResult(''); 57 | refreshSchema(); 58 | onClose(); 59 | } catch (error: any) { 60 | if (error.response && error.response.status === 400) { 61 | const errorMessage = error.response.data?.message || t('common.invalidParams'); 62 | toast.error(t('ddl.importMethods.queryFailed') + ': ' + errorMessage); 63 | } else { 64 | toast.error(t('ddl.importMethods.queryFailed') + ': ' + (error.message || t('common.unknownError'))); 65 | } 66 | } 67 | }; 68 | 69 | return ( 70 |
71 |
72 | 82 |
83 | 84 |
85 | 88 |
89 |
 90 |                         
 91 |                             {tableQuery}
 92 |                         
 93 |                     
94 | 103 |
104 |
105 | 106 |
107 | 110 |
111 |
112 |                         
113 |                             {columnQuery}
114 |                         
115 |                     
116 | 125 |
126 |
127 | 128 |
129 | 132 |