├── jian ├── __init__.py ├── enums.py ├── plugin.py ├── redprint.py ├── forma.py ├── config.py ├── sse.py ├── db.py ├── loader.py ├── log.py ├── jwt.py ├── exception.py ├── notify.py ├── util.py ├── interface.py └── core.py ├── LICENSE ├── .gitignore └── README.md /jian/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | from .core import Jian, route_meta, manager 6 | from .db import db 7 | from .jwt import login_required, group_required, admin_required 8 | -------------------------------------------------------------------------------- /jian/enums.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | from enum import Enum 6 | 7 | 8 | # status for user is super 9 | # 是否为超级管理员的枚举 10 | class UserSuper(Enum): 11 | COMMON = 1 12 | SUPER = 2 13 | 14 | 15 | # : status for user is active 16 | # : 当前用户是否为激活状态的枚举 17 | class UserActive(Enum): 18 | ACTIVE = 1 19 | NOT_ACTIVE = 2 20 | -------------------------------------------------------------------------------- /jian/plugin.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | 7 | class Plugin(object): 8 | def __init__(self, name=None): 9 | """ 10 | :param name: plugin的名称 11 | """ 12 | # container of plugin's controllers 13 | # 控制器容器 14 | self.controllers = {} 15 | # container of plugin's models 16 | # 模型层容器 17 | self.models = {} 18 | # plugin's services 19 | self.services = {} 20 | 21 | self.name = name 22 | 23 | def add_model(self, name, model): 24 | self.models[name] = model 25 | 26 | def get_model(self, name): 27 | return self.models.get(name) 28 | 29 | def add_controller(self, name, controller): 30 | self.controllers[name] = controller 31 | 32 | def add_service(self, name, service): 33 | self.services[name] = service 34 | 35 | def get_service(self, name): 36 | return self.services.get(name) 37 | -------------------------------------------------------------------------------- /jian/redprint.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | 7 | class Redprint: 8 | def __init__(self, name, with_prefix=True): 9 | self.name = name 10 | self.with_prefix = with_prefix 11 | self.mound = [] 12 | 13 | def route(self, rule, **options): 14 | def decorator(f): 15 | self.mound.append((f, rule, options)) 16 | return f 17 | 18 | return decorator 19 | 20 | def register(self, bp, url_prefix=None): 21 | if url_prefix is None and self.with_prefix: 22 | url_prefix = '/' + self.name 23 | else: 24 | url_prefix = '' 25 | for f, rule, options in self.mound: 26 | endpoint = self.name + '+' + options.pop("endpoint", f.__name__) 27 | if rule: 28 | url = url_prefix + rule 29 | bp.add_url_rule(url, endpoint, f, **options) 30 | else: 31 | bp.add_url_rule(url_prefix, endpoint, f, **options) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Little student. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /jian/forma.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/07/03 4 | """ 5 | 6 | from flask import request 7 | from wtforms import Form as WTForm, IntegerField 8 | from wtforms.validators import StopValidation 9 | 10 | from .exception import ParameterException 11 | 12 | 13 | class Form(WTForm): 14 | def __init__(self): 15 | data = request.get_json(silent=True) 16 | args = request.args.to_dict() 17 | super(Form, self).__init__(data=data, **args) 18 | 19 | def validate_for_api(self): 20 | valid = super(Form, self).validate() 21 | if not valid: 22 | raise ParameterException(msg=self.errors) 23 | return self 24 | 25 | 26 | def integer_check(form, field): 27 | if field.data is None: 28 | raise StopValidation('输入字段不可为空') 29 | try: 30 | field.data = int(field.data) 31 | except ValueError: 32 | raise StopValidation('不是一个有效整数') 33 | 34 | 35 | class JianIntegerField(IntegerField): 36 | """ 37 | 校验一个字段是否为正整数 38 | """ 39 | 40 | def __init__(self, label=None, validators=None, **kwargs): 41 | if validators is not None and type(validators) == list: 42 | validators.insert(0, integer_check) 43 | else: 44 | validators = [integer_check] 45 | super(JianIntegerField, self).__init__(label, validators, **kwargs) -------------------------------------------------------------------------------- /jian/config.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | from collections import defaultdict 7 | 8 | 9 | class Config(defaultdict): 10 | def add_plugin_config(self, plugin_name, obj): 11 | if type(obj) is dict: 12 | if self.get(plugin_name, None) is None: 13 | self[plugin_name] = {} 14 | for k, v in obj.items(): 15 | self[plugin_name][k] = v 16 | 17 | def add_plugin_config_item(self, plugin_name, key, value): 18 | if self.get(plugin_name, None) is None: 19 | self[plugin_name] = {} 20 | self[plugin_name][key] = value 21 | 22 | def get_plugin_config(self, plugin_name, default=None): 23 | return self.get(plugin_name, default) 24 | 25 | def get_plugin_config_item(self, plugin_name, key, default=None): 26 | plugin_conf = self.get(plugin_name) 27 | if plugin_conf is None: 28 | return default 29 | return plugin_conf.get(key, default) 30 | 31 | def get_config(self, key: str, default=None): 32 | """ plugin_name.key """ 33 | if '.' not in key: 34 | return self.get(key, default) 35 | index = key.rindex('.') 36 | plugin_name = key[:index] 37 | plugin_key = key[index + 1:] 38 | plugin_conf = self.get(plugin_name) 39 | if plugin_conf is None: 40 | return default 41 | return plugin_conf.get(plugin_key, default) 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /jian/sse.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | import json 7 | from collections import deque 8 | 9 | 10 | class Sse(object): 11 | messages = deque() 12 | _retry = None 13 | 14 | def __init__(self, default_retry=2000): 15 | self._buffer = [] 16 | self._default_id = 1 17 | self.set_retry(default_retry) 18 | 19 | def set_retry(self, num): 20 | self._retry = num 21 | self._buffer.append("retry: {0}\n".format(self._retry)) 22 | 23 | def set_event_id(self, event_id=None): 24 | if event_id: 25 | self._default_id = event_id 26 | self._buffer.append("id: {0}\n".format(event_id)) 27 | else: 28 | self._buffer.append("id: {0}\n".format(self._default_id)) 29 | 30 | def reset_event_id(self): 31 | self.set_event_id(1) 32 | 33 | def increase_id(self): 34 | self._default_id += 1 35 | 36 | def add_message(self, event, obj, flush=True): 37 | self.set_event_id() 38 | self._buffer.append("event: {0}\n".format(event)) 39 | line = json.dumps(obj, ensure_ascii=False) 40 | self._buffer.append("data: {0}\n".format(line)) 41 | self._buffer.append("\n") 42 | if flush: 43 | self.flush() 44 | 45 | def flush(self): 46 | self.messages.append(self.join_buffer()) 47 | self._buffer.clear() 48 | self.increase_id() 49 | 50 | def pop(self): 51 | return self.messages.popleft() 52 | 53 | def heartbeat(self, comment=None): 54 | # 发送注释 : this is a test stream\n\n 告诉客户端,服务器还活着 55 | if comment and type(comment) == 'str': 56 | self._buffer.append(comment) 57 | else: 58 | self._buffer.append(': sse sever is still alive \n\n') 59 | tmp = self.join_buffer() 60 | self._buffer.clear() 61 | return tmp 62 | 63 | def join_buffer(self): 64 | string = '' 65 | for it in self._buffer: 66 | string += it 67 | return string 68 | 69 | def exit_message(self): 70 | return len(self.messages) > 0 71 | 72 | 73 | sser = Sse() 74 | -------------------------------------------------------------------------------- /jian/db.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery 7 | from sqlalchemy import inspect, orm, func 8 | from contextlib import contextmanager 9 | 10 | from .exception import NotFound 11 | 12 | 13 | class SQLAlchemy(_SQLAlchemy): 14 | @contextmanager 15 | def auto_commit(self): 16 | try: 17 | yield 18 | self.session.commit() 19 | except Exception as e: 20 | self.session.rollback() 21 | raise e 22 | 23 | 24 | class Query(BaseQuery): 25 | 26 | def filter_by(self, soft=True, **kwargs): 27 | # soft 应用软删除 28 | soft = kwargs.get('soft') 29 | if soft: 30 | kwargs['delete_time'] = None 31 | if soft is not None: 32 | kwargs.pop('soft') 33 | return super(Query, self).filter_by(**kwargs) 34 | 35 | def get_or_404(self, ident): 36 | rv = self.get(ident) 37 | if not rv: 38 | raise NotFound() 39 | return rv 40 | 41 | def first_or_404(self): 42 | rv = self.first() 43 | if not rv: 44 | raise NotFound() 45 | return rv 46 | 47 | 48 | db = SQLAlchemy(query_class=Query) 49 | 50 | 51 | def get_total_nums(cls, is_soft=False, **kwargs): 52 | nums = db.session.query(func.count(cls.id)) 53 | nums = nums.filter(cls.delete_time == None).filter_by(**kwargs).scalar() if is_soft else nums.filter().scalar() 54 | if nums: 55 | return nums 56 | else: 57 | return 0 58 | 59 | 60 | class MixinJSONSerializer: 61 | @orm.reconstructor 62 | def init_on_load(self): 63 | self._fields = [] 64 | self._exclude = [] 65 | 66 | self._set_fields() 67 | self.__prune_fields() 68 | 69 | def _set_fields(self): 70 | pass 71 | 72 | def __prune_fields(self): 73 | columns = inspect(self.__class__).columns 74 | if not self._fields: 75 | all_columns = set([column.name for column in columns]) 76 | self._fields = list(all_columns - set(self._exclude)) 77 | 78 | def hide(self, *args): 79 | for key in args: 80 | self._fields.remove(key) 81 | return self 82 | 83 | def keys(self): 84 | return self._fields 85 | 86 | def __getitem__(self, key): 87 | return getattr(self, key) 88 | -------------------------------------------------------------------------------- /jian/loader.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | from importlib import import_module 6 | 7 | from .redprint import Redprint 8 | from .db import db 9 | from .plugin import Plugin 10 | from .core import gigi_config 11 | 12 | 13 | class Loader(object): 14 | plugin_path: dict = None 15 | 16 | def __init__(self, plugin_path): 17 | self.plugins = {} 18 | assert type(plugin_path) is dict, 'plugin_path must be a dict' 19 | self.plugin_path = plugin_path 20 | self.load_plugins_config() 21 | self.load_plugins() 22 | 23 | def load_plugins(self): 24 | for name, conf in self.plugin_path.items(): 25 | enable = conf.get('enable', None) 26 | if enable: 27 | path = conf.get('path') 28 | # load plugin 29 | path and self._load_plugin(f'{path}.app.__init__', name) 30 | 31 | def load_plugins_config(self): 32 | for name, conf in self.plugin_path.items(): 33 | path = conf.get('path', None) 34 | # load config 35 | self._load_config(f'{path}.config', name, conf) 36 | 37 | def _load_plugin(self, path, name): 38 | mod = import_module(path) 39 | plugin = Plugin(name=name) 40 | dic = mod.__dict__ 41 | for key in dic.keys(): 42 | if not key.startswith('_'): 43 | attr = dic[key] 44 | if isinstance(attr, Redprint): 45 | plugin.add_controller(attr.name, attr) 46 | elif issubclass(attr, db.Model): 47 | plugin.add_model(attr.__name__, attr) 48 | # 暂时废弃加载service,用处不大 49 | # elif issubclass(attr, ServiceInterface): 50 | # plugin.add_service(attr.__name__, attr) 51 | self.plugins[plugin.name] = plugin 52 | 53 | def _load_config(self, config_path, name, conf): 54 | default_conf = {**conf} if conf else {} 55 | try: 56 | if config_path: 57 | mod = import_module(config_path) 58 | dic = mod.__dict__ 59 | for key in dic.keys(): 60 | if not key.startswith('_'): 61 | default_conf[key] = dic[key] 62 | except ModuleNotFoundError as e: 63 | pass 64 | gigi_config.add_plugin_config(plugin_name=name, obj=default_conf) 65 | -------------------------------------------------------------------------------- /jian/log.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | from functools import wraps 7 | import re 8 | from flask import Response, request 9 | from flask_jwt_extended import get_current_user 10 | from .core import find_info_by_ep, Log 11 | 12 | REG_XP = r'[{](.*?)[}]' 13 | OBJECTS = ['user', 'response', 'request'] 14 | 15 | 16 | class Logger(object): 17 | # message template 18 | template = None 19 | 20 | def __init__(self, template=None): 21 | if template: 22 | self.template: str = template 23 | elif self.template is None: 24 | raise Exception('template must not be None!') 25 | self.message = '' 26 | self.response = None 27 | self.user = None 28 | 29 | def __call__(self, func): 30 | @wraps(func) 31 | def wrap(*args, **kwargs): 32 | response: Response = func(*args, **kwargs) 33 | self.response = response 34 | self.user = get_current_user() 35 | self.message = self._parse_template() 36 | self.write_log() 37 | return response 38 | 39 | return wrap 40 | 41 | def write_log(self): 42 | info = find_info_by_ep(request.endpoint) 43 | authority = info.auth if info is not None else '' 44 | status_code = getattr(self.response, 'status_code', None) 45 | if status_code is None: 46 | status_code = getattr(self.response, 'code', None) 47 | if status_code is None: 48 | status_code = 0 49 | Log.create_log(message=self.message, user_id=self.user.id, user_name=self.user.nickname, 50 | status_code=status_code, method=request.method, 51 | path=request.path, authority=authority, commit=True) 52 | 53 | # 解析自定义模板 54 | def _parse_template(self): 55 | message = self.template 56 | total = re.findall(REG_XP, message) 57 | for it in total: 58 | assert '.' in it, '%s中必须包含 . ,且为一个' % it 59 | i = it.rindex('.') 60 | obj = it[:i] 61 | assert obj in OBJECTS, '%s只能为user,response,request中的一个' % obj 62 | prop = it[i + 1:] 63 | if obj == 'user': 64 | item = getattr(self.user, prop, '') 65 | elif obj == 'response': 66 | item = getattr(self.response, prop, '') 67 | else: 68 | item = getattr(request, prop, '') 69 | message = message.replace('{%s}' % it, str(item)) 70 | return message 71 | -------------------------------------------------------------------------------- /jian/jwt.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | from functools import wraps 7 | 8 | from flask_jwt_extended import JWTManager, verify_jwt_in_request, get_current_user, create_access_token, \ 9 | create_refresh_token 10 | 11 | from .exception import AuthFailed, InvalidTokenException, ExpiredTokenException, NotFound 12 | 13 | jwt = JWTManager() 14 | 15 | 16 | def admin_required(fn): 17 | @wraps(fn) 18 | def wrapper(*args, **kwargs): 19 | verify_jwt_in_request() 20 | current_user = get_current_user() 21 | if not current_user.is_super: 22 | raise AuthFailed(msg='只有超级管理员可操作') 23 | return fn(*args, **kwargs) 24 | 25 | return wrapper 26 | 27 | 28 | def group_required(fn): 29 | @wraps(fn) 30 | def wrapper(*args, **kwargs): 31 | verify_jwt_in_request() 32 | current_user = get_current_user() 33 | # check current user is active or not 34 | # 判断当前用户是否为激活状态 35 | if not current_user.is_active: 36 | raise AuthFailed(msg='您目前处于未激活状态,请联系超级管理员') 37 | # not super 38 | if not current_user.is_super: 39 | group_id = current_user.group_id 40 | if group_id is None: 41 | raise AuthFailed(msg='您还不属于任何权限组,请联系超级管理员获得权限') 42 | from .core import is_user_allowed 43 | it = is_user_allowed(group_id) 44 | if not it: 45 | raise AuthFailed(msg='权限不够,请联系超级管理员获得权限') 46 | else: 47 | return fn(*args, **kwargs) 48 | else: 49 | return fn(*args, **kwargs) 50 | 51 | return wrapper 52 | 53 | 54 | def login_required(fn): 55 | @wraps(fn) 56 | def wrapper(*args, **kwargs): 57 | verify_jwt_in_request() 58 | return fn(*args, **kwargs) 59 | 60 | return wrapper 61 | 62 | 63 | @jwt.user_loader_callback_loader 64 | def user_loader_callback(identity): 65 | from .core import find_user 66 | # token is granted , user must be exit 67 | # 如果token已经被颁发,则该用户一定存在 68 | user = find_user(id=identity) 69 | if user is None: 70 | raise NotFound(msg='用户不存在') 71 | return user 72 | 73 | 74 | @jwt.expired_token_loader 75 | def expired_loader_callback(): 76 | return ExpiredTokenException() 77 | 78 | 79 | @jwt.invalid_token_loader 80 | def invalid_loader_callback(e): 81 | return InvalidTokenException() 82 | 83 | 84 | @jwt.unauthorized_loader 85 | def unauthorized_loader_callback(e): 86 | return AuthFailed(msg='认证失败,请检查请求头或者重新登陆') 87 | 88 | 89 | def get_tokens(user): 90 | access_token = create_access_token(identity=user.id) 91 | refresh_token = create_refresh_token(identity=user.id) 92 | return access_token, refresh_token -------------------------------------------------------------------------------- /jian/exception.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | from flask import json, request 6 | from werkzeug.exceptions import HTTPException 7 | from werkzeug._compat import text_type 8 | 9 | 10 | class APIException(HTTPException): 11 | code = 500 12 | msg = '抱歉,服务器未知错误' 13 | error_code = 999 14 | 15 | def __init__(self, msg=None, code=None, error_code=None, 16 | headers=None): 17 | if code: 18 | self.code = code 19 | if error_code: 20 | self.error_code = error_code 21 | if msg: 22 | self.msg = msg 23 | if headers is not None: 24 | headers_merged = headers.copy() 25 | headers_merged.update(self.headers) 26 | self.headers = headers_merged 27 | 28 | super(APIException, self).__init__(msg, None) 29 | 30 | def get_body(self, environ=None): 31 | body = dict( 32 | msg=self.msg, 33 | error_code=self.error_code, 34 | request=request.method + ' ' + self.get_url_no_param() 35 | ) 36 | text = json.dumps(body) 37 | return text_type(text) 38 | 39 | @staticmethod 40 | def get_url_no_param(): 41 | full_path = str(request.full_path) 42 | main_path = full_path.split('?') 43 | return main_path[0] 44 | 45 | def get_headers(self, environ=None): 46 | return [('Content-Type', 'application/json')] 47 | 48 | 49 | class Success(APIException): 50 | code = 201 51 | msg = '成功' 52 | error_code = 0 53 | 54 | 55 | class Failed(APIException): 56 | code = 400 57 | msg = '失败' 58 | error_code = 9999 59 | 60 | 61 | class AuthFailed(APIException): 62 | code = 401 63 | msg = '认证失败' 64 | error_code = 10000 65 | 66 | 67 | class NotFound(APIException): 68 | code = 404 69 | msg = '资源不存在' 70 | error_code = 10020 71 | 72 | 73 | class ParameterException(APIException): 74 | code = 400 75 | msg = '参数错误' 76 | error_code = 10030 77 | 78 | 79 | class InvalidTokenException(APIException): 80 | code = 401 81 | msg = '令牌失效' 82 | error_code = 10040 83 | 84 | 85 | class ExpiredTokenException(APIException): 86 | code = 422 87 | msg = '令牌过期' 88 | error_code = 10050 89 | 90 | 91 | class UnknownException(APIException): 92 | code = 400 93 | msg = '服务器未知错误' 94 | error_code = 999 95 | 96 | 97 | class RepeatException(APIException): 98 | code = 400 99 | msg = '字段重复' 100 | error_code = 10060 101 | 102 | 103 | class Forbidden(APIException): 104 | code = 401 105 | msg = '不可操作' 106 | error_code = 10070 107 | -------------------------------------------------------------------------------- /jian/notify.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | 7 | from functools import wraps 8 | import re 9 | from datetime import datetime 10 | from flask import Response, request 11 | from flask_jwt_extended import get_current_user 12 | from .sse import sser 13 | 14 | REG_XP = r'[{](.*?)[}]' 15 | OBJECTS = ['user', 'response', 'request'] 16 | SUCCESS_STATUS = [200, 201] 17 | MESSAGE_EVENTS = set() 18 | 19 | 20 | class Notify(object): 21 | def __init__(self, template=None, event=None, **kwargs): 22 | """ 23 | Notify a message or create a log 24 | :param template: message template 25 | {user.nickname}查看自己是否为激活状态 ,状态码为{response.status_code} -> pedro查看自己是否为激活状态 ,状态码为200 26 | :param write: write to db or not 27 | :param push: push to front_end or not 28 | """ 29 | if event: 30 | self.event = event 31 | elif self.event is None: 32 | raise Exception('event must not be None!') 33 | if template: 34 | self.template: str = template 35 | elif self.template is None: 36 | raise Exception('template must not be None!') 37 | # 加入所有types中 38 | MESSAGE_EVENTS.add(event) 39 | self.message = '' 40 | self.response = None 41 | self.user = None 42 | self.extra = kwargs 43 | 44 | def __call__(self, func): 45 | @wraps(func) 46 | def wrap(*args, **kwargs): 47 | response: Response = func(*args, **kwargs) 48 | self.response = response 49 | self.user = get_current_user() 50 | self.message = self._parse_template() 51 | self.push_message() 52 | return response 53 | 54 | return wrap 55 | 56 | def push_message(self): 57 | # status = '操作成功' if self.response.status_code in SUCCESS_STATUS else '操作失败' 58 | sser.add_message(self.event, 59 | {'message': self.message, 60 | 'time': int(datetime.now().timestamp()), 61 | **self.extra 62 | }) 63 | 64 | # 解析自定义消息的模板 65 | def _parse_template(self): 66 | message = self.template 67 | total = re.findall(REG_XP, message) 68 | for it in total: 69 | assert '.' in it, '%s中必须包含 . ,且为一个' % it 70 | i = it.rindex('.') 71 | obj = it[:i] 72 | assert obj in OBJECTS, '%s只能为user,response,request中的一个' % obj 73 | prop = it[i + 1:] 74 | if obj == 'user': 75 | item = getattr(self.user, prop, '') 76 | elif obj == 'response': 77 | item = getattr(self.response, prop, '') 78 | else: 79 | item = getattr(request, prop, '') 80 | message = message.replace('{%s}' % it, str(item)) 81 | return message 82 | 83 | def _check_can_push(self): 84 | # 超级管理员不可push,暂时测试可push 85 | if self.user.is_super: 86 | return False 87 | return True 88 | -------------------------------------------------------------------------------- /jian/util.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | import time 6 | import re 7 | import errno 8 | import random 9 | import types 10 | from importlib import import_module 11 | import os 12 | import importlib.util 13 | from flask import request, current_app 14 | 15 | from .exception import ParameterException 16 | 17 | 18 | def get_timestamp(fmt='%Y-%m-%d %H:%M:%S'): 19 | return time.strftime(fmt, time.localtime(time.time())) 20 | 21 | 22 | def get_pyfile(path, module_name, silent=False): 23 | """ 24 | get all properties of a pyfile 25 | 获得一个.py文件的所有属性 26 | :param path: path of pytfile 27 | :param module_name: name 28 | :param silent: show the error or not 29 | :return: all properties of a pyfile 30 | """ 31 | d = types.ModuleType(module_name) 32 | d.__file__ = path 33 | try: 34 | with open(path, mode='rb') as config_file: 35 | exec(compile(config_file.read(), path, 'exec'), d.__dict__) 36 | except IOError as e: 37 | if silent and e.errno in ( 38 | errno.ENOENT, errno.EISDIR, errno.ENOTDIR 39 | ): 40 | return False 41 | e.strerror = 'Unable to load configuration file (%s)' % e.strerror 42 | raise 43 | return d.__dict__ 44 | 45 | 46 | def load_object(path): 47 | """ 48 | 获得一个模块中的某个属性 49 | :param path: module path 50 | :return: the obj of module which you want get. 51 | """ 52 | try: 53 | dot = path.rindex('.') 54 | except ValueError: 55 | raise ValueError("Error loading object '%s': not a full path" % path) 56 | 57 | module, name = path[:dot], path[dot + 1:] 58 | mod = import_module(module) 59 | 60 | try: 61 | obj = getattr(mod, name) 62 | except AttributeError: 63 | raise NameError("Module '%s' doesn't define any object named '%s'" % (module, name)) 64 | 65 | return obj 66 | 67 | 68 | def import_module_abs(name, path): 69 | """ 70 | 绝对路径导入模块 71 | :param name: name of module 72 | :param path: absolute path of module 73 | :return: the module 74 | """ 75 | spec = importlib.util.spec_from_file_location(name, path) 76 | foo = importlib.util.module_from_spec(spec) 77 | spec.loader.exec_module(foo) 78 | 79 | 80 | def get_pwd(): 81 | """ 82 | :return: absolute current work path 83 | """ 84 | return os.path.abspath(os.getcwd()) 85 | 86 | 87 | def paginate(): 88 | count = int(request.args.get('count', current_app.config.get('COUNT_DEFAULT') if current_app.config.get( 89 | 'COUNT_DEFAULT') else 5)) 90 | start = int(request.args.get('page', current_app.config.get('PAGE_DEFAULT') if current_app.config.get( 91 | 'PAGE_DEFAULT') else 0)) 92 | count = 15 if count >= 15 else count 93 | start = start * count 94 | if start < 0 or count < 0: 95 | raise ParameterException() 96 | return start, count 97 | 98 | 99 | def camel2line(camel: str): 100 | p = re.compile(r'([a-z]|\d)([A-Z])') 101 | line = re.sub(p, r'\1_\2', camel).lower() 102 | return line 103 | 104 | 105 | def get_random_str(length): 106 | seed = "1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 107 | sa = [] 108 | for i in range(length): 109 | sa.append(random.choice(seed)) 110 | salt = ''.join(sa) 111 | return salt 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Jian (简单)

2 | 3 |

4 | 5 | flask version 6 | Python version 7 |

8 |

9 | 简介 | 更新计划 10 |

11 | 12 | # Jian 13 | 对Flask项目中一些常用第三方进行封装. 14 | 例如: 15 | - SQLalchemy 16 | - WTF-forms 17 | - JWT 18 | - Blueprint 19 | - Route_meta 20 | - Exception 21 | - Log 22 | 23 | 对这些优秀的第三方库进行二次开发,使更贴合我们项目中的使用. 24 | 25 | 26 | # 安装 install 27 | 28 | > pipenv or pip 29 | 30 | ```bash 31 | pipenv install jian 32 | ``` 33 | 34 | `or` 35 | 36 | ```bash 37 | pip install jian 38 | ``` 39 | 40 | ## 使用 41 | 42 | > 初始化项目,并且注册Jian 43 | 44 | ```python 45 | from jian import Jian 46 | 47 | 48 | def create_apps(): 49 | app = Flask(__name__) 50 | app.config.from_object('app.config.setting') 51 | Jian(app) 52 | ``` 53 | 54 | > 数据层操作 55 | 56 | ```python 57 | from sqlalchemy import Column, String, Integer 58 | 59 | from jian.interface import InfoCrud as Base 60 | 61 | 62 | class Friend(Base): 63 | id = Column(Integer, primary_key=True, autoincrement=True) 64 | url = Column(String(32), unique=True) 65 | doc = Column(String(50)) 66 | image = Column(String(50)) 67 | ``` 68 | 69 | > 表单验证 70 | ```python 71 | from wtforms import PasswordFiled 72 | from wtforms.validators import DataRequired 73 | from jian.forms import Form 74 | 75 | # 注册校验 76 | class RegisterForm(Form): 77 | password = PasswordField('新密码', validators=[ 78 | DataRequired(message='新密码不可为空'), 79 | Regexp(r'^[A-Za-z0-9_*&$#@]{6,22}$', message='密码长度必须在6~22位之间,包含字符、数字和 _ '), 80 | EqualTo('confirm_password', message='两次输入的密码不一致,请输入相同的密码')]) 81 | confirm_password = PasswordField('确认新密码', validators=[DataRequired(message='请确认密码')]) 82 | nickname = StringField(validators=[DataRequired(message='昵称不可为空'), 83 | length(min=2, max=10, message='昵称长度必须在2~10之间')]) 84 | ``` 85 | 86 | > 获取JWT登录令牌 87 | 88 | ```python 89 | from jian.jwt import get_tokens 90 | 91 | # Forms , Exception, Model 等要自己引入,这里只介绍jwt的用法 92 | # 登录逻辑只是展示,并不是这样直接用,登录逻辑已经封装到jian.core中 93 | 94 | @login.route('/login', methods=['POST']) 95 | def login(): 96 | form = LoginForm().validate_for_api() 97 | user = User.query.filter_by(nickname=nickname).first() 98 | if user is None or user.delete_time is not None: 99 | raise NotFound(msg='用户不存在') 100 | if not user.check_password(password): 101 | raise ParameterException(msg='密码错误,请输入正确密码') 102 | if not user.is_active: 103 | raise AuthFailed(msg='您目前处于未激活状态,请联系超级管理员') 104 | access_token, refresh_token = get_tokens(user) 105 | return jsonify({ 106 | 'access_token': access_token, 107 | 'refresh_token': refresh_token 108 | }) 109 | ``` 110 | 111 | > 记录日志模块 112 | 113 | ```python 114 | from jian.log import Logger 115 | 116 | # 第一种方法 117 | @Logger(template='新注册了一个用户') 118 | def register_user(): 119 | pass 120 | 121 | # 第二种方法主要是复杂的结构 122 | def register_user(): 123 | Log.create_log( 124 | message=f'{user.nickname}登陆成功获取了令牌', 125 | user_id=user.id, user_name=user.nickname, 126 | status_code=200, method='post',path='/cms/user/login', 127 | authority='无', commit=True 128 | ) 129 | ... 130 | 131 | ``` 132 | 133 | ## 更新计划 134 | 135 | - [ ] 加入Swagger文档. 136 | - [ ] 文件上传模块. -------------------------------------------------------------------------------- /jian/interface.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | from datetime import datetime 6 | 7 | from sqlalchemy import Column, Integer, String, FetchedValue, SmallInteger, TIMESTAMP 8 | from werkzeug.security import generate_password_hash, check_password_hash 9 | 10 | from .enums import UserSuper, UserActive 11 | from .db import MixinJSONSerializer, db 12 | from .util import camel2line 13 | 14 | 15 | # 基础的crud model 16 | class BaseCrud(db.Model, MixinJSONSerializer): 17 | __abstract__ = True 18 | 19 | def __init__(self): 20 | name: str = self.__class__.__name__ 21 | if not hasattr(self, '__tablename__'): 22 | self.__tablename__ = camel2line(name) 23 | 24 | def _set_fields(self): 25 | self._exclude = [] 26 | 27 | def set_attrs(self, attrs_dict): 28 | for key, value in attrs_dict.items(): 29 | if hasattr(self, key) and key != 'id': 30 | setattr(self, key, value) 31 | 32 | # 硬删除 33 | def delete(self, commit=False): 34 | db.session.delete(self) 35 | if commit: 36 | db.session.commit() 37 | 38 | # 查 39 | @classmethod 40 | def get(cls, start=None, count=None, one=True, **kwargs): 41 | if one: 42 | return cls.query.filter().filter_by(**kwargs).first() 43 | return cls.query.filter().filter_by(**kwargs).offset(start).limit(count).all() 44 | 45 | # 增 46 | @classmethod 47 | def create(cls, **kwargs): 48 | one = cls() 49 | for key in kwargs.keys(): 50 | if hasattr(one, key): 51 | setattr(one, key, kwargs[key]) 52 | db.session.add(one) 53 | if kwargs.get('commit') is True: 54 | db.session.commit() 55 | return one 56 | 57 | def update(self, **kwargs): 58 | for key in kwargs.keys(): 59 | if hasattr(self, key): 60 | setattr(self, key, kwargs[key]) 61 | db.session.add(self) 62 | if kwargs.get('commit') is True: 63 | db.session.commit() 64 | return self 65 | 66 | 67 | # 提供软删除,及创建时间,更新时间信息的crud model 68 | class InfoCrud(db.Model, MixinJSONSerializer): 69 | __abstract__ = True 70 | _create_time = Column('create_time', TIMESTAMP(True), default=datetime.now) 71 | update_time = Column(TIMESTAMP(True), default=datetime.now, onupdate=datetime.now) 72 | delete_time = Column(TIMESTAMP(True)) 73 | 74 | def __init__(self): 75 | name: str = self.__class__.__name__ 76 | if not hasattr(self, '__tablename__'): 77 | self.__tablename__ = camel2line(name) 78 | 79 | def _set_fields(self): 80 | self._exclude = ['update_time', 'delete_time'] 81 | 82 | @property 83 | def create_time(self): 84 | if self._create_time is None: 85 | return None 86 | return int(round(self._create_time.timestamp() * 1000)) 87 | 88 | # @property 89 | # def update_time(self): 90 | # if self._update_time is None: 91 | # return None 92 | # return int(round(self._update_time.timestamp() * 1000)) 93 | 94 | # @property 95 | # def delete_time(self): 96 | # if self._delete_time is None: 97 | # return None 98 | # return int(round(self._delete_time.timestamp() * 1000)) 99 | 100 | def set_attrs(self, attrs_dict): 101 | for key, value in attrs_dict.items(): 102 | if hasattr(self, key) and key != 'id': 103 | setattr(self, key, value) 104 | 105 | # 软删除 106 | def delete(self, commit=False): 107 | self.delete_time = datetime.now() 108 | db.session.add(self) 109 | # 记得提交会话 110 | if commit: 111 | db.session.commit() 112 | 113 | # 硬删除 114 | def hard_delete(self, commit=False): 115 | db.session.delete(self) 116 | if commit: 117 | db.session.commit() 118 | 119 | # 查 120 | @classmethod 121 | def get(cls, start=None, count=None, one=True, **kwargs): 122 | # 应用软删除,必须带有delete_time 123 | if kwargs.get('delete_time') is None: 124 | kwargs['delete_time'] = None 125 | if one: 126 | return cls.query.filter().filter_by(**kwargs).first() 127 | return cls.query.filter().filter_by(**kwargs).offset(start).limit(count).all() 128 | 129 | # 增 130 | @classmethod 131 | def create(cls, **kwargs): 132 | one = cls() 133 | for key in kwargs.keys(): 134 | # if key == 'from': 135 | # setattr(one, '_from', kwargs[key]) 136 | # if key == 'parts': 137 | # setattr(one, '_parts', kwargs[key]) 138 | if hasattr(one, key): 139 | setattr(one, key, kwargs[key]) 140 | db.session.add(one) 141 | if kwargs.get('commit') is True: 142 | db.session.commit() 143 | return one 144 | 145 | def update(self, **kwargs): 146 | for key in kwargs.keys(): 147 | # if key == 'from': 148 | # setattr(self, '_from', kwargs[key]) 149 | if hasattr(self, key): 150 | setattr(self, key, kwargs[key]) 151 | db.session.add(self) 152 | if kwargs.get('commit') is True: 153 | db.session.commit() 154 | return self 155 | 156 | 157 | class UserInterface(InfoCrud): 158 | __tablename__ = 'jian_user' 159 | 160 | id = Column(Integer, primary_key=True) 161 | nickname = Column(String(24), nullable=False, unique=True) 162 | # : super express the user is super(super admin) ; 1 -> common | 2 -> super 163 | # : super 代表是否为超级管理员 ; 1 -> 普通用户 | 2 -> 超级管理员 164 | super = Column(SmallInteger, nullable=False, default=1, server_default=FetchedValue()) 165 | # : active express the user can manage the authorities or not ; 1 -> active | 2 -> not 166 | # : active 代表当前用户是否为激活状态,非激活状态默认失去用户权限 ; 1 -> 激活 | 2 -> 非激活 167 | active = Column(SmallInteger, nullable=False, default=1, server_default=FetchedValue()) 168 | # : used to send email in the future 169 | # : 预留字段,方便以后扩展 170 | email = Column(String(100), unique=True) 171 | # : which group the user belongs,nullable is true 172 | # : 用户所属的权限组id 173 | group_id = Column(Integer) 174 | _password = Column('password', String(100)) 175 | 176 | def _set_fields(self): 177 | self._exclude = ['password'] 178 | 179 | @property 180 | def password(self): 181 | return self._password 182 | 183 | @password.setter 184 | def password(self, raw): 185 | self._password = generate_password_hash(raw) 186 | 187 | @property 188 | def is_super(self): 189 | return self.super == UserSuper.SUPER.value 190 | 191 | @property 192 | def is_active(self): 193 | return self.active == UserActive.ACTIVE.value 194 | 195 | @classmethod 196 | def verify(cls, nickname, password): 197 | raise Exception('must implement this method') 198 | 199 | def check_password(self, raw): 200 | if not self._password: 201 | return False 202 | return check_password_hash(self._password, raw) 203 | 204 | def reset_password(self, new_password): 205 | raise Exception('must implement this method') 206 | 207 | def change_password(self, old_password, new_password): 208 | raise Exception('must implement this method') 209 | 210 | 211 | class AuthInterface(BaseCrud): 212 | __tablename__ = 'jian_auth' 213 | 214 | id = Column(Integer, primary_key=True) 215 | # : belongs to which group 216 | # : 所属权限组id 217 | group_id = Column(Integer, nullable=False) 218 | # : authority field 219 | # : 权限字段 220 | auth = Column(String(60)) 221 | # : authority module, default common , which can sort authorities 222 | # : 权限的模块 223 | module = Column(String(50)) 224 | 225 | 226 | class GroupInterface(BaseCrud): 227 | __tablename__ = 'jian_group' 228 | 229 | id = Column(Integer, primary_key=True) 230 | # : name of group 231 | # : 权限组名称 232 | name = Column(String(60)) 233 | # a description of a group 234 | # 权限组描述 235 | info = Column(String(255)) 236 | 237 | 238 | class LogInterface(BaseCrud): 239 | __tablename__ = 'jian_log' 240 | 241 | id = Column(Integer, primary_key=True) 242 | # : log message 243 | # : 日志信息 244 | message = Column(String(450)) 245 | # : create time 246 | # : 日志创建时间 247 | _time = Column('time', TIMESTAMP(True), default=datetime.now) 248 | # : user id 249 | # : 用户id 250 | user_id = Column(Integer, nullable=False) 251 | # user_name at that moment 252 | # 用户当时的昵称 253 | user_name = Column(String(20)) 254 | # : status_code check request is success or not 255 | # : 请求的http返回码 256 | status_code = Column(Integer) 257 | # request method 258 | # 请求方法 259 | method = Column(String(20)) 260 | # request path 261 | # 请求路径 262 | path = Column(String(50)) 263 | # which authority is accessed 264 | # 访问那个权限 265 | authority = Column(String(100)) 266 | 267 | @property 268 | def time(self): 269 | if self._time is None: 270 | return None 271 | return int(round(self._time.timestamp() * 1000)) 272 | 273 | 274 | class EventInterface(BaseCrud): 275 | __tablename__ = 'jian_event' 276 | id = Column(Integer, primary_key=True) 277 | # : belongs to which group 278 | group_id = Column(Integer, nullable=False) 279 | # message type ['订单','修改密码'] 280 | message_events = Column(String(250)) 281 | 282 | 283 | # service暂时不用 284 | class ServiceInterface(object): 285 | pass 286 | -------------------------------------------------------------------------------- /jian/core.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | """ 3 | @ Created by Seven on 2019/01/19 4 | """ 5 | 6 | from collections import namedtuple 7 | from datetime import datetime, date 8 | 9 | from flask import Flask, current_app, request, Blueprint 10 | from flask.json import JSONEncoder as _JSONEncoder 11 | from werkzeug.exceptions import HTTPException 12 | from werkzeug.local import LocalProxy 13 | 14 | from .db import db 15 | from .jwt import jwt 16 | from .exception import APIException, UnknownException 17 | from .interface import UserInterface, GroupInterface, AuthInterface, LogInterface, EventInterface 18 | from .exception import NotFound, AuthFailed 19 | from .config import Config 20 | 21 | __version__ = 0.1 22 | 23 | # 路由函数的权限和模块信息(meta信息) 24 | Meta = namedtuple('meta', ['auth', 'module']) 25 | 26 | # -> endpoint -> func 27 | # auth -> module 28 | # -> endpoint -> func 29 | 30 | # 记录路由函数的权限和模块信息 31 | route_meta_infos = {} 32 | 33 | 34 | gigi_config = Config() 35 | 36 | 37 | def route_meta(auth, module='common', mount=True): 38 | """ 39 | 记录路由函数的信息 40 | 记录路由函数访问的推送信息模板 41 | 注:只有使用了 route_meta 装饰器的函数才会被记录到权限管理的map中 42 | :param auth: 权限 43 | :param module: 所属模块 44 | :param mount: 是否挂在到权限中(一些视图函数需要说明,或暂时决定不挂在到权限中,则设置为False) 45 | :return: 46 | """ 47 | 48 | def wrapper(func): 49 | if mount: 50 | name = func.__name__ 51 | exit = route_meta_infos.get(name, None) 52 | if exit: 53 | raise Exception("func's name cant't be repeat") 54 | else: 55 | route_meta_infos.setdefault(name, Meta(auth, module)) 56 | return func 57 | 58 | return wrapper 59 | 60 | 61 | def find_user(**kwargs): 62 | return manager.find_user(**kwargs) 63 | 64 | 65 | def find_group(**kwargs): 66 | return manager.find_group(**kwargs) 67 | 68 | 69 | def get_ep_infos(): 70 | """ 返回权限管理中的所有视图函数的信息,包含它所属module """ 71 | infos = {} 72 | for ep, meta in manager.ep_meta.items(): 73 | mod = infos.get(meta.module, None) 74 | if mod: 75 | sub = mod.get(meta.auth, None) 76 | if sub: 77 | sub.append(ep) 78 | else: 79 | mod[meta.auth] = [ep] 80 | else: 81 | infos.setdefault(meta.module, {meta.auth: [ep]}) 82 | 83 | return infos 84 | 85 | 86 | def find_info_by_ep(ep): 87 | """ 通过请求的endpoint寻找路由函数的meta信息""" 88 | return manager.ep_meta.get(ep) 89 | 90 | 91 | def is_user_allowed(group_id): 92 | """查看当前user有无权限访问该路由函数""" 93 | ep = request.endpoint 94 | # 根据 endpoint 查找 authority 95 | meta = manager.ep_meta.get(ep) 96 | return manager.verity_user_in_group(group_id, meta.auth, meta.module) 97 | 98 | 99 | def find_auth_module(auth): 100 | """ 通过权限寻找meta信息""" 101 | for _, meta in manager.ep_meta.items(): 102 | if meta.auth == auth: 103 | return meta 104 | return None 105 | 106 | 107 | class Jian(object): 108 | 109 | def __init__(self, 110 | app: Flask = None, # flask app , default None 111 | group_model=None, # group model, default None 112 | user_model=None, # user model, default None 113 | auth_model=None, # authority model, default None 114 | create_all=False, # 是否创建所有数据库表, default false 115 | mount=True, # 是否挂载默认的蓝图, default True 116 | handle=True, # 是否使用全局异常处理 , default True 117 | json_encoder=True, # 是否使用自定义的json_encoder , default True 118 | ): 119 | self.app = app 120 | self.manager = None 121 | if app is not None: 122 | self.init_app(app, group_model, user_model, auth_model, create_all, mount, handle, json_encoder) 123 | 124 | def init_app(self, 125 | app: Flask, 126 | group_model=None, 127 | user_model=None, 128 | auth_model=None, 129 | create_all=False, 130 | mount=True, 131 | handle=True, 132 | json_encoder=True, 133 | ): 134 | # default config 135 | app.config.setdefault('PLUGIN_PATH', {}) 136 | # 默认蓝图的前缀 137 | app.config.setdefault('BP_URL_PREFIX', '') 138 | json_encoder and self._enable_json_encoder(app) 139 | self.app = app 140 | # 初始化 manager 141 | self.manager = Manager(app.config.get('PLUGIN_PATH'), 142 | group_model, 143 | user_model, 144 | auth_model) 145 | self.app.extensions['manager'] = self.manager 146 | db.init_app(app) 147 | create_all and self._enable_create_all(app) 148 | jwt.init_app(app) 149 | mount and self.mount(app) 150 | handle and self.handle_error(app) 151 | 152 | def mount(self, app): 153 | # 加载默认插件路由 154 | bp = Blueprint('plugin', __name__) 155 | # 加载插件的路由 156 | for plugin in self.manager.plugins.values(): 157 | for controller in plugin.controllers.values(): 158 | controller.register(bp) 159 | app.register_blueprint(bp, url_prefix=app.config.get('BP_URL_PREFIX')) 160 | for ep, func in app.view_functions.items(): 161 | info = route_meta_infos.get(func.__name__, None) 162 | if info: 163 | self.manager.ep_meta.setdefault(ep, info) 164 | 165 | def handle_error(self, app): 166 | @app.errorhandler(Exception) 167 | def handler(e): 168 | if isinstance(e, APIException): 169 | return e 170 | if isinstance(e, HTTPException): 171 | code = e.code 172 | msg = e.description 173 | error_code = 20000 174 | return APIException(msg, code, error_code) 175 | else: 176 | if not app.config['DEBUG']: 177 | return UnknownException() 178 | else: 179 | raise e 180 | 181 | def _enable_json_encoder(self, app): 182 | app.json_encoder = JSONEncoder 183 | 184 | def _enable_create_all(self, app): 185 | with app.app_context(): 186 | db.create_all() 187 | 188 | 189 | class Manager(object): 190 | """ manager for Jian """ 191 | 192 | # 路由函数的meta信息的容器 193 | ep_meta = {} 194 | 195 | def __init__(self, plugin_path, group_model=None, user_model=None, auth_model=None): 196 | if not group_model: 197 | self.group_model = Group 198 | else: 199 | self.group_model = group_model 200 | 201 | if not user_model: 202 | self.user_model = User 203 | else: 204 | self.user_model = user_model 205 | 206 | if not auth_model: 207 | self.auth_model = Auth 208 | else: 209 | self.auth_model = auth_model 210 | 211 | from .loader import Loader 212 | self.loader: Loader = Loader(plugin_path) 213 | 214 | def find_user(self, **kwargs): 215 | return self.user_model.query.filter_by(**kwargs).first() 216 | 217 | def verify_user(self, nickname, password): 218 | return self.user_model.verify(nickname, password) 219 | 220 | def find_group(self, **kwargs): 221 | return self.group_model.query.filter_by(**kwargs).first() 222 | 223 | def verity_user_in_group(self, group_id, auth, module): 224 | return self.auth_model.query.filter_by(group_id=group_id, auth=auth, module=module).first() 225 | 226 | @property 227 | def plugins(self): 228 | return self.loader.plugins 229 | 230 | def get_plugin(self, name): 231 | return self.loader.plugins.get(name) 232 | 233 | def get_model(self, name): 234 | # attention!!! if models have the same name,will return the first one 235 | # 注意!!! 如果容器内有相同的model,则默认返回第一个 236 | for plugin in self.plugins.values(): 237 | return plugin.models.get(name) 238 | 239 | def get_service(self, name): 240 | # attention!!! if services have the same name,will return the first one 241 | # 注意!!! 如果容器内有相同的service,则默认返回第一个 242 | for plugin in self.plugins.values(): 243 | return plugin.services.get(name) 244 | 245 | 246 | # a proxy for manager instance 247 | # attention, only used when context in stack 248 | 249 | # 获得manager实例 250 | # 注意,仅仅在flask的上下文栈中才可获得 251 | manager: Manager = LocalProxy(lambda: get_manager()) 252 | 253 | 254 | def get_manager(): 255 | _manager = current_app.extensions['manager'] 256 | if _manager: 257 | return _manager 258 | else: 259 | app = current_app._get_current_object() 260 | with app.app_context(): 261 | return app.extensions['manager'] 262 | 263 | 264 | 265 | 266 | class User(UserInterface, db.Model): 267 | 268 | @classmethod 269 | def verify(cls, nickname, password): 270 | user = cls.query.filter_by(nickname=nickname).first() 271 | if user is None: 272 | raise NotFound(msg='用户不存在') 273 | if not user.check_password(password): 274 | raise AuthFailed(msg='密码错误,请输入正确密码') 275 | return user 276 | 277 | def reset_password(self, new_password): 278 | #: attention,remember to commit 279 | #: 注意,修改密码后记得提交至数据库 280 | self.password = new_password 281 | 282 | def change_password(self, old_password, new_password): 283 | #: attention,remember to commit 284 | #: 注意,修改密码后记得提交至数据库 285 | if self.check_password(old_password): 286 | self.password = new_password 287 | return True 288 | return False 289 | 290 | 291 | class Group(GroupInterface): 292 | pass 293 | 294 | 295 | class Auth(AuthInterface): 296 | pass 297 | 298 | 299 | # log model 300 | class Log(LogInterface): 301 | @staticmethod 302 | def create_log(**kwargs): 303 | log = Log() 304 | for key in kwargs.keys(): 305 | if hasattr(log, key): 306 | setattr(log, key, kwargs[key]) 307 | db.session.add(log) 308 | if kwargs.get('commit') is True: 309 | db.session.commit() 310 | return log 311 | 312 | 313 | # event model 314 | class Event(EventInterface): 315 | pass 316 | 317 | 318 | class JSONEncoder(_JSONEncoder): 319 | def default(self, o): 320 | if hasattr(o, 'keys') and hasattr(o, '__getitem__'): 321 | return dict(o) 322 | if isinstance(o, datetime): 323 | return o.strftime('%Y-%m-%dT%H:%M:%SZ') 324 | if isinstance(o, date): 325 | return o.strftime('%Y-%m-%d') 326 | return JSONEncoder.default(self, o) 327 | --------------------------------------------------------------------------------