├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yml ├── docker-entrypoint.sh ├── images ├── extra.sql ├── pip.txt └── viper.sql ├── logs └── .gitkeep ├── requirements.txt ├── start_huey.sh ├── tests ├── __init__.py ├── demo_jsonschema.py ├── get_access_token.py ├── get_chat_id.py └── nacos_encrypt.py └── viper ├── __init__.py ├── core ├── __init__.py ├── errors.py ├── events.py ├── middlewares.py ├── routers.py └── settings.py ├── delays ├── __init__.py ├── backgrounds │ ├── __init__.py │ └── long_task.py ├── huey_instance.py ├── log_config.py └── schedules │ ├── __init__.py │ └── scheduled_task.py ├── models ├── __init__.py ├── base_model.py ├── chat_model.py ├── content_model.py ├── message_model.py └── user_model.py ├── schemas ├── __init__.py ├── chat_schema.py └── user_schema.py ├── urls ├── __init__.py ├── chat_url.py ├── delay_url.py └── user_url.py ├── utils ├── __init__.py ├── conf_util.py ├── db_util.py ├── decorators.py ├── errors.py ├── file_util.py ├── http_util.py ├── json_util.py ├── jwt_util.py ├── log_util.py ├── meta_util.py ├── pools.py ├── pwd_util.py ├── redis_util.py ├── resp_util.py └── tools.py └── views ├── __init__.py ├── chat_view.py ├── delay_view.py └── user_view.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Development 107 | *.db 108 | .idea 109 | logs/* 110 | !logs/.gitkeep 111 | uploads/* 112 | !uploads/.gitkeep 113 | 114 | # Nacos 115 | nacos-data/ 116 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim 2 | 3 | RUN groupadd -r viper && useradd -r -g viper viper 4 | 5 | WORKDIR /home/viper 6 | 7 | COPY pyproject.toml pdm.lock ./ 8 | RUN pip install -U pdm 9 | ENV PDM_CHECK_UPDATE=false 10 | RUN pdm install --check --prod --no-editable 11 | ENV PATH="/home/viper/.venv/bin:$PATH" 12 | 13 | COPY viper viper 14 | COPY app.py docker-entrypoint.sh ./ 15 | 16 | RUN chown -R viper:viper . 17 | USER viper 18 | 19 | ENV FLASK_APP=app.py 20 | ENV FLASK_CONFIG=production 21 | 22 | EXPOSE 8848 23 | ENTRYPOINT ["./docker-entrypoint.sh"] 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sun Geer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Viper 2 | 3 | *A simple chat interface inspired by DeepSeek.* 4 | 5 | > This project is built on the Starlette framework and can be considered a comprehensive backend project template. The main advantage is that it can be conveniently used directly for other new projects. 6 | 7 | No Pydantic, no aiomysql, nothing that is poorly maintained or unstable is referenced. 8 | 9 | ## Installation 10 | 11 | clone: 12 | ``` 13 | $ git clone https://github.com/sungeer/viper.git 14 | $ cd viper 15 | ``` 16 | create & activate virtual env then install dependency: 17 | 18 | with venv + pip: 19 | ``` 20 | $ python -m venv venv 21 | $ source venv/bin/activate # use `venv\Scripts\activate` on Windows 22 | $ pip install -r requirements.txt 23 | ``` 24 | 25 | run: 26 | ``` 27 | $ granian --interface wsgi viper:app 28 | * Running on http://127.0.0.1:8000/ 29 | ``` 30 | 31 | ## License 32 | 33 | This project is licensed under the MIT License (see the 34 | [LICENSE](LICENSE) file for details). 35 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.7' 2 | services: 3 | web: 4 | build: . 5 | command: uvicorn viper:app --host 0.0.0.0 --port 8848 6 | volumes: 7 | - ./app:/app 8 | ports: 9 | - 8000:8000 10 | depends_on: 11 | - redis 12 | worker: 13 | build: . 14 | command: bash start_huey.sh 15 | volumes: 16 | - ./app:/app 17 | environment: 18 | - RUN_HUEY=true # 仅在目标服务器上设置 19 | depends_on: 20 | - redis 21 | redis: 22 | image: "redis:alpine" -------------------------------------------------------------------------------- /docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # start Gunicorn 5 | exec gunicorn --workers 4 --bind 0.0.0.0:5000 --access-logfile - --error-logfile - app:app 6 | -------------------------------------------------------------------------------- /images/extra.sql: -------------------------------------------------------------------------------- 1 | 2 | -- 删除表中所有行,并重置自增列 3 | TRUNCATE TABLE table_name; 4 | 5 | 6 | CREATE TABLE workflow ( 7 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 8 | name VARCHAR(255) NOT NULL COMMENT '流程名称', 9 | description TEXT NOT NULL COMMENT '描述', 10 | 11 | PRIMARY KEY (id) USING BTREE 12 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='流程'; 13 | 14 | 15 | CREATE TABLE node ( 16 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 17 | workflow_id INT(10) NOT NULL COMMENT 't_workflow', 18 | name VARCHAR(255) NOT NULL COMMENT '节点名称', 19 | node_order INT NOT NULL COMMENT '节点顺序', 20 | func VARCHAR(255) NOT NULL COMMENT '调用的接口', 21 | node_group INT(10) DEFAULT 0 COMMENT '并行调用组', 22 | 23 | PRIMARY KEY (id) USING BTREE, 24 | INDEX idx_workflow_id (workflow_id) USING BTREE 25 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='节点'; 26 | 27 | 28 | CREATE TABLE record ( 29 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 30 | workflow_id INT(10) NOT NULL COMMENT 't_workflow', 31 | node_id INT(10) NOT NULL COMMENT 't_node', 32 | status ENUM('pending', 'active', 'done') DEFAULT 'pending' COMMENT '执行状态', 33 | func VARCHAR(255) NOT NULL COMMENT '调用的接口', 34 | result TEXT COMMENT '执行结果', 35 | created_time DATETIME DEFAULT CURRENT_TIMESTAMP, 36 | updated_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 37 | 38 | PRIMARY KEY (id) USING BTREE, 39 | INDEX idx_workflow_id (workflow_id) USING BTREE, 40 | INDEX idx_node_id (node_id) USING BTREE 41 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='节点执行'; 42 | 43 | -- ext 44 | 45 | CREATE TABLE users ( 46 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 47 | name VARCHAR(100) NOT NULL COMMENT '姓名', 48 | gender ENUM('男', '女') NOT NULL COMMENT '性别', 49 | birth DATE NOT NULL COMMENT '出生年份', 50 | phone VARCHAR(20) NOT NULL COMMENT '手机', 51 | address VARCHAR(50) NOT NULL COMMENT '地区', 52 | comment VARCHAR(255) NULL DEFAULT NULL COMMENT '备注', 53 | created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 54 | 55 | PRIMARY KEY (id) USING BTREE, 56 | UNIQUE KEY uniq_phone (phone), 57 | INDEX idx_created_time (created_time) USING BTREE 58 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='客户'; 59 | 60 | 61 | CREATE TABLE user_sources ( 62 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 63 | name VARCHAR(100) NOT NULL COMMENT '来源', 64 | 65 | PRIMARY KEY (id) USING BTREE 66 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='客户来源'; 67 | 68 | 69 | CREATE TABLE user_source_relations ( 70 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 71 | user_id INT(10) NOT NULL COMMENT 't_users', 72 | source_id INT(10) NOT NULL COMMENT 't_user_sources', 73 | 74 | PRIMARY KEY (id) USING BTREE, 75 | INDEX idx_user_id (user_id) USING BTREE, 76 | INDEX idx_source_id (source_id) USING BTREE 77 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='客户-来源关系'; 78 | 79 | 80 | CREATE TABLE jobs ( 81 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 82 | name VARCHAR(100) NOT NULL COMMENT '职业', 83 | 84 | PRIMARY KEY (id) USING BTREE 85 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='职业'; 86 | 87 | 88 | CREATE TABLE goods ( 89 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 90 | name VARCHAR(100) NOT NULL COMMENT '商品', 91 | descr VARCHAR(255) NOT NULL COMMENT '商品描述', 92 | price DOUBLE(10, 2) NOT NULL COMMENT '销售价格', 93 | num INT(11) NOT NULL COMMENT '数量', 94 | unit VARCHAR(255) NOT NULL COMMENT '商品规格', 95 | pack VARCHAR(255) NOT NULL COMMENT '包装单位', 96 | img VARCHAR(255) NOT NULL COMMENT '商品图片', 97 | product_num VARCHAR(255) NOT NULL COMMENT '生产批号', 98 | approve_num VARCHAR(255) NOT NULL COMMENT '批准文号', 99 | status ENUM('可用', '禁用') DEFAULT '可用' COMMENT '状态', 100 | 101 | PRIMARY KEY (id) USING BTREE 102 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='职业'; 103 | 104 | 105 | CREATE TABLE sales ( 106 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 107 | name VARCHAR(100) NOT NULL COMMENT '购买标识', 108 | pay_type VARCHAR(100) NOT NULL COMMENT '支付类型', 109 | created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 110 | 111 | PRIMARY KEY (id) USING BTREE, 112 | INDEX idx_created_time (created_time) USING BTREE 113 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='购买单'; 114 | 115 | 116 | CREATE TABLE sale_goods ( 117 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 118 | sale_id INT(10) NOT NULL COMMENT 't_sales', 119 | goods_id INT(10) NOT NULL COMMENT 't_goods', 120 | price DOUBLE NOT NULL COMMENT '销售单价', 121 | num INT(11) NOT NULL COMMENT '销售数量', 122 | unit VARCHAR(255) NOT NULL COMMENT '商品规格', 123 | total DOUBLE(10, 2) NOT NULL COMMENT '销售总价', 124 | comment VARCHAR(255) NOT NULL COMMENT '备注', 125 | 126 | PRIMARY KEY (id) USING BTREE, 127 | INDEX idx_sale_id (sale_id) USING BTREE, 128 | INDEX idx_goods_id (goods_id) USING BTREE 129 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='商品-购买单关系'; 130 | 131 | 132 | CREATE TABLE user_sales ( 133 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 134 | user_id INT(10) NOT NULL COMMENT 't_users', 135 | sale_id INT(10) NOT NULL COMMENT 't_sales', 136 | 137 | PRIMARY KEY (id) USING BTREE, 138 | INDEX idx_user_id (user_id) USING BTREE, 139 | INDEX idx_sale_id (sale_id) USING BTREE 140 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='客户-购买单关系'; 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /images/pip.txt: -------------------------------------------------------------------------------- 1 | pip cache purge # 清除缓存 2 | 3 | pip freeze > requirements.txt 4 | pip install -r requirements.txt 5 | 6 | 7 | python -m pip install starlette 8 | python -m pip install loguru 9 | python -m pip install uvicorn 10 | python -m pip install mysqlclient 11 | python -m pip install DBUtils 12 | python -m pip install httpx 13 | python -m pip install redis 14 | python -m pip install gunicorn 15 | python -m pip install bcrypt 16 | python -m pip install pyjwt 17 | python -m pip install pycryptodome 18 | python -m pip install jsonschema 19 | python -m pip install huey 20 | 21 | 22 | python -m pip install granian 23 | 24 | granian --interface wsgi krathon:app 25 | 26 | granian --interface wsgi --workers 1 --threads 4 krathon:app 27 | 28 | 29 | 30 | huey_consumer viper.delays.huey_instance.huey 31 | -------------------------------------------------------------------------------- /images/viper.sql: -------------------------------------------------------------------------------- 1 | 2 | CREATE TABLE users ( 3 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 4 | name VARCHAR(100) NOT NULL COMMENT '姓名', 5 | phone VARCHAR(20) NOT NULL COMMENT '手机', 6 | password_hash VARCHAR(255) NOT NULL COMMENT '密码哈希', 7 | is_admin TINYINT(1) NOT NULL DEFAULT 0 COMMENT '是否管理员', 8 | created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', 9 | 10 | PRIMARY KEY (id) USING BTREE, 11 | UNIQUE KEY uniq_phone (phone), 12 | INDEX idx_created_time (created_time) USING BTREE 13 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='用户'; 14 | 15 | 16 | CREATE TABLE chats ( 17 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 18 | conversation_id VARCHAR(100) NOT NULL COMMENT '大模型的会话ID', 19 | title VARCHAR(255) NOT NULL COMMENT '会话标题', 20 | user_id INT(10) NOT NULL COMMENT 't_users', 21 | created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', 22 | 23 | PRIMARY KEY (id) USING BTREE, 24 | INDEX idx_user_id (user_id) USING BTREE, 25 | INDEX idx_created_time (created_time) USING BTREE 26 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='会话'; 27 | 28 | 29 | CREATE TABLE messages ( 30 | id INT(10) NOT NULL AUTO_INCREMENT COMMENT 'ID', 31 | chat_id INT(10) NOT NULL COMMENT 't_chats', 32 | trace_id VARCHAR(100) NOT NULL COMMENT '每次问答的配对', 33 | sender VARCHAR(50) NOT NULL COMMENT 'user or robot', 34 | created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', 35 | 36 | PRIMARY KEY (id) USING BTREE, 37 | INDEX idx_chat_id (chat_id) USING BTREE, 38 | INDEX idx_created_time (created_time) USING BTREE 39 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='消息'; 40 | 41 | 42 | CREATE TABLE contents ( 43 | message_id INT(10) NOT NULL COMMENT 't_messages', 44 | content TEXT NOT NULL COMMENT '消息内容', 45 | 46 | INDEX idx_message_id (message_id) USING BTREE 47 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='消息内容'; 48 | 49 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/logs/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/requirements.txt -------------------------------------------------------------------------------- /start_huey.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 检查是否为目标服务器 4 | if [ "$RUN_HUEY" = "true" ]; then 5 | # 启动 Huey worker 6 | huey_consumer viper.delays.huey_instance.huey 7 | else 8 | echo "Huey is not configured to run on this server." 9 | fi 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/tests/__init__.py -------------------------------------------------------------------------------- /tests/demo_jsonschema.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from jsonschema import validate, FormatChecker 4 | 5 | 6 | # 自定义手机号码验证函数 7 | def validate_phone_number(phone_number): 8 | pattern = r'^1[3-9]\d{9}$' 9 | if not re.match(pattern, phone_number): 10 | raise ValueError(f"'{phone_number}' 不是有效的手机号码") 11 | return True 12 | 13 | 14 | # 自定义邮箱验证函数 15 | def validate_email(email): 16 | pattern = r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$' 17 | if not re.match(pattern, email): 18 | raise ValueError(f"'{email}' 不是有效的邮箱地址") 19 | return True 20 | 21 | 22 | # 注册自定义格式验证器 23 | format_checker = FormatChecker() 24 | format_checker.checks('phone')(validate_phone_number) # 注册 'phone' 格式验证 25 | format_checker.checks('email')(validate_email) # 注册 'email' 格式验证 26 | 27 | # 定义 JSON Schema 28 | schema = { 29 | 'type': 'object', 30 | 'properties': { 31 | 'phone': { 32 | 'type': 'string', 33 | 'format': 'phone' # 使用自定义的 'phone' 格式验证 34 | }, 35 | 'email': { 36 | 'type': 'string', 37 | 'format': 'email' # 使用自定义的 'email' 格式验证 38 | } 39 | }, 40 | 'required': ['phone', 'email'] # 两个字段都是必填的 41 | } 42 | 43 | # 测试数据 44 | data = { 45 | 'phone': '13800138000', # 有效的手机号码 46 | 'email': 'test@example.com' # 有效的邮箱地址 47 | } 48 | 49 | # data = { 50 | # 'phone': '1234567890', # 无效的手机号码 51 | # 'email': 'invalid-email' # 无效的邮箱地址 52 | # } 53 | 54 | # 验证数据 55 | try: 56 | validate(instance=data, schema=schema, format_checker=format_checker) 57 | print('验证成功:数据是有效的手机号码和邮箱地址') 58 | except Exception as e: 59 | print(f'验证失败:{e}') 60 | -------------------------------------------------------------------------------- /tests/get_access_token.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | 3 | 4 | def get_access_token(phone_number, password): 5 | url = 'http://127.0.0.1:8000/user/get-access-token' 6 | data = { 7 | 'phone_number': phone_number, 8 | 'password': password 9 | } 10 | with httpx.Client() as client: 11 | response = client.post(url, json=data) 12 | data = response.json() 13 | data_dict = data['data'] 14 | access_token = data_dict['access_token'] 15 | return access_token 16 | 17 | 18 | if __name__ == '__main__': 19 | phone_number = '' 20 | password = '' 21 | access_token = get_access_token(phone_number, password) 22 | print(access_token) 23 | -------------------------------------------------------------------------------- /tests/get_chat_id.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | 3 | 4 | def get_chat_id(access_token, title): 5 | url = 'http://bebinca.cc/chat/chat-id' 6 | headers = {'Authorization': f'Bearer {access_token}'} 7 | data = {'title': title} 8 | with httpx.Client() as client: 9 | response = client.post(url, json=data, headers=headers) 10 | data = response.json() 11 | chat_id = data['data'] 12 | return chat_id 13 | 14 | 15 | if __name__ == '__main__': 16 | access_token = 'eyJhbGlWG13Nd84nNIY' 17 | title = '你是谁?' 18 | print(get_chat_id(access_token, title)) 19 | -------------------------------------------------------------------------------- /tests/nacos_encrypt.py: -------------------------------------------------------------------------------- 1 | from binascii import b2a_hex 2 | from Crypto.Util.Padding import pad 3 | 4 | from Crypto.Cipher import AES # pip install pycryptodomex 5 | 6 | 7 | def encrypt_sec(plaintext, seckey): 8 | if len(seckey) not in [16, 32]: 9 | raise ValueError(f'The length of the seckey must be 16, or 32, it cannot be {len(seckey)}.') 10 | plaintext_padded = pad(plaintext.encode(), AES.block_size) 11 | aes = AES.new(seckey.encode(), AES.MODE_ECB) 12 | encrypted_data = aes.encrypt(plaintext_padded) 13 | return b2a_hex(encrypted_data).decode() 14 | 15 | 16 | if __name__ == '__main__': 17 | pt = 'admin' 18 | sk = '1f2095a2ec0cefd2c2ab9dd258ad22c3' 19 | es = encrypt_sec(pt, sk) 20 | print(es) 21 | -------------------------------------------------------------------------------- /viper/__init__.py: -------------------------------------------------------------------------------- 1 | from starlette.applications import Starlette 2 | 3 | from viper.core.errors import register_errors 4 | from viper.core.events import register_events 5 | from viper.core.middlewares import register_middlewares 6 | from viper.core.routers import register_routes 7 | 8 | 9 | app = Starlette( 10 | routes = register_routes, 11 | middleware = register_middlewares, 12 | exception_handlers = register_errors, 13 | lifespan = register_events 14 | ) 15 | -------------------------------------------------------------------------------- /viper/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/viper/core/__init__.py -------------------------------------------------------------------------------- /viper/core/errors.py: -------------------------------------------------------------------------------- 1 | from starlette.requests import Request 2 | from starlette.exceptions import HTTPException, WebSocketException 3 | from starlette.websockets import WebSocket 4 | 5 | from viper.utils.log_util import logger 6 | from viper.utils.resp_util import jsonify_exc 7 | from viper.utils.errors import ValidationError, TokenExpiredError, AuthFailureError 8 | 9 | 10 | async def validation_exception_handler(request: Request, exc: ValidationError): 11 | logger.opt(exception=True).warning(exc) 12 | return jsonify_exc(422, exc.message) 13 | 14 | 15 | async def jwt_expired_exception_handler(request: Request, exc: TokenExpiredError): 16 | return jsonify_exc(401, exc.message) 17 | 18 | 19 | async def jwt_failure_exception_handler(request: Request, exc: AuthFailureError): 20 | return jsonify_exc(400, exc.message) 21 | 22 | 23 | async def http_exception_handler(request: Request, exc: HTTPException): 24 | logger.opt(exception=True).warning(exc) 25 | return jsonify_exc(exc.status_code, exc.detail) 26 | 27 | 28 | async def global_exception_handler(request: Request, exc: Exception): 29 | logger.exception(exc) 30 | return jsonify_exc(500) 31 | 32 | 33 | async def websocket_exception_handler(websocket: WebSocket, exc: WebSocketException): 34 | logger.opt(exception=True).warning(exc) 35 | await websocket.close(code=1008) 36 | 37 | 38 | register_errors = { 39 | ValidationError: validation_exception_handler, 40 | TokenExpiredError: jwt_expired_exception_handler, 41 | AuthFailureError: jwt_failure_exception_handler, 42 | HTTPException: http_exception_handler, 43 | WebSocketException: websocket_exception_handler, 44 | Exception: global_exception_handler, 45 | } 46 | -------------------------------------------------------------------------------- /viper/core/events.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | 3 | from viper.utils import http_util, redis_util, pools 4 | 5 | 6 | @asynccontextmanager 7 | async def register_events(app): 8 | pass 9 | yield 10 | await http_util.close_httpx() 11 | await redis_util.close_redis() 12 | pools.close_threads() 13 | -------------------------------------------------------------------------------- /viper/core/middlewares.py: -------------------------------------------------------------------------------- 1 | from starlette.requests import Request 2 | from starlette.middleware import Middleware 3 | from starlette.middleware.cors import CORSMiddleware 4 | from starlette.middleware.authentication import AuthenticationMiddleware 5 | from starlette.authentication import BaseUser, AuthCredentials, AuthenticationBackend 6 | 7 | from viper.utils import jwt_util 8 | from viper.models.user_model import UserModel 9 | 10 | # cors 11 | origins = [ 12 | 'http://127.0.0.1:8000', # 后端应用使用的端口 13 | 'http://127.0.0.1:8080', # 前端应用使用的端口 14 | ] 15 | 16 | 17 | # auth_required 18 | class User(BaseUser): 19 | 20 | def __init__(self, user_id: int, username, phone): 21 | self.user_id = user_id 22 | self.username = username 23 | self.phone = phone 24 | 25 | @property 26 | def is_authenticated(self) -> bool: 27 | return True 28 | 29 | @property 30 | def display_name(self) -> str: 31 | return self.username 32 | 33 | 34 | class JWTAuthBackend(AuthenticationBackend): 35 | 36 | async def authenticate(self, request: Request): 37 | if 'Authorization' not in request.headers: 38 | return None 39 | 40 | auth_header = request.headers['Authorization'] 41 | try: 42 | scheme, token = auth_header.split() 43 | if scheme.lower() != 'bearer': 44 | return None 45 | except ValueError: 46 | return None 47 | 48 | user_id = jwt_util.verify_token(token) 49 | db_user = await UserModel().get_user_by_id(user_id) 50 | username = db_user['name'] 51 | phone = db_user['phone'] 52 | is_admin = db_user['is_admin'] 53 | 54 | if is_admin: 55 | scopes = ['authenticated', 'admin'] 56 | else: 57 | scopes = ['authenticated'] 58 | 59 | return AuthCredentials(scopes), User(user_id, username, phone) 60 | 61 | 62 | register_middlewares = [ 63 | Middleware( 64 | CORSMiddleware, # type: ignore 65 | allow_origins=origins, 66 | allow_credentials=True, 67 | allow_methods=['*'], 68 | allow_headers=['*'], 69 | ), 70 | Middleware( 71 | AuthenticationMiddleware, # type: ignore 72 | backend=JWTAuthBackend() 73 | ), 74 | ] 75 | -------------------------------------------------------------------------------- /viper/core/routers.py: -------------------------------------------------------------------------------- 1 | from starlette.routing import Mount 2 | 3 | from viper.urls import chat_url, user_url, delay_url 4 | 5 | register_routes = [ 6 | Mount('/chat', app=chat_url.chat_url), 7 | Mount('/user', app=user_url.user_url), 8 | Mount('/delay', app=delay_url.delay_url) 9 | ] 10 | -------------------------------------------------------------------------------- /viper/core/settings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from starlette.config import Config 4 | 5 | from viper.utils.conf_util import ConfigDetector 6 | 7 | CURRENT_DIR = Path(__file__).resolve() # 当前文件 的 绝对路径 8 | BASE_DIR = CURRENT_DIR.parent.parent.parent 9 | 10 | config = Config('.env') 11 | 12 | DEBUG = config('DEBUG', cast=bool, default=False) 13 | 14 | if DEBUG: 15 | conf_dir = BASE_DIR / 'nacos-data' 16 | CONF = ConfigDetector(conf_dir) 17 | else: 18 | CONF = ConfigDetector( 19 | nacos_addr=config('NACOS_ADDR'), 20 | namespace=config('NACOS_NAMESPACE') 21 | ) 22 | -------------------------------------------------------------------------------- /viper/delays/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/viper/delays/__init__.py -------------------------------------------------------------------------------- /viper/delays/backgrounds/__init__.py: -------------------------------------------------------------------------------- 1 | from viper.utils.pools import run_in_thread_pool_delay 2 | from viper.delays.backgrounds import long_task 3 | 4 | 5 | async def delay_long_task(data): 6 | await run_in_thread_pool_delay(long_task.long_task, data) 7 | -------------------------------------------------------------------------------- /viper/delays/backgrounds/long_task.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from viper.delays.log_config import logger 4 | from viper.delays.huey_instance import huey 5 | 6 | 7 | @huey.task() 8 | def long_task(data: str): 9 | logger.info('Starting long-running task...') 10 | time.sleep(5) 11 | logger.info('Long-running task completed!') 12 | return f'Processed data: {data}' 13 | -------------------------------------------------------------------------------- /viper/delays/huey_instance.py: -------------------------------------------------------------------------------- 1 | from huey import RedisHuey 2 | from redis import ConnectionPool 3 | 4 | from viper.core import settings 5 | from viper.delays.log_config import logger # noqa 配置日志记录器 6 | 7 | redis_pool = ConnectionPool( 8 | host=settings.CONF.get_conf('REDIS', 'HOST'), 9 | port=6379, 10 | # password=settings.CONF.get_int_conf('REDIS', 'PORT'), 11 | db=0 12 | ) 13 | 14 | huey = RedisHuey( 15 | name = settings.CONF.get_conf('APP', 'NAME'), 16 | connection_pool=redis_pool 17 | ) 18 | 19 | from viper.delays import schedules # noqa 在消费者 启动时 加载 定时任务 20 | -------------------------------------------------------------------------------- /viper/delays/log_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from viper.core import settings 5 | 6 | log_dir = Path(settings.BASE_DIR) / 'logs' 7 | log_dir.mkdir(parents=True, exist_ok=True) 8 | log_file = log_dir / 'huey.log' 9 | 10 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 11 | 12 | file_handler = logging.FileHandler(log_file, encoding='utf-8') 13 | file_handler.setFormatter(formatter) 14 | 15 | logger = logging.getLogger('huey') 16 | 17 | logger.setLevel(logging.INFO) 18 | logger.addHandler(file_handler) 19 | -------------------------------------------------------------------------------- /viper/delays/schedules/__init__.py: -------------------------------------------------------------------------------- 1 | from viper.delays.schedules import scheduled_task 2 | -------------------------------------------------------------------------------- /viper/delays/schedules/scheduled_task.py: -------------------------------------------------------------------------------- 1 | from huey import crontab 2 | 3 | from viper.delays.log_config import logger 4 | from viper.delays.huey_instance import huey 5 | 6 | 7 | @huey.periodic_task(crontab(minute='0', hour='3')) # 每天凌晨3点执行 8 | def scheduled_task(): 9 | logger.info('Scheduled task running...') 10 | 11 | 12 | @huey.periodic_task(crontab()) # 每分钟执行 13 | def every_minute_task(): 14 | logger.info('Task running every minute...') 15 | -------------------------------------------------------------------------------- /viper/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/viper/models/__init__.py -------------------------------------------------------------------------------- /viper/models/base_model.py: -------------------------------------------------------------------------------- 1 | from viper.utils.db_util import create_dbconn 2 | 3 | 4 | class BaseModel: 5 | 6 | def __init__(self): 7 | self.cursor = None 8 | self._conn = None 9 | 10 | def conn(self): 11 | if not self.cursor: 12 | self._conn = create_dbconn() 13 | self.cursor = self._conn.cursor() 14 | 15 | def rollback(self): 16 | self._conn.rollback() 17 | 18 | def commit(self): 19 | try: 20 | self._conn.commit() 21 | except Exception: 22 | self.rollback() 23 | raise 24 | 25 | def begin(self): 26 | self._conn.begin() 27 | 28 | def close(self): 29 | try: 30 | if self.cursor: 31 | self.cursor.execute('UNLOCK TABLES;') 32 | self.cursor.close() 33 | if self._conn: 34 | self._conn.close() 35 | finally: 36 | self.cursor = None 37 | self._conn = None 38 | 39 | def execute(self, sql_str, values=None): 40 | try: 41 | self.cursor.execute(sql_str, values) 42 | except Exception: 43 | self.rollback() 44 | self.close() 45 | raise 46 | 47 | def executemany(self, sql_str, values=None): 48 | try: 49 | self.cursor.executemany(sql_str, values) 50 | except Exception: 51 | self.rollback() 52 | self.close() 53 | raise 54 | -------------------------------------------------------------------------------- /viper/models/chat_model.py: -------------------------------------------------------------------------------- 1 | from viper.models.base_model import BaseModel 2 | from viper.utils.decorators import sync_to_async_db 3 | 4 | 5 | class ChatModel(BaseModel): 6 | 7 | @sync_to_async_db 8 | def add_chat(self, conversation_id, title, user_id): 9 | sql_str = ''' 10 | INSERT INTO 11 | chats 12 | (conversation_id, title, user_id) 13 | VALUES 14 | (%s, %s, %s) 15 | ''' 16 | values = (conversation_id, title, user_id) 17 | self.conn() 18 | self.execute(sql_str, values) 19 | self.commit() 20 | lastrowid = self.cursor.lastrowid 21 | self.close() 22 | return lastrowid 23 | 24 | @sync_to_async_db 25 | def get_chats(self, user_id): 26 | sql_str = ''' 27 | SELECT 28 | id, conversation_id, title, created_time 29 | FROM 30 | chats 31 | WHERE 32 | user_id = %s 33 | LIMIT 100 34 | ''' 35 | self.conn() 36 | self.execute(sql_str, (user_id,)) 37 | chats = self.cursor.fetchall() 38 | self.close() 39 | return chats 40 | 41 | @sync_to_async_db 42 | def get_chat_by_conversation(self, conversation_id): 43 | sql_str = ''' 44 | SELECT 45 | id, conversation_id, title, created_time 46 | FROM 47 | chats 48 | WHERE 49 | conversation_id = %s 50 | ''' 51 | self.conn() 52 | self.execute(sql_str, (conversation_id,)) 53 | chat_info = self.cursor.fetchone() 54 | self.close() 55 | return chat_info 56 | -------------------------------------------------------------------------------- /viper/models/content_model.py: -------------------------------------------------------------------------------- 1 | from viper.models.base_model import BaseModel 2 | from viper.utils.decorators import sync_to_async_db 3 | 4 | 5 | class ContentModel(BaseModel): 6 | 7 | @sync_to_async_db 8 | def add_content(self, message_id, content): 9 | sql_str = ''' 10 | INSERT INTO 11 | contents 12 | (message_id, content) 13 | VALUES 14 | (%s, %s) 15 | ''' 16 | self.conn() 17 | self.execute(sql_str, (message_id, content)) 18 | self.commit() 19 | last_row_id = self.cursor.lastrowid 20 | self.close() 21 | return last_row_id 22 | -------------------------------------------------------------------------------- /viper/models/message_model.py: -------------------------------------------------------------------------------- 1 | from viper.models.base_model import BaseModel 2 | from viper.utils.decorators import sync_to_async_db 3 | 4 | 5 | class MessageModel(BaseModel): 6 | 7 | @sync_to_async_db 8 | def add_message(self, chat_id, trace_id, sender): 9 | sql_str = ''' 10 | INSERT INTO 11 | messages 12 | (chat_id, trace_id, sender) 13 | VALUES 14 | (%s, %s, %s) 15 | ''' 16 | self.conn() 17 | self.execute(sql_str, (chat_id, trace_id, sender)) 18 | self.commit() 19 | lastrowid = self.cursor.lastrowid 20 | self.close() 21 | return lastrowid 22 | 23 | @sync_to_async_db 24 | def get_messages(self, chat_id): 25 | sql_str = ''' 26 | SELECT 27 | trace_id, 28 | MAX(CASE WHEN sender = 'user' THEN content END) AS 问题, 29 | MAX(CASE WHEN sender = 'robot' THEN content END) AS 回答, 30 | MAX(CASE WHEN sender = 'user' THEN created_time END) AS 问题时间, 31 | MAX(CASE WHEN sender = 'robot' THEN created_time END) AS 回答时间 32 | FROM 33 | ( 34 | SELECT 35 | m.trace_id, ct.content, m.sender, m.created_time 36 | FROM 37 | chats c 38 | LEFT JOIN messages M ON c.id = m.chat_id 39 | LEFT JOIN contents CT ON m.id = ct.message_id 40 | WHERE 41 | c.conversation_id = %s 42 | ) AS subquery 43 | GROUP BY 44 | trace_id 45 | LIMIT 100; 46 | ''' 47 | self.conn() 48 | self.execute(sql_str, (chat_id,)) 49 | chats = self.cursor.fetchall() 50 | self.close() 51 | return chats 52 | -------------------------------------------------------------------------------- /viper/models/user_model.py: -------------------------------------------------------------------------------- 1 | from viper.models.base_model import BaseModel 2 | from viper.utils.decorators import sync_to_async_db 3 | 4 | 5 | class UserModel(BaseModel): 6 | 7 | @sync_to_async_db 8 | def get_user_by_phone(self, phone_number): 9 | sql_str = ''' 10 | SELECT 11 | id, name, phone, password_hash, is_admin, created_time 12 | FROM 13 | users 14 | WHERE 15 | phone = %s 16 | ''' 17 | self.conn() 18 | self.execute(sql_str, (phone_number,)) 19 | user_info = self.cursor.fetchone() 20 | self.close() 21 | return user_info 22 | 23 | @sync_to_async_db 24 | def get_user_by_id(self, user_id): 25 | sql_str = ''' 26 | SELECT 27 | id, name, phone, is_admin, created_time 28 | FROM 29 | users 30 | WHERE 31 | id = %s 32 | ''' 33 | self.conn() 34 | self.execute(sql_str, (user_id,)) 35 | user_info = self.cursor.fetchone() 36 | self.close() 37 | return user_info 38 | -------------------------------------------------------------------------------- /viper/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from jsonschema import validate, FormatChecker, ValidationError as ValidationException 4 | 5 | from viper.utils.errors import ValidationError 6 | 7 | 8 | # 自定义手机号码验证函数 9 | def validate_phone(phone_number): 10 | pattern = r'^1[3-9]\d{9}$' 11 | if not re.match(pattern, phone_number): 12 | raise ValueError(f"'{phone_number}' 不是有效的手机号码") 13 | return True 14 | 15 | 16 | # 自定义邮箱验证函数 17 | def validate_email(email): 18 | pattern = r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$' 19 | if not re.match(pattern, email): 20 | raise ValueError(f"'{email}' 不是有效的邮箱地址") 21 | return True 22 | 23 | 24 | # 注册自定义格式验证器 25 | format_checker = FormatChecker() 26 | format_checker.checks('phone')(validate_phone) # 注册 'phone' 格式验证 27 | format_checker.checks('email')(validate_email) # 注册 'email' 格式验证 28 | 29 | 30 | # 验证数据 31 | def validator(data, schema, check_format=None): 32 | try: 33 | if check_format: 34 | validate(instance=data, schema=schema, format_checker=format_checker) 35 | return data 36 | validate(instance=data, schema=schema) 37 | return data 38 | except ValidationException as exc: 39 | raise ValidationError(exc.message) 40 | -------------------------------------------------------------------------------- /viper/schemas/chat_schema.py: -------------------------------------------------------------------------------- 1 | chat_id_schema = { 2 | 'type': 'object', 3 | 'properties': { 4 | 'title': { 5 | 'type': 'string', 6 | 'minLength': 1, # 不能为空字符串 7 | 'maxLength': 20 8 | } 9 | }, 10 | 'required': ['title'] # 必填字段 11 | } 12 | 13 | send_message_schema = { 14 | 'type': 'object', 15 | 'properties': { 16 | 'conversation_id': { 17 | 'type': 'string', 18 | 'minLength': 1 19 | }, 20 | 'content': { 21 | 'type': 'string', 22 | 'minLength': 1 23 | } 24 | }, 25 | 'required': ['conversation_id', 'content'] 26 | } 27 | 28 | get_messages_schema = { 29 | 'type': 'object', 30 | 'properties': { 31 | 'conversation_id': { 32 | 'type': 'string', 33 | 'minLength': 1 34 | } 35 | }, 36 | 'required': ['conversation_id'] 37 | } 38 | -------------------------------------------------------------------------------- /viper/schemas/user_schema.py: -------------------------------------------------------------------------------- 1 | access_token_schema = { 2 | 'type': 'object', 3 | 'properties': { 4 | 'phone_number': { 5 | 'type': 'string', 6 | 'pattern': r'^1[3-9]\d{9}$' # 中国大陆手机号码正则 7 | }, 8 | 'password': { 9 | 'type': 'string', 10 | 'minLength': 6, # 密码最小长度 11 | 'maxLength': 12 # 密码最大长度 12 | } 13 | }, 14 | 'required': ['phone_number', 'password'] # 必填字段 15 | } 16 | -------------------------------------------------------------------------------- /viper/urls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/viper/urls/__init__.py -------------------------------------------------------------------------------- /viper/urls/chat_url.py: -------------------------------------------------------------------------------- 1 | from starlette.routing import Router 2 | 3 | from viper.views import chat_view 4 | 5 | chat_url = Router() 6 | 7 | chat_url.add_route('/chat-id', chat_view.get_chat_id, ['POST']) 8 | chat_url.add_route('/send-message', chat_view.send_message, ['POST']) 9 | chat_url.add_route('/chats', chat_view.get_chats, ['POST']) 10 | chat_url.add_route('/messages', chat_view.get_messages, ['POST']) 11 | -------------------------------------------------------------------------------- /viper/urls/delay_url.py: -------------------------------------------------------------------------------- 1 | from starlette.routing import Router 2 | 3 | from viper.views import delay_view 4 | 5 | delay_url = Router() 6 | 7 | delay_url.add_route('/start-task', delay_view.start_task, ['POST']) 8 | delay_url.add_route('/task-status', delay_view.check_task_status, ['POST']) 9 | -------------------------------------------------------------------------------- /viper/urls/user_url.py: -------------------------------------------------------------------------------- 1 | from starlette.routing import Router 2 | 3 | from viper.views import user_view 4 | 5 | user_url = Router() 6 | 7 | user_url.add_route('/get-access-token', user_view.get_access_token, ['POST']) 8 | -------------------------------------------------------------------------------- /viper/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/viper/utils/__init__.py -------------------------------------------------------------------------------- /viper/utils/conf_util.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import time 3 | from pathlib import Path 4 | from binascii import a2b_hex 5 | 6 | import httpx 7 | from Crypto.Cipher import AES # pip install pycryptodome 8 | 9 | 10 | class ConfigDetector: 11 | 12 | def __init__(self, conf_dir=None, nacos_addr=None, namespace='prd', nacos_user=None, nacos_passwd=None): 13 | self.conf_dir = conf_dir 14 | self.nacos_addr = nacos_addr 15 | self.namespace = namespace 16 | self.nacos_user = nacos_user 17 | self.nacos_passwd = nacos_passwd 18 | self.load_conf() 19 | 20 | def load_conf(self): 21 | max_times = 10 22 | for i in range(max_times): 23 | try: 24 | self._load_conf() 25 | except (Exception,): 26 | time.sleep(1) 27 | continue 28 | else: 29 | break 30 | 31 | def _load_conf(self): 32 | self.config = configparser.ConfigParser() 33 | self.key = configparser.ConfigParser() 34 | if self.conf_dir: 35 | self.config.read(Path(self.conf_dir) / 'default_conf.ini') 36 | self.key.read(Path(self.conf_dir) / 'seckey_conf.ini') 37 | else: 38 | conf = self._get_client('viper_default_conf.ini', 'DEFAULT_GROUP') 39 | salt = self._get_client('viper_seckey_conf.ini', 'DEFAULT_GROUP') 40 | self.config.read_string(conf) 41 | self.key.read_string(salt) 42 | 43 | def _get_client(self, data_id, group): 44 | url = f'{self.nacos_addr}/nacos/v2/cs/config' 45 | params = { 46 | 'dataId': data_id, 47 | 'group': group, 48 | } 49 | response = httpx.get(url, params=params, timeout=30.0) 50 | return response.text 51 | 52 | def get_conf(self, section='DEFAULT', key='DEFAULT'): 53 | value = self.config.get(section, key) 54 | return value 55 | 56 | def get_sec_conf(self, section='DEFAULT', key='DEFAULT'): 57 | text = self.get_conf(section, key) 58 | seckey = self.key.get(section, key) 59 | if len(seckey) not in [32, 16]: 60 | raise ValueError(f'The length of the seckey must be 16 or 32, it cannot be {len(seckey)}.') 61 | aes = AES.new(seckey.encode(), AES.MODE_ECB) 62 | sec_conf = str(aes.decrypt(a2b_hex(text)), encoding='utf-8', errors='ignore') 63 | return sec_conf.strip() 64 | 65 | def get_boolean_conf(self, section='DEFAULT', key='DEFAULT'): 66 | value = self.config.getboolean(section, key) 67 | return value 68 | 69 | def get_int_conf(self, section='DEFAULT', key='DEFAULT'): 70 | value = self.config.getint(section, key) 71 | return value 72 | 73 | def get_float_conf(self, section='DEFAULT', key='DEFAULT'): 74 | value = self.config.getfloat(section, key) 75 | return value 76 | -------------------------------------------------------------------------------- /viper/utils/db_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | 4 | import MySQLdb 5 | from MySQLdb.cursors import DictCursor 6 | from dbutils.pooled_db import PooledDB 7 | 8 | from viper.core import settings 9 | 10 | dbpool = PooledDB( 11 | creator=MySQLdb, 12 | maxcached=5, 13 | host=settings.CONF.get_conf('DATABASE', 'HOST'), 14 | port=settings.CONF.get_int_conf('DATABASE', 'PORT'), 15 | db=settings.CONF.get_conf('DATABASE', 'NAME'), 16 | user=settings.CONF.get_conf('DATABASE', 'USER'), 17 | passwd=settings.CONF.get_sec_conf('DATABASE', 'PASSWD'), 18 | charset='utf8mb4', 19 | cursorclass=DictCursor 20 | ) 21 | 22 | 23 | def create_dbconn(): 24 | conn = dbpool.connection() # 从连接池中获取一个连接 25 | return conn 26 | 27 | 28 | class DBConnection: 29 | def __init__(self): 30 | self.dbconn = None 31 | self.cursor = None 32 | 33 | def commit(self): 34 | self.dbconn.commit() 35 | 36 | def __enter__(self): 37 | if not self.dbconn: 38 | self.dbconn = create_dbconn() 39 | if not self.cursor: 40 | self.cursor = self.dbconn.cursor() 41 | return self 42 | 43 | def __exit__(self, exc_type, exc_val, exc_tb): 44 | try: 45 | if self.cursor: 46 | self.cursor.execute('UNLOCK TABLES;') 47 | self.cursor.close() 48 | if self.dbconn: 49 | self.dbconn.close() 50 | finally: 51 | self.dbconn = None 52 | self.cursor = None 53 | 54 | 55 | class Common: 56 | 57 | @staticmethod 58 | def parse_limit_str(page_info=None): 59 | if page_info is None: 60 | page_info = {} 61 | page = int(page_info.get('page', 1)) 62 | page_size = int(page_info.get('rows', 20)) 63 | limit_str = ' LIMIT %s, %s ' % ((page - 1) * page_size, page_size) 64 | return limit_str 65 | 66 | @staticmethod 67 | def parse_update_str(table, p_key, p_id, update_dict): 68 | sql_str = ' UPDATE %s SET ' % (table,) 69 | temp_str = [] 70 | sql_values = [] 71 | for key, value in update_dict.items(): 72 | temp_str.append(key + ' = %s ') 73 | sql_values.append(value) 74 | sql_str += ', '.join(r for r in temp_str) + ' WHERE ' + p_key + ' = %s ' 75 | sql_values.append(p_id) 76 | return sql_str, sql_values 77 | 78 | @staticmethod 79 | def parse_where_str(filter_fields, request_data): 80 | if not isinstance(filter_fields, tuple) and not isinstance(filter_fields, list): 81 | filter_fields = (filter_fields,) 82 | where_str = ' WHERE 1 = %s ' 83 | where_values = [1] 84 | for key in filter_fields: 85 | value = request_data.get(key) 86 | if value: 87 | where_str += ' AND ' + key + ' = %s ' 88 | where_values.append(value) 89 | if not where_values: 90 | where_values = None 91 | return where_str, where_values 92 | 93 | @staticmethod 94 | def parse_where_like_str(filter_fields, request_data): 95 | if not isinstance(filter_fields, tuple) and not isinstance(filter_fields, list): 96 | filter_fields = (filter_fields,) 97 | where_str = ' WHERE 1 = %s ' 98 | where_values = [1] 99 | for key in filter_fields: 100 | value = request_data.get(key) 101 | if value: 102 | where_str += ' AND ' + key + ' LIKE %s ' 103 | where_values.append('%%%%%s%%%%' % value) 104 | if not where_values: 105 | where_values = None 106 | return where_str, where_values 107 | 108 | @staticmethod 109 | def parse_count_str(sql_str, truncate=False): 110 | if truncate: 111 | if 'GROUP BY' in sql_str: 112 | sql_str = 'SELECT COUNT(*) total FROM (%s) AS TEMP' % sql_str 113 | else: 114 | sql_str = re.sub(r'SELECT[\s\S]*?FROM', 'SELECT COUNT(*) total FROM', sql_str, count=1) 115 | if 'ORDER BY' in sql_str: 116 | sql_str = sql_str[:sql_str.find('ORDER BY')] 117 | if 'LIMIT' in sql_str: 118 | sql_str = sql_str[:sql_str.find('LIMIT')] 119 | return sql_str 120 | 121 | @staticmethod 122 | def get_page_info(total, page=1, per_page=20): 123 | pages = math.ceil(total / per_page) 124 | next_num = page + 1 if page < pages else None 125 | has_next = page < pages 126 | prev_num = page - 1 if page > 1 else None 127 | has_prev = page > 1 128 | page_info = { 129 | 'page': page, 130 | 'per_page': per_page, # 每页显示的记录数 131 | 'pages': pages, # 总页数 132 | 'total': total, 133 | 'next_num': next_num, 134 | 'has_next': has_next, 135 | 'prev_num': prev_num, 136 | 'has_prev': has_prev 137 | } 138 | return page_info 139 | -------------------------------------------------------------------------------- /viper/utils/decorators.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from viper.schemas import validator 4 | from viper.utils.resp_util import abort 5 | from viper.utils.pools import run_in_thread_pool_db 6 | 7 | 8 | def validate_request(schema): 9 | def decorator(func): 10 | @wraps(func) 11 | async def decorated_function(request, *args, **kwargs): 12 | data = await request.json() 13 | validator(data, schema) 14 | return await func(request, *args, **kwargs) 15 | 16 | return decorated_function 17 | 18 | return decorator 19 | 20 | 21 | def permission_required(permission_name): 22 | def decorator(func): 23 | @wraps(func) 24 | async def decorated_function(request, *args, **kwargs): 25 | perm = request.state.has_perm 26 | if perm not in (permission_name,): 27 | return abort(403) 28 | return await func(request, *args, **kwargs) 29 | 30 | return decorated_function 31 | 32 | return decorator 33 | 34 | 35 | def admin_required(func): # @admin_required 36 | return permission_required('admin')(func) 37 | 38 | 39 | def sync_to_async_db(func): 40 | @wraps(func) 41 | async def async_run_in_thread_pool(*args, **kwargs): 42 | return await run_in_thread_pool_db(func, *args, **kwargs) 43 | 44 | return async_run_in_thread_pool 45 | -------------------------------------------------------------------------------- /viper/utils/errors.py: -------------------------------------------------------------------------------- 1 | class ValidationError(Exception): 2 | 3 | def __init__(self, message): 4 | self.message = message 5 | super().__init__(message) 6 | 7 | 8 | class TokenExpiredError(Exception): 9 | 10 | def __init__(self, message): 11 | self.message = message 12 | super().__init__(message) 13 | 14 | 15 | class AuthFailureError(Exception): 16 | 17 | def __init__(self, message): 18 | self.message = message 19 | super().__init__(message) 20 | -------------------------------------------------------------------------------- /viper/utils/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class FileRead: 5 | 6 | def __init__(self, path, file_name): 7 | self.xml = '' 8 | file = os.path.normpath(os.path.join(path, file_name)) 9 | # todo: sync to thread pool 10 | with open(file, mode='r', encoding='utf-8') as f: 11 | self.xml = f.read() 12 | 13 | @property 14 | def content(self): 15 | return self.xml 16 | -------------------------------------------------------------------------------- /viper/utils/http_util.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | 3 | from viper.core import settings 4 | 5 | limits = httpx.Limits( 6 | max_keepalive_connections=settings.CONF.get_int_conf('HTTPX', 'POOL_SIZE_COMMON'), 7 | max_connections=settings.CONF.get_int_conf('HTTPX', 'MAX_OVERFLOW_COMMON') 8 | ) 9 | 10 | timeout = httpx.Timeout( 11 | connect=2.0, 12 | read=5.0, # 从发送请求到接收完整响应数据的时间 13 | write=2.0, 14 | pool=2.0 15 | ) 16 | 17 | httpx_common = httpx.AsyncClient(limits=limits, timeout=timeout) 18 | 19 | # 流式 20 | limits = httpx.Limits( 21 | max_keepalive_connections=settings.CONF.get_int_conf('HTTPX', 'POOL_SIZE_STREAM'), 22 | max_connections=settings.CONF.get_int_conf('HTTPX', 'MAX_OVERFLOW_STREAM') 23 | ) 24 | 25 | timeout = httpx.Timeout( 26 | connect=3.0, # 建立连接的时间 27 | read=10.0, # 等待每个数据块的时间 28 | write=3.0, # 向服务器发送完数据的时间 29 | pool=2.0 # 从连接池中获取连接的时间 30 | ) 31 | 32 | httpx_stream = httpx.AsyncClient(limits=limits, timeout=timeout) 33 | 34 | 35 | async def close_httpx(): 36 | await httpx_common.aclose() 37 | await httpx_stream.aclose() 38 | -------------------------------------------------------------------------------- /viper/utils/json_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime, date 3 | from decimal import Decimal 4 | 5 | from starlette.responses import JSONResponse 6 | 7 | 8 | def dict_to_json(data): 9 | return json.dumps(data, cls=JsonExtendEncoder, ensure_ascii=False) 10 | 11 | 12 | def dict_to_json_stream(data): 13 | return json.dumps(data, cls=JsonExtendEncoder, ensure_ascii=False).encode('utf-8') 14 | 15 | 16 | def json_to_dict(json_data): 17 | return json.loads(json_data) 18 | 19 | 20 | class JsonExtendEncoder(json.JSONEncoder): 21 | 22 | def default(self, obj): 23 | if isinstance(obj, (tuple, list, datetime)): 24 | return obj.strftime('%Y-%m-%d %H:%M:%S') 25 | elif isinstance(obj, date): 26 | return obj.strftime('%Y-%m-%d') 27 | elif isinstance(obj, Decimal): 28 | return float(obj) 29 | elif isinstance(obj, bytes): 30 | return obj.decode('utf-8') 31 | return super().default(obj) 32 | 33 | 34 | class JsonExtendResponse(JSONResponse): 35 | 36 | def render(self, content): 37 | return dict_to_json_stream(content) 38 | -------------------------------------------------------------------------------- /viper/utils/jwt_util.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | import bcrypt # python -m pip install bcrypt 4 | import jwt # python -m pip install pyjwt 5 | from jwt.exceptions import ExpiredSignatureError, InvalidTokenError 6 | 7 | from viper.core import settings 8 | from viper.utils.errors import TokenExpiredError, AuthFailureError 9 | 10 | 11 | def set_password(password): 12 | hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()) 13 | return hashed_password.decode('utf-8') 14 | 15 | 16 | def validate_password(plain_password, hashed_password): 17 | return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8')) 18 | 19 | 20 | def generate_token(data: dict): 21 | token_data = data.copy() # data = {'id': 3} 22 | expiration_delta = timedelta(minutes=settings.CONF.get_int_conf('JWT', 'EXPIRE_MINUTES')) 23 | expiration_time = datetime.now() + expiration_delta 24 | token_data.update({'exp': expiration_time.timestamp()}) 25 | encoded_token = jwt.encode( 26 | token_data, settings.CONF.get_conf('JWT', 'SEC_KEY'), 27 | algorithm=settings.CONF.get_conf('JWT', 'ALGORITHM') 28 | ) 29 | return encoded_token 30 | 31 | 32 | def verify_token(token: str): 33 | secret_key = settings.CONF.get_conf('JWT', 'SEC_KEY') 34 | jwt_algorithm = settings.CONF.get_conf('JWT', 'ALGORITHM') 35 | try: 36 | payload = jwt.decode(token, secret_key, algorithms=[jwt_algorithm]) 37 | user_id = payload.get('id') 38 | if not user_id: 39 | raise AuthFailureError('Invalid JWT: missing field id') 40 | except ExpiredSignatureError: 41 | raise TokenExpiredError('Token has expired') 42 | except InvalidTokenError as exc: 43 | raise AuthFailureError(f'Invalid token: {str(exc)}') 44 | return user_id 45 | -------------------------------------------------------------------------------- /viper/utils/log_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from loguru import logger 5 | 6 | from viper.core import settings 7 | 8 | log_dir = Path(settings.BASE_DIR) / 'logs' 9 | log_dir.mkdir(parents=True, exist_ok=True) 10 | 11 | access_path = log_dir / 'access.log' 12 | error_path = log_dir / 'error.log' 13 | 14 | logger.remove() 15 | 16 | logger.add( 17 | access_path, 18 | rotation='10MB', 19 | format='{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}', 20 | encoding='utf-8', 21 | enqueue=True, # 启用异步日志处理 22 | level='DEBUG', 23 | diagnose=False, # 关闭变量值 24 | backtrace=False, # 关闭完整堆栈跟踪 25 | colorize=False, 26 | filter=lambda record: record["level"].no < 40 # 过滤掉 ERROR 及以上级别的日志 27 | ) 28 | 29 | logger.add( 30 | error_path, 31 | rotation='10MB', # 日志文件达到 10MB 时轮转 32 | format='{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}', 33 | encoding='utf-8', 34 | enqueue=True, 35 | diagnose=False, 36 | backtrace=False, 37 | colorize=False, 38 | level='ERROR' 39 | ) 40 | 41 | logger.add( 42 | sink=sys.stdout, # 输出到标准输出流 43 | format='{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}', # 日志格式 44 | level='DEBUG', 45 | diagnose=False, 46 | backtrace=False, 47 | colorize=False, 48 | filter=lambda record: record["level"].no < 40, 49 | enqueue=True 50 | ) 51 | 52 | logger.add( 53 | sink=sys.stderr, # 输出到标准错误流 54 | format='{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}', 55 | level='ERROR', 56 | diagnose=False, 57 | backtrace=False, 58 | colorize=False, 59 | enqueue=True 60 | ) -------------------------------------------------------------------------------- /viper/utils/meta_util.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from viper.utils.pools import run_in_thread_pool_db 4 | 5 | 6 | def async_run_in_thread_pool_db(func): 7 | @wraps(func) 8 | async def wrapper(*args, **kwargs): 9 | return await run_in_thread_pool_db(func, *args, **kwargs) 10 | 11 | return wrapper 12 | 13 | 14 | class AsyncMethodsMeta(type): 15 | 16 | def __new__(cls, name, bases, dct): 17 | for attr_name, attr_value in dct.items(): 18 | if callable(attr_value) and not attr_name.startswith('__'): 19 | dct[attr_name] = async_run_in_thread_pool_db(attr_value) # 为同步方法添加装饰器 20 | return super().__new__(cls, name, bases, dct) 21 | 22 | 23 | if __name__ == '__main__': 24 | class ParentClass(metaclass=AsyncMethodsMeta): # class ParentClass: pass 25 | 26 | @staticmethod 27 | def common_method(): 28 | print('父类方法运行') 29 | 30 | 31 | class ChildClass(ParentClass): # class ChildClass(ParentClass, metaclass=AsyncMethodsMeta): pass 32 | 33 | @staticmethod 34 | def custom_sync_method(x, y): 35 | print(f'在子类中运行同步方法: {x} + {y} = {x + y}') 36 | return x + y 37 | -------------------------------------------------------------------------------- /viper/utils/pools.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from concurrent.futures import ThreadPoolExecutor 3 | from functools import partial 4 | 5 | thread_pool_db = ThreadPoolExecutor(max_workers=2) 6 | 7 | 8 | async def run_in_thread_pool_db(func, *args, **kwargs): 9 | loop = asyncio.get_running_loop() 10 | bound_func = partial(func, *args, **kwargs) 11 | return await loop.run_in_executor(thread_pool_db, bound_func) # noqa 12 | 13 | 14 | thread_pool_delay = ThreadPoolExecutor() 15 | 16 | 17 | async def run_in_thread_pool_delay(func, *args, **kwargs): 18 | loop = asyncio.get_running_loop() 19 | bound_func = partial(func, *args, **kwargs) 20 | return await loop.run_in_executor(thread_pool_db, bound_func) # noqa 21 | 22 | 23 | def close_threads(): 24 | thread_pool_db.shutdown() 25 | thread_pool_delay.shutdown() 26 | -------------------------------------------------------------------------------- /viper/utils/pwd_util.py: -------------------------------------------------------------------------------- 1 | from base64 import b64encode, b64decode 2 | 3 | from Crypto.Cipher import AES # pip install pycryptodome 4 | from Crypto.Util.Padding import pad, unpad 5 | 6 | from viper.core import settings 7 | 8 | 9 | class AESCipher: 10 | 11 | def __init__(self, key: str): 12 | key_hex = key 13 | key_bytes = bytes.fromhex(key_hex) 14 | if len(key_bytes) not in (16, 24, 32): 15 | raise ValueError('sec_key is error') 16 | self.key = key_bytes 17 | 18 | # 加密 19 | def encrypt(self, data): 20 | cipher = AES.new(self.key, AES.MODE_ECB) 21 | ct_bytes = cipher.encrypt(pad(data.encode(), AES.block_size)) 22 | ct = b64encode(ct_bytes).decode('utf-8') 23 | return ct 24 | 25 | # 解密 26 | def decrypt(self, data): 27 | ct = b64decode(data) 28 | cipher = AES.new(self.key, AES.MODE_ECB) 29 | pt = unpad(cipher.decrypt(ct), AES.block_size) 30 | return pt.decode('utf-8') 31 | 32 | 33 | secret_key = settings.CONF.get_conf('JWT', 'SEC_KEY') 34 | cipher = AESCipher(secret_key) 35 | 36 | if __name__ == '__main__': 37 | import secrets 38 | 39 | key_hex = secrets.token_hex(16) # 生成十六进制字符串 40 | key_bytes = bytes.fromhex(key_hex) # 转换为字节字符串 41 | print(key_hex) 42 | 43 | passwd = 'zaq1xsw2cde' 44 | encrypted = cipher.encrypt(passwd) # 加密 45 | print(encrypted) 46 | 47 | decrypted = cipher.decrypt(encrypted) # 解密 48 | print(decrypted) 49 | -------------------------------------------------------------------------------- /viper/utils/redis_util.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | 3 | from viper.core import settings 4 | 5 | 6 | def redis_conn(host=settings.CONF.get_conf('REDIS', 'HOST'), port=6379, db=0, decode_responses=False): 7 | return redis.Redis( 8 | host=host, 9 | port=port, 10 | db=db, 11 | # password=settings.CONF.get_sec_conf('REDIS', 'PASSWD'), 12 | decode_responses=decode_responses 13 | ) 14 | 15 | 16 | redis_client = redis_conn(decode_responses=True) 17 | 18 | 19 | async def close_redis(): 20 | await redis_client.aclose() 21 | -------------------------------------------------------------------------------- /viper/utils/resp_util.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | 3 | from starlette.exceptions import HTTPException 4 | 5 | from viper.utils.json_util import JsonExtendResponse 6 | 7 | 8 | class BaseResponse: 9 | 10 | def __init__(self): 11 | self.status = True 12 | self.error_code = None 13 | self.message = None 14 | self.data = None 15 | 16 | def to_dict(self): 17 | resp_dict = { 18 | 'status': self.status, 19 | 'error_code': self.error_code, 20 | 'message': self.message, 21 | 'data': self.data 22 | } 23 | return resp_dict 24 | 25 | 26 | def jsonify(*args, **kwargs): 27 | if args and kwargs: 28 | raise TypeError('jsonify() behavior undefined when passed both args and kwargs') 29 | elif len(args) == 1: 30 | content = args[0] 31 | else: 32 | content = args or kwargs 33 | response = BaseResponse() 34 | response.data = content 35 | response = response.to_dict() 36 | return JsonExtendResponse(response) 37 | 38 | 39 | def jsonify_exc(error_code, message=None): 40 | if not message: 41 | message = HTTPStatus(error_code).phrase 42 | response = BaseResponse() 43 | response.status = False 44 | response.error_code = error_code 45 | response.message = message 46 | response = response.to_dict() 47 | return JsonExtendResponse(response) 48 | 49 | 50 | def abort(error_code, message=None): 51 | raise HTTPException(status_code=error_code, detail=message) 52 | -------------------------------------------------------------------------------- /viper/utils/tools.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | from datetime import datetime 3 | 4 | 5 | def generate_random_id(byte_length: int = 16) -> str: 6 | return secrets.token_hex(byte_length) 7 | 8 | 9 | def current_time(): 10 | return datetime.now() 11 | -------------------------------------------------------------------------------- /viper/views/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungeer/viper/fbb39fd62b10907aadd6ec99eb0482147411054e/viper/views/__init__.py -------------------------------------------------------------------------------- /viper/views/chat_view.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import httpx 4 | from starlette.authentication import requires 5 | from starlette.responses import StreamingResponse 6 | 7 | from viper.core import settings 8 | from viper.utils import tools, json_util 9 | from viper.utils.http_util import httpx_common, httpx_stream 10 | from viper.utils.resp_util import jsonify 11 | from viper.utils.log_util import logger 12 | from viper.models.chat_model import ChatModel 13 | from viper.models.message_model import MessageModel 14 | from viper.models.content_model import ContentModel 15 | from viper.schemas import validator 16 | from viper.schemas.chat_schema import chat_id_schema, send_message_schema, get_messages_schema 17 | 18 | headers = { 19 | 'Content-Type': 'application/json', 20 | 'Access-key': settings.CONF.get_conf('AI', 'API_KEY'), 21 | 'Workspace-Id': settings.CONF.get_conf('AI', 'WORKSPACE_ID') 22 | } 23 | 24 | 25 | @requires('authenticated') 26 | async def get_chat_id(request): 27 | body = await request.json() 28 | body = validator(body, chat_id_schema) 29 | title = body['title'] 30 | 31 | url = f'{settings.CONF.get_conf('AI', 'URL')}/v1/oapi/agent/chat/conversation/create' 32 | data = { 33 | 'robot_id': settings.CONF.get_conf('AI', 'ROBOT_ID'), 34 | 'user': 'wangxun', 35 | 'title': title 36 | } 37 | response = await httpx_common.post(url, headers=headers, json=data) 38 | response = response.json() 39 | data = response.get('data') 40 | conversation_id = data.get('conversation_id') 41 | 42 | user = request.user 43 | user_id = user.id 44 | await ChatModel().add_chat(conversation_id, title, user_id) 45 | return jsonify(conversation_id) 46 | 47 | 48 | async def get_response(conversation_id, content): 49 | url = f'{settings.CONF.get_conf('AI', 'URL')}/v1/oapi/agent/chat' 50 | data = { 51 | 'robot_id': settings.CONF.get_conf('AI', 'ROBOT_ID'), 52 | 'conversation_id': conversation_id, 53 | 'content': content, 54 | 'response_mode': 'streaming' 55 | } 56 | error_msg = {'finish': 'error'} 57 | try: 58 | async with httpx_stream.stream('POST', url=url, headers=headers, json=data) as response: 59 | async for line in response.aiter_lines(): 60 | if not line: 61 | continue 62 | yield line 63 | except httpx.TimeoutException: 64 | logger.error(f'ai time out: 【{conversation_id}】') 65 | yield f'data: {json_util.dict_to_json(error_msg)}\n\n' 66 | except (Exception,): 67 | logger.opt(exception=True).error(f'ai error 【{conversation_id}】.') 68 | yield f'data: {json_util.dict_to_json(error_msg)}\n\n' 69 | 70 | 71 | async def stream_data(conversation_id, chat_id, trace_id, content): 72 | full_content = [] 73 | async for line in get_response(conversation_id, content): 74 | answer = line.replace('data: ', '') 75 | try: 76 | answer = json_util.json_to_dict(answer) 77 | except json.JSONDecodeError: 78 | continue 79 | is_error = answer.get('finish') 80 | if is_error: 81 | yield f'{is_error}\n' 82 | break 83 | if answer.get('type') == 'TEXT' and answer.get('status') == 'SUCCEEDED': 84 | content = answer.get('content') 85 | full_content.append(content) 86 | yield f'{content}\n' 87 | 88 | content_str = ''.join(full_content) if full_content else 'error' 89 | message_id = await MessageModel().add_message(chat_id, trace_id, 'robot') 90 | await ContentModel().add_content(message_id, content_str) 91 | 92 | 93 | @requires('authenticated') 94 | async def send_message(request): 95 | body = await request.json() 96 | body = validator(body, send_message_schema) 97 | conversation_id = body['conversation_id'] 98 | content = body['content'] 99 | 100 | trace_id = tools.generate_random_id() 101 | chat_info = await ChatModel().get_chat_by_conversation(conversation_id) 102 | chat_id = chat_info['ID'] 103 | 104 | message_id = await MessageModel().add_message(chat_id, trace_id, 'user') 105 | await ContentModel().add_content(message_id, content) 106 | 107 | return StreamingResponse(stream_data(conversation_id, chat_id, trace_id, content), media_type='text/event-stream') 108 | 109 | 110 | # 所有会话 111 | @requires('authenticated') 112 | async def get_chats(request): 113 | user = request.user 114 | user_id = user.id 115 | chats = await ChatModel().get_chats(user_id) 116 | return jsonify(chats) 117 | 118 | 119 | # 所有问答 120 | @requires(['authenticated', 'admin']) 121 | async def get_messages(request): 122 | body = await request.json() 123 | body = validator(body, get_messages_schema) 124 | conversation_id = body['conversation_id'] 125 | 126 | chats = await MessageModel().get_messages(conversation_id) 127 | return jsonify(chats) 128 | -------------------------------------------------------------------------------- /viper/views/delay_view.py: -------------------------------------------------------------------------------- 1 | from viper.utils.resp_util import jsonify 2 | from viper.delays.backgrounds import delay_long_task 3 | from viper.delays.huey_instance import huey 4 | 5 | 6 | async def start_task(request): 7 | data = await request.json() 8 | input_data = data.get('input', 'default') 9 | task = await delay_long_task(input_data) 10 | message = {'status': 'task started', 'task_id': task.id} 11 | return jsonify(message) 12 | 13 | 14 | async def check_task_status(request): 15 | # task_id = request.path_params['task_id'] # get method /task-status/{task_id} 16 | data = await request.json() 17 | task_id = data['task_id'] 18 | result = huey.result(task_id) 19 | if result is None: 20 | return jsonify({'status': 'pending or failed'}) 21 | return jsonify({'status': 'completed', 'result': result}) 22 | -------------------------------------------------------------------------------- /viper/views/user_view.py: -------------------------------------------------------------------------------- 1 | from viper.utils import jwt_util 2 | from viper.utils.resp_util import jsonify, abort 3 | from viper.models.user_model import UserModel 4 | from viper.schemas import validator 5 | from viper.schemas.user_schema import access_token_schema 6 | 7 | 8 | async def get_access_token(request): 9 | body = await request.json() 10 | body = validator(body, access_token_schema) 11 | phone_number = body['phone_number'] 12 | password = body['password'] 13 | 14 | db_user = await UserModel().get_user_by_phone(phone_number) 15 | if not db_user: 16 | return abort(404, 'User not found') 17 | 18 | db_password = db_user['password_hash'] 19 | is_pwd = jwt_util.validate_password(password, db_password) 20 | if not is_pwd: 21 | return abort(403, 'Incorrect password') 22 | 23 | user_id = db_user['id'] 24 | access_token = jwt_util.generate_token({'id': user_id}) 25 | jwt_token = {'access_token': access_token, 'token_type': 'bearer'} 26 | return jsonify(jwt_token) 27 | --------------------------------------------------------------------------------