├── fastapi_plus ├── config │ ├── __init__.py │ └── anonymous.py ├── dao │ ├── __init__.py │ ├── demo.py │ ├── event.py │ └── base.py ├── model │ ├── __init__.py │ ├── demo.py │ ├── event_log.py │ └── base.py ├── utils │ ├── __init__.py │ ├── json_encoders.py │ ├── obj2json.py │ ├── locker.py │ ├── list2dict.py │ ├── custom_route.py │ ├── json_custom.py │ ├── obj2dict.py │ ├── wxapp.py │ ├── auth.py │ ├── mongo.py │ ├── db.py │ ├── redis.py │ ├── generate_model.py │ ├── request_log.py │ └── sync_model.py ├── schema │ ├── __init__.py │ ├── demo.py │ └── base.py ├── service │ ├── __init__.py │ ├── demo.py │ ├── sys_info.py │ └── base.py ├── controller │ ├── __init__.py │ ├── base.py │ └── demo.py └── __init__.py ├── .gitignore ├── README.md ├── LICENSE └── setup.py /fastapi_plus/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 配置目录 3 | """ 4 | -------------------------------------------------------------------------------- /fastapi_plus/dao/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 数据处理层目录 3 | """ 4 | -------------------------------------------------------------------------------- /fastapi_plus/model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 模型对象目录 3 | """ 4 | -------------------------------------------------------------------------------- /fastapi_plus/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具箱目录 3 | """ 4 | -------------------------------------------------------------------------------- /fastapi_plus/schema/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Schema目录 3 | """ 4 | -------------------------------------------------------------------------------- /fastapi_plus/service/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 业务逻辑层目录 3 | """ 4 | -------------------------------------------------------------------------------- /fastapi_plus/controller/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 控制层目录,API入口 3 | """ 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .idea 3 | /venv/ 4 | 5 | #build 6 | /build/ 7 | /dist/ 8 | /fastapi_plus.egg-info/ -------------------------------------------------------------------------------- /fastapi_plus/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 这是一个Python FastAPI项目工程库,包含DB、Redis、MongoDB、JSON等工具和基础服务类。 3 | """ 4 | -------------------------------------------------------------------------------- /fastapi_plus/dao/demo.py: -------------------------------------------------------------------------------- 1 | from fastapi_plus.dao.base import BaseDao 2 | 3 | 4 | class DemoDao(BaseDao): 5 | pass 6 | -------------------------------------------------------------------------------- /fastapi_plus/utils/json_encoders.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | class JSONEncoders(object): 5 | """ 6 | 定义JSONEncoders 7 | """ 8 | json_encoders = { 9 | datetime: lambda dt: dt.isoformat(' ') # 解决日期和时间中“T”字符的格式问题 10 | } 11 | -------------------------------------------------------------------------------- /fastapi_plus/config/anonymous.py: -------------------------------------------------------------------------------- 1 | """ 2 | 匿名访问接口列表 3 | 格式:req.method + req.url.path,其中method为大写 4 | """ 5 | anonymous_path_list = [ 6 | 'GET/', 7 | 'GET/robots.txt', 8 | 'GET/favicon.ico', 9 | 'GET/docs', 10 | 'GET/docs/oauth2-redirect', 11 | 'GET/redoc', 12 | 'GET/openapi.json', 13 | 'GET/sys_info', 14 | ] 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #Fastapi Plus 2 | 这是一个Python FastAPI项目工程库,包含DB、Redis、MongoDB、JSON等工具和基础服务类。 3 | 4 | [Github: Fastapi Plus](https://github.com/zhenqiang-sun/fastapi_plus/) 5 | 6 | ### 组件: 7 | - FastAPI: https://fastapi.tiangolo.com/ 8 | - uvicorn: https://www.uvicorn.org/ 9 | - SQLAlchemy: https://www.sqlalchemy.org/ 10 | - PyMySQL: https://pymysql.readthedocs.io/ 11 | - REDIS: https://github.com/andymccurdy/redis-py 12 | -------------------------------------------------------------------------------- /fastapi_plus/utils/obj2json.py: -------------------------------------------------------------------------------- 1 | """ 2 | object转换json string函数 3 | :version: 1.0 4 | :date: 2020-02-16 5 | """ 6 | import json 7 | 8 | from .json_custom import CustomJSONEncoder 9 | from .obj2dict import obj2dict 10 | 11 | 12 | def obj2json(obj) -> str: 13 | x_dict = obj2dict(obj) 14 | x_json = json.dumps(x_dict, ensure_ascii=False, cls=CustomJSONEncoder) 15 | 16 | # 返回json字符串 17 | return x_json 18 | -------------------------------------------------------------------------------- /fastapi_plus/schema/demo.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from fastapi_plus.schema.base import InfoSchema, RespDetailSchema 4 | 5 | 6 | class DemoInfoSchema(InfoSchema): 7 | parent_name: str 8 | 9 | 10 | class DemoDetailSchema(DemoInfoSchema): 11 | created_time: datetime 12 | updated_time: datetime 13 | 14 | 15 | class DemoRespDetailSchema(RespDetailSchema): 16 | detail: DemoDetailSchema = None 17 | -------------------------------------------------------------------------------- /fastapi_plus/service/demo.py: -------------------------------------------------------------------------------- 1 | from fastapi_plus.service.base import BaseService 2 | 3 | from ..dao.demo import DemoDao 4 | from ..model.demo import Demo 5 | 6 | 7 | class DemoService(BaseService): 8 | def __init__(self, auth_data: dict = {}): 9 | user_id = auth_data.get('user_id', 0) 10 | self.Model = Demo 11 | self.dao = DemoDao(user_id) 12 | self.dao.Model = Demo 13 | 14 | super().__init__(user_id, auth_data) 15 | -------------------------------------------------------------------------------- /fastapi_plus/service/sys_info.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | 4 | class SysInfoService(object): 5 | """信息信息服务类. 6 | 7 | 获取服务运行系统的基础信息 8 | """ 9 | 10 | @staticmethod 11 | def get_sys_info(): 12 | return { 13 | 'platform': platform.platform(), 14 | 'machine': platform.machine(), 15 | 'node': platform.node(), 16 | 'processor': platform.processor(), 17 | 'python_version': platform.python_version(), 18 | } 19 | -------------------------------------------------------------------------------- /fastapi_plus/model/demo.py: -------------------------------------------------------------------------------- 1 | from fastapi_plus.model.base import * 2 | 3 | 4 | class Demo(Base): 5 | __tablename__ = 'demo' 6 | __table_args__ = {'comment': 'Demo'} 7 | 8 | user_id = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='用户ID') 9 | category_id = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='分类ID') 10 | category_name = Column(String(255), nullable=False, server_default=text("''"), comment='分类名称') 11 | data = Column(String(1000), nullable=False, server_default=text("''"), comment='数据') 12 | -------------------------------------------------------------------------------- /fastapi_plus/utils/locker.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from .redis import RedisUtils 4 | 5 | 6 | class Locker(object): 7 | """ 8 | Locker 基于redis的锁 9 | :version: 1.2 10 | :date: 2020-02-11 11 | """ 12 | 13 | redis: RedisUtils 14 | 15 | def __init__(self): 16 | self.redis = RedisUtils() 17 | 18 | # 判断是否存在锁 19 | def has_lock(self, key): 20 | return self.redis.get_string('locker:' + key) 21 | 22 | # 加锁 23 | def lock(self, key, expiration=None): 24 | self.redis.set_string('locker:' + key, str(time.time()), expiration) 25 | 26 | # 解锁 27 | def unlock(self, key): 28 | self.redis.delete('locker:' + key) 29 | -------------------------------------------------------------------------------- /fastapi_plus/controller/base.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from starlette.responses import Response 3 | 4 | base_router = APIRouter() 5 | 6 | 7 | @base_router.get('/') 8 | async def get_root(): 9 | """ 10 | 访问根路径 11 | """ 12 | return Response(content='', media_type='text/plain') 13 | 14 | 15 | @base_router.get('/robots.txt') 16 | async def get_robots(): 17 | """ 18 | 获取爬虫权限 19 | """ 20 | return Response(content='User-agent: * \nDisallow: /', media_type='text/plain') 21 | 22 | 23 | @base_router.get('/sys_info') 24 | async def get_sys_info(): 25 | """ 26 | 获取系统基本信息 27 | """ 28 | from ..service.sys_info import SysInfoService 29 | 30 | return SysInfoService().get_sys_info() 31 | -------------------------------------------------------------------------------- /fastapi_plus/utils/list2dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | list2dict, list转换为dict 3 | :version: 1.0 4 | :date: 2020-02-15 5 | """ 6 | 7 | 8 | def list_list2dict(x_list: list, key_index: int = 0, value_index: int = 1): 9 | x_dict = {} 10 | 11 | for x_item in x_list: 12 | if isinstance(x_item, list) or isinstance(x_item, tuple): 13 | if x_item[key_index]: 14 | x_dict[x_item[key_index]] = x_item[value_index] 15 | 16 | return x_dict 17 | 18 | 19 | def list_dict2dict(x_list: list, key_key: str, value_key: str): 20 | x_dict = {} 21 | 22 | for x_item in x_list: 23 | if isinstance(x_item, dict): 24 | if key_key in x_item and value_key in x_item: 25 | x_dict[x_item[key_key]] = x_item[value_key] 26 | 27 | return x_dict 28 | -------------------------------------------------------------------------------- /fastapi_plus/utils/custom_route.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from fastapi.routing import APIRoute 4 | from starlette.requests import Request 5 | from starlette.responses import Response 6 | 7 | from .request_log import create_log, update_log 8 | 9 | 10 | class CustomRoute(APIRoute): 11 | """ 12 | 自定义APIRouter 13 | """ 14 | 15 | def get_route_handler(self) -> Callable: 16 | original_route_handler = super().get_route_handler() 17 | 18 | async def custom_route_handler(request: Request) -> Response: 19 | request.state.log = await create_log(request) 20 | response = await original_route_handler(request) 21 | await update_log(request.state.log, response) 22 | return response 23 | 24 | return custom_route_handler 25 | -------------------------------------------------------------------------------- /fastapi_plus/utils/json_custom.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import decimal 3 | 4 | from json import JSONEncoder 5 | 6 | 7 | class CustomJSONEncoder(JSONEncoder): 8 | """ 9 | 自定义JSON编码处理 10 | :version: 1.1 11 | :date: 2019-01-08 12 | """ 13 | 14 | def default(self, obj): 15 | try: 16 | if isinstance(obj, datetime.date): 17 | return obj.isoformat().replace('T', ' ') 18 | elif isinstance(obj, datetime.datetime): 19 | return obj.isoformat().replace('T', " ") 20 | elif isinstance(obj, decimal.Decimal): 21 | return str(obj) 22 | iterable = iter(obj) 23 | except TypeError: 24 | pass 25 | else: 26 | return list(iterable) 27 | return JSONEncoder.default(self, obj) 28 | -------------------------------------------------------------------------------- /fastapi_plus/utils/obj2dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | object与dict转换函数 3 | :version: 1.1 4 | :date: 2019-01-08 5 | """ 6 | 7 | 8 | def obj2dict(obj): 9 | if not obj: 10 | return None 11 | 12 | # 判断是否是Query 13 | # 定义一个字典对象 14 | dictionary = {} 15 | # 检索记录中的成员 16 | for field in [x for x in dir(obj) if 17 | # 过滤属性 18 | not x.startswith('_') 19 | # 过滤掉方法属性 20 | and hasattr(obj.__getattribute__(x), '__call__') == False 21 | # 过滤掉不需要的属性 22 | and x != 'metadata' 23 | and x != 'query']: 24 | data = obj.__getattribute__(field) 25 | 26 | if hasattr(data, 'query'): 27 | data = obj2dict(data) 28 | 29 | try: 30 | dictionary[field] = data 31 | except TypeError: 32 | dictionary[field] = None 33 | 34 | # 返回字典对象 35 | return dictionary 36 | -------------------------------------------------------------------------------- /fastapi_plus/utils/wxapp.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | # 配置文件:微信小程序 5 | class WxappConfig(object): 6 | appid = '' 7 | secret = '' 8 | 9 | 10 | class WxappUtils(object): 11 | 12 | def __init__(self, config: WxappConfig): 13 | self.config: WxappConfig = config 14 | 15 | def jscode2session(self, code): 16 | """ 17 | 登录 18 | :url https://developers.weixin.qq.com/miniprogram/dev/api-backend/open-api/login/auth.code2Session.html 19 | :param code: 小程序登录时获取的 code 20 | :return: 21 | """ 22 | url = 'https://api.weixin.qq.com/sns/jscode2session?appid={APPID}&secret={SECRET}&js_code={JSCODE}&grant_type=authorization_code' 23 | url = url.format(**{ 24 | 'APPID': self.config.appid, 25 | 'SECRET': self.config.secret, 26 | 'JSCODE': code, 27 | }) 28 | 29 | resp = requests.get(url) 30 | return resp.json() 31 | -------------------------------------------------------------------------------- /fastapi_plus/model/event_log.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | 3 | 4 | class EventLog(Base): 5 | """ 6 | 事件记录模型 7 | """ 8 | __tablename__ = 'event_log' 9 | __table_args__ = {'comment': '事件记录'} 10 | 11 | user_id = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='用户ID') 12 | relation_obj = Column(String(255), nullable=False, server_default=text("''"), comment='相关对象') 13 | relation_id = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='相关ID') 14 | relation_name = Column(String(255), nullable=False, server_default=text("''"), comment='相关名称') 15 | event_id = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='事件id') 16 | event_time = Column(TIMESTAMP, comment='事件发生时间') 17 | event_from = Column(String(255), nullable=False, server_default=text("''"), comment='事件发生来源') 18 | before_data = Column(LONGTEXT, comment='之前数据') 19 | change_data = Column(LONGTEXT, comment='变化数据') 20 | after_data = Column(LONGTEXT, comment='之后数据') 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | FastAPI Plus 3 | """ 4 | 5 | import setuptools 6 | 7 | with open("README.md", "r", encoding='utf-8') as fh: 8 | long_description = fh.read() 9 | 10 | setuptools.setup( 11 | name="fastapi_plus", 12 | version='0.1.4.20201125', 13 | author="Zhenqiang Sun", 14 | author_email="zhenqiang.sun@gmail.com", 15 | description="This is a Python FastAPI project engineering library that includes tools and basic service classes.", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/zhenqiang-sun/fastapi_plus", 19 | packages=setuptools.find_packages(), 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | ], 25 | python_requires='>=3.6', 26 | install_requires=[ 27 | 'fastapi==0.61.2', 28 | 'uvicorn==0.12.2', 29 | 'sqlalchemy==1.3.19', 30 | 'pymysql==0.10.0', 31 | 'sqlacodegen==2.3.0', 32 | 'redis==3.5.3', 33 | 'pymongo==3.11.1', 34 | 'requests==2.25.0', 35 | 'python-multipart==0.0.5', 36 | 'aiofiles==0.6.0' 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /fastapi_plus/utils/auth.py: -------------------------------------------------------------------------------- 1 | from fastapi import Header 2 | 3 | from .redis import RedisUtils 4 | 5 | 6 | async def get_auth_data(authorization: str = Header(None)): 7 | """ 8 | 获取登录用户认证数据, 通常用于controller层 9 | :param authorization: 请求header中的authorization 10 | :return: 11 | """ 12 | return get_auth_data_by_authorization(authorization) 13 | 14 | 15 | def get_auth_data_by_authorization(authorization: str, ex: int = None): 16 | """ 17 | 获取登录用户认证数据 18 | :param authorization: 19 | :param prefix: 前缀 20 | :param ex: 数据过期秒数 21 | :return: 22 | """ 23 | if authorization: 24 | return get_auth_data_by_token(authorization, ex) 25 | 26 | return None 27 | 28 | 29 | def get_auth_data_by_token(token: str, ex: int = None): 30 | """ 31 | 获取登录用户认证数据, 从redis中读取 32 | :param token: 登录的token 33 | :param ex: 数据过期秒数 34 | :return: 登录认证数据 35 | """ 36 | 37 | auth_data = RedisUtils().get('token:' + token) 38 | 39 | if ex and auth_data: 40 | RedisUtils().expire('token:' + token, ex) 41 | 42 | return auth_data 43 | 44 | 45 | def update_auth_data(auth_data: dict, ex: int = None): 46 | """ 47 | 更新认证数据 48 | :param auth_data: 登录认证数据 49 | :param ex: 数据过期秒数 50 | """ 51 | RedisUtils().set('token:' + auth_data.get('token'), auth_data, ex) 52 | -------------------------------------------------------------------------------- /fastapi_plus/controller/demo.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends 2 | from fastapi_plus.schema.base import ListArgsSchema, RespBaseSchema, RespIdSchema, RespListSchema 3 | from fastapi_plus.utils.auth import get_auth_data 4 | from fastapi_plus.utils.custom_route import CustomRoute 5 | 6 | from ..schema.demo import DemoInfoSchema, DemoRespDetailSchema 7 | from ..service.demo import DemoService 8 | 9 | router = APIRouter(route_class=CustomRoute) 10 | 11 | 12 | @router.post('/list', response_model=RespListSchema) 13 | async def list(*, args: ListArgsSchema, auth_data: dict = Depends(get_auth_data)): 14 | args.user_id = auth_data.get('user_id') 15 | return DemoService(auth_data).list(args) 16 | 17 | 18 | @router.get('/{id}', response_model=DemoRespDetailSchema) 19 | async def read(id: int, auth_data: dict = Depends(get_auth_data)): 20 | resp = DemoRespDetailSchema() 21 | resp.detail = DemoService(auth_data).read(id) 22 | return resp 23 | 24 | 25 | @router.post('', response_model=RespIdSchema, response_model_exclude_none=True) 26 | async def create(*, info: DemoInfoSchema, auth_data: dict = Depends(get_auth_data)): 27 | return DemoService(auth_data).create(info) 28 | 29 | 30 | @router.put('/{id}', response_model=RespBaseSchema) 31 | async def update(*, info: DemoInfoSchema, auth_data: dict = Depends(get_auth_data)): 32 | return DemoService(auth_data).update(info) 33 | 34 | 35 | @router.delete('/{id}', response_model=RespBaseSchema) 36 | async def delete(id: int, auth_data: dict = Depends(get_auth_data)): 37 | return DemoService(auth_data).delete(id) 38 | -------------------------------------------------------------------------------- /fastapi_plus/model/base.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, TIMESTAMP, text, DECIMAL, Date 2 | from sqlalchemy.dialects.mysql import BIGINT, INTEGER, LONGTEXT, TINYINT 3 | from sqlalchemy.ext.declarative import declarative_base 4 | 5 | DeclarativeBase = declarative_base() 6 | 7 | 8 | class Base(DeclarativeBase): 9 | """ 10 | 基础Model模型对象 11 | """ 12 | __abstract__ = True 13 | 14 | id = Column(BIGINT(20), primary_key=True, comment='序号') 15 | parent_id = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='父序号') 16 | type = Column(INTEGER(11), nullable=False, server_default=text("0"), comment='类型') 17 | sort = Column(INTEGER(11), nullable=False, server_default=text("0"), comment='排序') 18 | status = Column(TINYINT(2), nullable=False, server_default=text("0"), comment='状态') 19 | is_deleted = Column(TINYINT(1), nullable=False, server_default=text("0"), comment='软删') 20 | created_by = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='创建人') 21 | created_time = Column(TIMESTAMP, nullable=False, server_default=text("current_timestamp()"), comment='创建时间') 22 | updated_by = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='更新人') 23 | updated_time = Column(TIMESTAMP, nullable=False, 24 | server_default=text("current_timestamp() ON UPDATE current_timestamp()"), comment='更新时间') 25 | code = Column(String(255), nullable=False, server_default=text("''"), comment='编码') 26 | name = Column(String(255), nullable=False, server_default=text("''"), comment='名称') 27 | label = Column(String(255), nullable=False, server_default=text("''"), comment='标签') 28 | logo = Column(String(255), nullable=False, server_default=text("''"), comment='图标') 29 | url = Column(String(255), nullable=False, server_default=text("''"), comment='URL') 30 | info = Column(String(1000), nullable=False, server_default=text("''"), comment='内容') 31 | remark = Column(String(1000), nullable=False, server_default=text("''"), comment='备注') 32 | search = Column(LONGTEXT, comment='搜索') 33 | -------------------------------------------------------------------------------- /fastapi_plus/utils/mongo.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | from pymongo.collection import Collection 3 | 4 | 5 | class MongoConfig(object): 6 | """ 7 | MongoConfig MongoDB配置类 8 | :version: 1.1 9 | :date: 2020-02-12 10 | """ 11 | 12 | host = 'mongodb' 13 | port = '27017' 14 | username = 'root' 15 | password = '' 16 | database = '' 17 | 18 | def get_url(self): 19 | config = [ 20 | 'mongodb://', 21 | self.username, 22 | ':', 23 | self.password, 24 | '@', 25 | self.host, 26 | ':', 27 | self.port, 28 | '/', 29 | self.database, 30 | '?authSource=', 31 | self.database, 32 | '&authMechanism=SCRAM-SHA-256', 33 | ] 34 | 35 | return ''.join(config) 36 | 37 | 38 | class MongoUtils(object): 39 | """ 40 | MongoUtils MongoDB工具类 41 | :version: 1.1 42 | :date: 2020-02-12 43 | """ 44 | 45 | _config: MongoConfig = None 46 | default_config: MongoConfig = None 47 | 48 | def __init__(self, config: MongoConfig = None): 49 | if config: 50 | self._config = config 51 | else: 52 | self._config = self.default_config 53 | 54 | def _get_client(self): 55 | """ 56 | 返回Mongo数据库连接,同步 57 | :return: 58 | """ 59 | try: 60 | client = MongoClient(self._config.get_url()) 61 | return client 62 | except Exception as e: 63 | raise str(e) 64 | 65 | def _get_db(self): 66 | """ 67 | 返回Mongo数据库实例 68 | :param database: 69 | :return: 70 | """ 71 | 72 | try: 73 | client = self._get_client() 74 | db = client[self._config.database] 75 | return db 76 | except Exception as e: 77 | raise str(e) 78 | 79 | def get_collection(self, collection_name): 80 | """ 81 | 返回输入的名称对应的集合 82 | :param collection_name: 83 | :return: 84 | """ 85 | 86 | try: 87 | db = self._get_db() 88 | collection: Collection = db[collection_name] 89 | return collection 90 | except Exception as e: 91 | raise str(e) 92 | -------------------------------------------------------------------------------- /fastapi_plus/utils/db.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.orm import sessionmaker, scoped_session, Session 3 | 4 | 5 | class DbConfig(object): 6 | """ 7 | DbConfig DB配置类 8 | :version: 1.4 9 | :date: 2020-02-11 10 | """ 11 | 12 | driver = 'mysql+pymysql' 13 | host = 'mariadb' 14 | port = '3306' 15 | username = 'root' 16 | password = '' 17 | database = '' 18 | charset = 'utf8mb4' 19 | table_name_prefix = '' 20 | echo = True 21 | pool_size = 100 22 | max_overflow = 100 23 | pool_recycle = 60 24 | 25 | def get_url(self): 26 | config = [ 27 | self.driver, 28 | '://', 29 | self.username, 30 | ':', 31 | self.password, 32 | '@', 33 | self.host, 34 | ':', 35 | self.port, 36 | '/', 37 | self.database, 38 | '?charset=', 39 | self.charset, 40 | ] 41 | 42 | return ''.join(config) 43 | 44 | 45 | class DbUtils(object): 46 | """ 47 | DbUtils DB工具类 48 | :version: 1.4 49 | :date: 2020-02-11 50 | """ 51 | 52 | sess: Session = None 53 | default_config: DbConfig = None 54 | 55 | def __init__(self, config: DbConfig = None): 56 | if not config: 57 | config = self.default_config 58 | 59 | self.sess = self._create_scoped_session(config) 60 | 61 | def __del__(self): 62 | self.sess.close() 63 | 64 | @staticmethod 65 | def _create_scoped_session(config: DbConfig): 66 | engine = create_engine( 67 | config.get_url(), 68 | pool_size=config.pool_size, 69 | max_overflow=config.max_overflow, 70 | pool_recycle=config.pool_recycle, 71 | echo=config.echo 72 | ) 73 | 74 | session_factory = sessionmaker(autocommit=True, autoflush=False, bind=engine) 75 | 76 | return scoped_session(session_factory) 77 | 78 | # 根据文件获取SQL文件 79 | @staticmethod 80 | def get_sql_by_file(file_path, params={}): 81 | sql = DbUtils._get_file(file_path) 82 | return sql.format(**params) 83 | 84 | # 获取SQL文件 85 | @staticmethod 86 | def _get_file(file_path): 87 | with open('app/sql/' + file_path, 'r', encoding='utf-8') as f: 88 | return f.read() 89 | -------------------------------------------------------------------------------- /fastapi_plus/utils/redis.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from redis import ConnectionPool, Redis 4 | 5 | from .json_custom import CustomJSONEncoder 6 | 7 | 8 | class RedisConfig(object): 9 | """ 10 | RedisConfig Redis配置类 11 | :version: 1.2 12 | :date: 2020-02-11 13 | """ 14 | 15 | host = 'redis' 16 | port = '6379' 17 | username = 'root' 18 | password = '' 19 | database = 0 20 | max_connections = 100 21 | 22 | 23 | class RedisUtils: 24 | """ 25 | RedisUtils redis工具类 26 | :version: 1.2 27 | :date: 2020-02-11 28 | """ 29 | 30 | _conn = None 31 | _default_conn_pool = None 32 | default_config: RedisConfig = None 33 | 34 | def __init__(self, config: RedisConfig = None): 35 | if config: 36 | self._conn = self._get_conn(config) 37 | else: 38 | if not self._default_conn_pool: 39 | RedisUtils._default_conn_pool = self._create_pool(self.default_config) 40 | 41 | self._conn = Redis(connection_pool=self._default_conn_pool) 42 | 43 | @staticmethod 44 | def _create_pool(config: RedisConfig): 45 | return ConnectionPool( 46 | host=config.host, 47 | port=config.port, 48 | max_connections=config.max_connections, 49 | username=config.username, 50 | password=config.password, 51 | db=config.database 52 | ) 53 | 54 | @staticmethod 55 | def _get_conn(config: RedisConfig): 56 | return Redis( 57 | host=config.host, 58 | port=config.port, 59 | max_connections=config.max_connections, 60 | username=config.username, 61 | password=config.password, 62 | db=config.database 63 | ) 64 | 65 | def delete(self, key): 66 | return self._conn.delete(key) 67 | 68 | def set_string(self, key, value, ex=None): 69 | return self._conn.set(key, value, ex) 70 | 71 | def get_string(self, key): 72 | value = self._conn.get(key) 73 | 74 | if value: 75 | return str(value, 'utf-8') 76 | else: 77 | return None 78 | 79 | def set(self, key, value, ex=None): 80 | return self._conn.set(key, json.dumps(value, ensure_ascii=False, cls=CustomJSONEncoder), ex) 81 | 82 | def get(self, key): 83 | value = self._conn.get(key) 84 | 85 | if value: 86 | return json.loads(str(value, 'utf-8')) 87 | else: 88 | return None 89 | 90 | def expire(self, key, ex=int): 91 | return self._conn.expire(key, ex) 92 | -------------------------------------------------------------------------------- /fastapi_plus/utils/generate_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class GenerateModel(object): 5 | """ 6 | 基于Demo代码生成新的模块,省去复制粘贴 7 | """ 8 | 9 | lib_path: str # fastapi_plus库路径 10 | app_path: str # app应用路径 11 | model_name: str # model名称,小写+下划线式,snake 12 | model_name_pascal: str # model名称,大驼峰,pascal 13 | 14 | def __init__(self, app_path: str, model_name: str): 15 | # 接收、处理入参 16 | self.app_path = app_path 17 | self.model_name = model_name 18 | self.model_name_pascal = self._transform_name(model_name) 19 | self.lib_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 20 | 21 | # 生成文件 22 | self._generate_file('controller') 23 | self._generate_file('dao') 24 | self._generate_file('model') 25 | self._generate_file('schema') 26 | self._generate_file('service') 27 | 28 | @staticmethod 29 | def _transform_name(src: str, first_upper: bool = True) -> str: 30 | """ 31 | 将下划线分隔的名字,转换为驼峰模式 32 | :param src: 33 | :param first_upper: 转换后的首字母是否指定大写(如 34 | :return: 35 | """ 36 | arr = src.split('_') 37 | res = '' 38 | for i in arr: 39 | res = res + i[0].upper() + i[1:] 40 | 41 | if not first_upper: 42 | res = res[0].lower() + res[1:] 43 | return res 44 | 45 | @staticmethod 46 | def _read_file_content(file_path: str) -> str: 47 | """ 48 | 读取文件 49 | :param file_path: 50 | :return: 51 | """ 52 | with open(file_path, 'r', encoding='utf-8') as f: 53 | return f.read() 54 | 55 | @staticmethod 56 | def _save_file_content(file_path: str, file_content: str): 57 | """ 58 | 保存文件 59 | :param file_path: 60 | :param file_content: 61 | :return: 62 | """ 63 | 64 | with open(file_path, 'w', encoding='utf-8') as f: 65 | f.write(file_content) 66 | 67 | def _generate_file(self, dir_name: str): 68 | """ 69 | 生成文件 70 | :param dir_name: 71 | :return: 72 | """ 73 | 74 | src_file_path = self.lib_path + os.sep + dir_name + os.sep + 'demo.py' 75 | new_file_path = self.app_path + os.sep + dir_name + os.sep + self.model_name + '.py' 76 | 77 | file_content = self._read_file_content(src_file_path) 78 | 79 | file_content = file_content.replace('demo', self.model_name) 80 | file_content = file_content.replace('Demo', self.model_name_pascal) 81 | 82 | self._save_file_content(new_file_path, file_content) 83 | -------------------------------------------------------------------------------- /fastapi_plus/utils/request_log.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | from starlette.requests import Request 5 | from starlette.responses import Response 6 | 7 | from .auth import get_auth_data_by_authorization, get_auth_data_by_token 8 | from .list2dict import list_list2dict 9 | from .mongo import MongoUtils 10 | from .obj2dict import obj2dict 11 | 12 | 13 | class Log(object): 14 | id = None 15 | status = 1 16 | user_id = None 17 | in_datetime = None 18 | in_millisecond = None 19 | out_datetime = None 20 | out_millisecond = None 21 | use_millisecond = None 22 | ip = None 23 | url = None 24 | method = None 25 | path = None 26 | path_params = None 27 | query_params = None 28 | header = None # request.headers.items() 29 | body = None 30 | response_status_code = None 31 | response = None 32 | 33 | 34 | async def create_log(request: Request) -> Log: 35 | mongo_log = MongoUtils().get_collection('request_log') 36 | 37 | in_time = time.time() 38 | log = Log() 39 | log.status = 1 40 | log.in_datetime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(in_time)) 41 | log.in_millisecond = int(round(in_time * 1000)) 42 | log.ip = request.client.host 43 | log.url = str(request.url) 44 | log.method = request.method 45 | log.path = request.url.path 46 | log.path_params = request.path_params 47 | log.query_params = request.url.query 48 | log.header = list_list2dict(request.headers.items()) 49 | 50 | try: 51 | log.body = await request.json() 52 | except: 53 | pass 54 | 55 | if 'authorization' in log.header: 56 | auth_data = get_auth_data_by_authorization(log.header['authorization']) 57 | 58 | if auth_data: 59 | log.user_id = auth_data.get('user_id') 60 | 61 | mongo_result = mongo_log.insert_one(obj2dict(log)) 62 | log.id = mongo_result.inserted_id 63 | 64 | return log 65 | 66 | 67 | async def update_log(log: Log, response: Response): 68 | mongo_log = MongoUtils().get_collection('request_log') 69 | 70 | out_time = time.time() 71 | log.status = 2 72 | log.out_datetime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(out_time)) 73 | log.out_millisecond = int(round(out_time * 1000)) 74 | log.use_millisecond = log.out_millisecond - log.in_millisecond 75 | log.response_status_code = response.status_code 76 | 77 | try: 78 | log.response = json.loads(str(response.body, 'utf8')) 79 | except: 80 | pass 81 | 82 | if not log.user_id and log.response and 'token' in log.response: 83 | auth_data = get_auth_data_by_token(log.response['token']) 84 | log.user_id = auth_data.get('user_id') 85 | 86 | mongo_log.update_one({'_id': log.id}, {'$set': obj2dict(log)}) 87 | -------------------------------------------------------------------------------- /fastapi_plus/dao/event.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from ..model.event_log import EventLog 4 | from ..utils.db import DbUtils 5 | from ..utils.json_custom import CustomJSONEncoder 6 | from ..utils.obj2json import obj2json 7 | 8 | 9 | class EventDao(object): 10 | """Event(事件)Dao,业务事件记录. 11 | 12 | 用于业务事件记录,比如修改数据等 13 | 14 | Attributes: 15 | user_id: 当前操作用户id 16 | db: db实体 17 | """ 18 | 19 | def __init__(self, user_id: int = 0): 20 | self.user_id = user_id 21 | self.db = DbUtils() 22 | 23 | def get_event_log(self, event_id: int, relation_obj: str, before_data=None) -> EventLog: 24 | """ 25 | 生成事件记录对象并返回 26 | :param event_id: 事件id 27 | :param relation_obj: 相关对象 28 | :param before_data: 之前数据 29 | :return: 事件记录对象 30 | """ 31 | 32 | # 构造数据:事件记录实例 33 | event_log = EventLog() 34 | event_log.user_id = self.user_id 35 | event_log.event_id = event_id 36 | event_log.relation_obj = relation_obj 37 | 38 | if before_data: 39 | event_log.relation_id = before_data.id 40 | event_log.before_data = obj2json(before_data) 41 | 42 | return event_log 43 | 44 | def create_event_log(self, event_log: EventLog, after_data=None): 45 | """ 46 | 创建事件记录,保存在数据库中 47 | :param event_log: 事件记录实例 48 | :param after_data: 变化后的数据 49 | """ 50 | if after_data: 51 | event_log.after_data = obj2json(after_data) 52 | 53 | if not event_log.change_data: 54 | self.calculate_change(event_log) 55 | 56 | self.db.sess.add(event_log) 57 | self.db.sess.flush() 58 | 59 | def update_event_log(self, event_log: EventLog): 60 | """ 61 | 更新操作记录 62 | :param event_log: 63 | :return: 64 | """ 65 | self.db.sess.add(event_log) 66 | self.db.sess.flush() 67 | 68 | @staticmethod 69 | def calculate_change(event_log: EventLog): 70 | """ 71 | 比较数据前后差异,计算变化部分 72 | :param event_log: 73 | """ 74 | if not event_log.before_data or not event_log.after_data: 75 | event_log.change_data = None 76 | return None 77 | 78 | before: dict = json.loads(event_log.before_data) 79 | after: dict = json.loads(event_log.after_data) 80 | change = {} 81 | 82 | for key in before: 83 | if key == 'updated_time': 84 | continue 85 | 86 | if before[key] != after[key]: 87 | change[key] = { 88 | 'before': before.get(key), 89 | 'after': after.get(key), 90 | } 91 | 92 | if not change: 93 | event_log.change_data = None 94 | else: 95 | event_log.change_data = json.dumps(change, ensure_ascii=False, cls=CustomJSONEncoder) 96 | -------------------------------------------------------------------------------- /fastapi_plus/schema/base.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Dict 3 | 4 | from pydantic import BaseModel 5 | 6 | from ..utils.json_encoders import JSONEncoders 7 | 8 | 9 | class BaseSchema(BaseModel): 10 | """ 11 | 基础Schema 12 | """ 13 | 14 | class Config: 15 | json_encoders = JSONEncoders.json_encoders # 使用自定义json转换 16 | 17 | 18 | class BaseObjSchema(BaseModel): 19 | """ 20 | 基础ObjSchema 21 | """ 22 | 23 | class Config: 24 | json_encoders = JSONEncoders.json_encoders 25 | orm_mode = True # 为模型实例 26 | 27 | 28 | class RespBaseSchema(BaseSchema): 29 | """ 30 | 基础返回Schema 31 | """ 32 | code: int = 0 # 返回编号 33 | message: str = 'SUCCESS' # 返回消息 34 | 35 | 36 | class RespIdSchema(RespBaseSchema): 37 | """ 38 | 返回Schema,带id 39 | """ 40 | id: int = 0 # 返回id 41 | 42 | 43 | class RespDetailSchema(RespBaseSchema): 44 | """ 45 | 详情返回Schema 46 | """ 47 | detail: dict = None # 返回详情 48 | 49 | 50 | class RespListSchema(RespBaseSchema): 51 | """ 52 | 列表返回Schema 53 | """ 54 | page: int = 0 # 当前页码 55 | size: int = 0 # 每页大小 56 | count: int = 0 # 数据总条数 57 | page_count: int = 0 # 总页数 58 | list: List[Dict] = None # 数据list 59 | 60 | 61 | class ListFilterSchema(BaseModel): 62 | """ 63 | 列表参数:过滤条件Schema 64 | """ 65 | key: str # 字段名 66 | condition: str # 过滤条件 67 | value: str # 条件值,如condition为in或!in时,value为用“,”分割的多值得字符串 68 | 69 | 70 | class ListOrderSchema(BaseModel): 71 | """ 72 | 列表参数:排序条件Schema 73 | """ 74 | key: str # 字段名 75 | condition: str # 排序条件 76 | 77 | 78 | class ListKeySchema(BaseModel): 79 | """ 80 | 列表参数:字段条件Schema 81 | """ 82 | key: str # 字段名 83 | rename: str = None # 字段名重命名, 为空则不进行重命名 84 | 85 | 86 | class ListArgsSchema(BaseModel): 87 | """ 88 | 列表参数Schema 89 | """ 90 | page: int = 1 # 当前页码 91 | size: int = 10 # 每页条数 92 | keywords: str = None # 关键字,用于模糊、分词搜索 93 | is_deleted: str = None # 软删标记 94 | user_id: int = None # 数据对应用户id 95 | filters: List[ListFilterSchema] = None # 过滤条件 96 | orders: List[ListOrderSchema] = None # 排序条件 97 | keys: List[ListKeySchema] = None # 字段条件 98 | 99 | 100 | class UserBaseSchema(BaseObjSchema): 101 | """ 102 | 用户基础Schema 103 | """ 104 | id: int = None # 用户id 105 | name: str = None # 用户名称 106 | 107 | 108 | class FileBaseSchema(BaseObjSchema): 109 | """ 110 | 文件基础Schema 111 | """ 112 | id: int # 文件id 113 | name: str # 文件名称 114 | suffix: str # 文件后缀 115 | 116 | 117 | class InfoSchema(BaseObjSchema): 118 | """ 119 | 详情基础Schema 120 | """ 121 | id: int = None 122 | parent_id: int = None 123 | type: int = None 124 | sort: int = None 125 | status: int = None 126 | code: str = None 127 | name: str = None 128 | label: str = None 129 | logo: str = None 130 | url: str = None 131 | info: str = None 132 | remark: str = None 133 | 134 | 135 | class DetailSchema(InfoSchema): 136 | """ 137 | 详情基础Schema 138 | """ 139 | created_time: datetime 140 | updated_time: datetime 141 | -------------------------------------------------------------------------------- /fastapi_plus/service/base.py: -------------------------------------------------------------------------------- 1 | from ..dao.event import EventDao 2 | from ..schema.base import ListArgsSchema, RespListSchema, RespIdSchema, RespBaseSchema 3 | 4 | 5 | class BaseService(object): 6 | """Base(基础)服务,用于被继承. 7 | 8 | CRUD基础服务类,拥有基本方法,可直接继承使用 9 | 10 | Attributes: 11 | auth_data: 认证数据,包括用户、权限等 12 | user_id: 当前操作用户id 13 | event_dao: 业务事件dao 14 | dao: 当前业务数据处理类 15 | """ 16 | 17 | auth_data: dict = {} 18 | user_id: int = 0 19 | event_dao: EventDao 20 | dao = None 21 | Model = None 22 | 23 | def __init__(self, user_id: int = 0, auth_data: dict = {}): 24 | """Service初始化.""" 25 | 26 | self.user_id = user_id 27 | self.auth_data = auth_data 28 | self.event_dao = EventDao(user_id) 29 | 30 | def read(self, id: int) -> Model: 31 | """读取单条数据. 32 | 33 | Args: 34 | id: 数据id 35 | 36 | Returns: 37 | 一个model实体 38 | """ 39 | 40 | return self.dao.read(id) 41 | 42 | def list(self, args: ListArgsSchema) -> RespListSchema: 43 | """读取多条数据. 44 | 45 | Args: 46 | args: 列表请求参数,详见ListArgsSchema 47 | 48 | Returns: 49 | 多个model实体组成的List 50 | """ 51 | 52 | return self.dao.read_list(args) 53 | 54 | def create(self, schema) -> RespIdSchema: 55 | """创建一条数据. 56 | 57 | Args: 58 | schema: model对应的schema,详见schema中对应的实体 59 | model: model的实体 60 | 61 | Returns: 62 | 是否创建成功,创建成功则附加数据id 63 | """ 64 | 65 | resp = RespIdSchema() 66 | model = self.Model() 67 | 68 | self.set_model_by_schema(schema, model) 69 | model.user_id = self.user_id 70 | model.created_by = self.user_id 71 | model.updated_by = self.user_id 72 | self.dao.create(model) 73 | 74 | event_log = self.event_dao.get_event_log(2, model.__tablename__) 75 | event_log.name = '创建{}:{}'.format(model.__table_args__.get('comment', '数据'), model.name) 76 | event_log.relation_id = model.id 77 | self.event_dao.create_event_log(event_log, model) 78 | 79 | resp.id = model.id 80 | 81 | return resp 82 | 83 | @staticmethod 84 | def set_model_by_schema(schema, model): 85 | """给model赋值,从schema. 86 | 87 | Args: 88 | schema: model对应的schema,详见schema中对应的实体 89 | model: model的实体 90 | 91 | Returns: 92 | 是否创建成功,创建成功则附加数据id 93 | """ 94 | 95 | for (key, value) in schema: 96 | model.__setattr__(key, value) 97 | 98 | if hasattr(model, 'search'): 99 | model.search = model.name 100 | 101 | def update(self, schema) -> RespBaseSchema: 102 | """更新一条数据. 103 | 104 | Args: 105 | schema: model对应的schema,详见schema中对应的实体 106 | model: model的实体 107 | 108 | Returns: 109 | 是否更新成功 110 | """ 111 | 112 | resp = RespBaseSchema() 113 | 114 | model = self.dao.read(schema.id) 115 | 116 | if not model: 117 | resp.code = 2002191527 118 | resp.message = '找不到对应的:{}'.format(model.__table_args__.get('comment', '数据')) 119 | return resp 120 | 121 | event_log = self.event_dao.get_event_log(1, model.__tablename__, model) 122 | event_log.name = '修改{}:{}'.format(model.__table_args__.get('comment', '数据'), model.name) 123 | 124 | self.set_model_by_schema(schema, model) 125 | model.updated_by = self.user_id 126 | self.dao.update(model) 127 | 128 | self.event_dao.create_event_log(event_log, model) 129 | 130 | return resp 131 | 132 | def delete(self, id: int) -> RespBaseSchema: 133 | """删除单条数据. 134 | 135 | Args: 136 | id: 数据id 137 | 138 | Returns: 139 | 是否删除成功 140 | """ 141 | 142 | resp = RespBaseSchema() 143 | 144 | model = self.dao.read(id) 145 | 146 | if not model: 147 | resp.code = 2002191553 148 | resp.message = '找不到对应的:{}'.format(model.__table_args__.get('comment', '数据')) 149 | return resp 150 | 151 | event_log = self.event_dao.get_event_log(5, model.__tablename__, model) 152 | event_log.name = '删除{}:{}'.format(model.__table_args__.get('comment', '数据'), model.name) 153 | 154 | self.dao.delete(model) 155 | 156 | self.event_dao.create_event_log(event_log, model) 157 | 158 | return resp 159 | -------------------------------------------------------------------------------- /fastapi_plus/utils/sync_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | 5 | from fastapi_plus.utils.db import DbConfig 6 | 7 | 8 | class SyncModel(object): 9 | """SyncModel,同步数据模型. 10 | 11 | 将数据库中的表转换成model 12 | 13 | Attributes: 14 | model_header: model头部 15 | is_use_base_model: 是否使用基础model 16 | base_model_lines: 基础model行 17 | dao: 当前业务数据处理类 18 | """ 19 | 20 | app_dir: str # 应用目录名 21 | model_header = 'from fastapi_plus.model.base import *\n\n\n' 22 | is_use_base_model = False 23 | base_model_path = os.path.dirname(os.path.dirname(__file__)) + os.sep + 'model' + os.sep + 'base.py' 24 | base_model_lines = [] 25 | db_config: DbConfig = None 26 | 27 | def sqlacodegen(self, file_path): 28 | subprocess.call(['sqlacodegen', self.db_config.get_url(), '--outfile', file_path]) 29 | 30 | def get_models_content(self, file_path): 31 | self.sqlacodegen(file_path) 32 | content = self._read_file_content(file_path) 33 | return content 34 | 35 | @staticmethod 36 | def _read_file_content(file_path: str) -> str: 37 | """ 38 | 读取文件 39 | :param file_path: 40 | :return: 41 | """ 42 | with open(file_path, 'r', encoding='utf-8') as f: 43 | return f.read() 44 | 45 | @staticmethod 46 | def _transform_name(src: str, first_upper: bool = True): 47 | """ 48 | 将下划线分隔的名字,转换为驼峰模式 49 | :param src: 50 | :param first_upper: 转换后的首字母是否指定大写(如 51 | :return: 52 | """ 53 | arr = src.split('_') 54 | res = '' 55 | for i in arr: 56 | res = res + i[0].upper() + i[1:] 57 | 58 | if not first_upper: 59 | res = res[0].lower() + res[1:] 60 | return res 61 | 62 | @staticmethod 63 | def _get_table_name(content): 64 | table_name_list = re.findall(".*__tablename__ = '(.*)'.*", content) 65 | 66 | if table_name_list: 67 | return table_name_list[0] 68 | else: 69 | return None 70 | 71 | def sync(self, app_dir: str == 'app', db_config: DbConfig, is_use_base_model: bool = False, 72 | base_model_path: str = None, 73 | model_header: str = None): 74 | 75 | self.app_dir = app_dir 76 | self.db_config = db_config 77 | file_path = self.app_dir + os.sep + 'temporary' + os.sep + 'models.py' 78 | models_content = self.get_models_content(file_path) 79 | content_list = models_content.split('\n\nclass ') 80 | 81 | if base_model_path: 82 | self.base_model_path = base_model_path 83 | 84 | if model_header: 85 | self.model_header = model_header 86 | 87 | if is_use_base_model: 88 | self.is_use_base_model = is_use_base_model 89 | self._init_base_model() 90 | else: 91 | self.model_header = content_list[0] 92 | 93 | for content in content_list: 94 | self._save_model(content) 95 | 96 | os.remove(file_path) 97 | 98 | def _init_base_model(self): 99 | content = self._read_file_content(self.base_model_path) 100 | class_list = content.split('class Base(DeclarativeBase):') 101 | lines = class_list[1].split('\n') 102 | 103 | for line in lines: 104 | line = line.strip() 105 | 106 | if line: 107 | self.base_model_lines.append(line) 108 | 109 | def _use_base_model(self, model_content): 110 | model_lines = model_content.split('\n') 111 | del model_lines[0] 112 | lines = [] 113 | 114 | for model_line in model_lines: 115 | for base_model_line in self.base_model_lines: 116 | if model_line.find(base_model_line) > -1: 117 | model_line = None 118 | break 119 | 120 | if model_line is not None: 121 | lines.append(model_line) 122 | 123 | return '\n'.join(lines) 124 | 125 | def _save_model(self, content): 126 | table_name = self._get_table_name(content) 127 | if not table_name: 128 | return 129 | 130 | if self.is_use_base_model: 131 | content = self._use_base_model(content) 132 | 133 | file_name = table_name[len(self.db_config.table_name_prefix):] 134 | class_name = self._transform_name(file_name) 135 | file_path = self.app_dir + os.sep + 'model' + os.sep + file_name + '.py' 136 | file_content = self.model_header + 'class ' + class_name + '(Base):\n' + content 137 | file_content = file_content.replace('Column(JSON,', 'Column(LONGTEXT,') 138 | 139 | with open(file_path, 'w', encoding='utf-8', newline='\n') as f: 140 | f.write(file_content) 141 | -------------------------------------------------------------------------------- /fastapi_plus/dao/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | from sqlalchemy import and_, func 5 | 6 | from ..schema.base import ListArgsSchema, ListOrderSchema, ListKeySchema, ListFilterSchema, RespListSchema 7 | from ..utils.db import DbUtils 8 | from ..utils.obj2dict import obj2dict 9 | 10 | 11 | class BaseDao(object): 12 | """Base(基础)Dao,用于被继承. 13 | 14 | CRUD基础Dao类,拥有基本方法,可直接继承使用 15 | 16 | Attributes: 17 | user_id: 当前操作用户id 18 | db: db实体 19 | """ 20 | Model = None 21 | 22 | def __init__(self, user_id=0): 23 | self.user_id = user_id 24 | self.db = DbUtils() 25 | 26 | def create(self, model: Model): 27 | """ 28 | 创建一条数据 29 | :param model: 数据模型实例 30 | """ 31 | self.db.sess.add(model) 32 | self.db.sess.flush() 33 | 34 | def read(self, id: int, user_id: int = None, is_deleted: int = 0) -> Model: 35 | """ 36 | 读取一条数据 37 | :param id: 数据id 38 | :param user_id: 用户id 39 | :param is_deleted: 是否为已删除数据 40 | :return: 数据模型实例 41 | """ 42 | 43 | # 定义:query过滤条件 44 | filters = [] 45 | 46 | # 判断:软删标记 47 | if is_deleted == 1: 48 | filters.append(self.Model.is_deleted == 1) 49 | elif is_deleted == 2: 50 | pass 51 | else: 52 | filters.append(self.Model.is_deleted == 0) 53 | 54 | # 判断:是否限制指定用户的数据 55 | if user_id: 56 | filters.append(self.Model.user_id == user_id) 57 | 58 | return self.db.sess.query(self.Model).filter( 59 | self.Model.id == id, 60 | *filters 61 | ).first() 62 | 63 | def update(self, model: Model): 64 | """ 65 | 更新一条数据 66 | :param model: 数据模型实例 67 | :return: 68 | """ 69 | self.db.sess.add(model) 70 | self.db.sess.flush() 71 | 72 | def delete(self, model: Model): 73 | """ 74 | 删除一条数据,软删除 75 | :param model: 数据模型实体 76 | """ 77 | model.is_deleted = 1 78 | self.update(model) 79 | 80 | def read_list(self, args: ListArgsSchema) -> RespListSchema: 81 | """ 82 | 读取数据列表 83 | :param args: 聚合参数,详见:ListArgsSchema 84 | :return: 返回数据列表结构,详见:RespListSchema 85 | """ 86 | 87 | # 定义:query过滤条件 88 | filters = [] 89 | 90 | # 判断:是否包含已软删除的数据 91 | if args.is_deleted != 'all': 92 | filters.append(self.Model.is_deleted == 0) 93 | 94 | # 判断:是否限制指定用户的数据 95 | if args.user_id: 96 | filters.append(self.Model.user_id == args.user_id) 97 | 98 | # 增加:传入调整 99 | filters.extend(self._handle_list_filters(args.filters)) 100 | 101 | # 判断:是否进行关键词搜索 102 | if args.keywords and hasattr(self.Model, 'search'): 103 | filters.append(and_(*[self.Model.search.like('%' + kw + '%') for kw in args.keywords.split(' ')])) 104 | 105 | # 执行:数据检索 106 | query = self.db.sess.query(self.Model).filter(*filters) 107 | count = query.count() 108 | 109 | # 判断: 结果数,是否继续查询 110 | if count > 0: 111 | orders = self._handle_list_orders(args.orders) 112 | obj_list = query.order_by(*orders).offset((args.page - 1) * args.size).limit(args.size).all() 113 | else: 114 | obj_list = [] 115 | 116 | # 构造:返回结构 117 | resp = RespListSchema() 118 | resp.page = args.page 119 | resp.size = args.size 120 | resp.count = count 121 | resp.page_count = math.ceil(count / args.size) # 计算总页数 122 | resp.list = self._handle_list_keys(args.keys, obj_list) # 处理list 123 | 124 | return resp 125 | 126 | def _handle_list_filters(self, args_filters: ListFilterSchema): 127 | """ 128 | 处理list接口传入的过滤条件 129 | :param args_filters: 传入过滤条件 130 | :return: 转换后的sqlalchemy过滤条件 131 | """ 132 | filters = [] 133 | 134 | if args_filters: 135 | for item in args_filters: 136 | if hasattr(self.Model, item.key): 137 | attr = getattr(self.Model, item.key) 138 | 139 | if item.condition == '=': 140 | filters.append(attr == item.value) 141 | elif item.condition == '!=': 142 | filters.append(attr != item.value) 143 | elif item.condition == '<': 144 | filters.append(attr < item.value) 145 | elif item.condition == '>': 146 | filters.append(attr > item.value) 147 | elif item.condition == '<=': 148 | filters.append(attr <= item.value) 149 | elif item.condition == '>=': 150 | filters.append(attr >= item.value) 151 | elif item.condition == 'like': 152 | filters.append(attr.like('%' + item.value + '%')) 153 | elif item.condition == 'in': 154 | filters.append(attr.in_(item.value.split(','))) 155 | elif item.condition == '!in': 156 | filters.append(~attr.in_(item.value.split(','))) 157 | elif item.condition == 'null': 158 | filters.append(attr.is_(None)) 159 | elif item.condition == '!null': 160 | filters.append(~attr.isnot(None)) 161 | 162 | return filters 163 | 164 | def _handle_list_orders(self, args_orders: ListOrderSchema): 165 | """ 166 | 处理list接口传入的排序条件 167 | :param args_orders: 传入排序条件 168 | :return: 转换后的sqlalchemy排序条件 169 | """ 170 | orders = [] 171 | 172 | if args_orders: 173 | for item in args_orders: 174 | if hasattr(self.Model, item.key): 175 | attr = getattr(self.Model, item.key) 176 | 177 | if item.condition == 'desc': 178 | orders.append(attr.desc()) 179 | elif item.condition == 'acs': 180 | orders.append(attr) 181 | elif item.condition == 'rand': # 随机排序 182 | orders.append(func.rand()) 183 | 184 | return orders 185 | 186 | def _handle_list_keys(self, args_keys: ListKeySchema, obj_list: List): 187 | """ 188 | 处理list返回数据,根据传入参数keys进行过滤 189 | :param args_keys: 传入过滤字段 190 | :return: 转换后的list数据,数据转为dict类型 191 | """ 192 | keys = [] 193 | 194 | if args_keys: 195 | for item in args_keys: 196 | if hasattr(self.Model, item.key): 197 | keys.append(item) 198 | 199 | resp_list = [] 200 | 201 | for obj in obj_list: 202 | dict_1 = obj2dict(obj) 203 | 204 | # 判断:keys存在,不存在则返回所有字段 205 | if keys: 206 | dict_2 = {} 207 | for item in keys: 208 | if item.rename: 209 | dict_2[item.rename] = dict_1[item.key] 210 | else: 211 | dict_2[item.key] = dict_1[item.key] 212 | else: 213 | dict_2 = dict_1 214 | 215 | resp_list.append(dict_2) 216 | 217 | return resp_list 218 | --------------------------------------------------------------------------------