├── .gitignore ├── CHANGES.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── aclients ├── __init__.py ├── aio_http_client.py ├── aio_mongo_client.py ├── aio_mysql_client.py ├── aio_redis_client.py ├── decorators.py ├── err_msg.py ├── exceptions.py ├── tinylibs │ ├── __init__.py │ ├── blinker.py │ └── tinymysql.py └── utils.py ├── docs └── index.rst ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── jrpc_client.py ├── jrpc_server.py ├── verify_gen_sql.py └── verify_http_client.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | 将删除 .idea/ 107 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | ## aclients Changelog 2 | 3 | ###[1.0.3] - 2020-10-13 4 | 5 | #### Changed 6 | - 暂时去掉一个账户只能一处登录的限制 7 | - session的有效期由30分钟延长到1个小时 8 | - 上下文管理器改为异步上下文管理器 9 | 10 | ###[1.0.1] - 2020-6-14 11 | 12 | #### Added 13 | - 增加生成分表model功能,使得分表的使用简单高效 14 | - 配置增加pool_recycle回旋关闭连接功能 15 | 16 | #### Changed 17 | - 优化所有代码中没有类型标注的地方,都改为typing中的类型标注 18 | - 经过项目测试功能稳定,发布正式版本. 19 | - 删除未开发完成的功能. 20 | - 此扩展原则上将做为最后一个未拆分扩展版本,后续将分开维护. 21 | 22 | 23 | ###[1.0.0b45] - 2019-12-30 24 | 25 | #### Added 26 | - 新增如果查询(分页查询)的时候没有进行排序,默认按照id关键字升序排序的功能,防止出现混乱数据 27 | 28 | ###[1.0.0b44] - 2019-9-29 29 | 30 | #### Added 31 | - 增加incrbynumber方法,可以增加整数或者浮点数 32 | - 添加四个扩展到app的反向引用 33 | - 增加add_task异步任务执行 34 | 35 | #### Changed 36 | - 优化redis client中所有的设置过期时间的实现方式 37 | - 优化获取数据的方式,去掉不适宜的异常 38 | - 更改session可能存在的隐藏问题 39 | 40 | ###[1.0.0b40] - 2019-9-11 41 | 42 | #### Changed 43 | - 更改由于session过期后再次登录删除老的session报错的问题 44 | 45 | ###[1.0.0b39] - 2019-9-10 46 | 47 | #### Added 48 | - 增加新的账号登录生成session后,清除老的session的功能,保证一个账号只能一个终端登录 49 | 50 | #### Changed 51 | - 更改session的过期时间为30分钟 52 | - 更改通用数据的缓存时间为12小时 53 | - 删除session时删除和session所有相关的缓存key 54 | 55 | ###[1.0.0b37] - 2019-7-12 56 | 57 | #### Changed 58 | - 更改schema valication的中文提示信息 59 | - 更改http,mongo,mysql,redis中停止服务时会出现关闭pool报错的情况 60 | 61 | ###[1.0.0b36] - 2019-5-23 62 | 63 | #### Added 64 | - 工具类中增加生成随机长度以字母开头的字母和数字的字符串标识 65 | 66 | ###[1.0.0b35] - 2019-4-29 67 | 68 | #### Changed 69 | - 修改三方库的pymongo>=3.8.0,以上的版本实现了ObjectID 0.2版本规范,生成的ObjectID发生碰撞的可能行更小. 70 | 71 | ###[1.0.0b34] - 2019-4-24 72 | 73 | #### Added 74 | - 修改aiohttp client增加针对ip地址的URL中session接受cookie的功能开关 75 | 76 | ###[1.0.0b33] - 2019-4-23 77 | 78 | #### Added 79 | - 工具类utils中增加用于枚举实例的元类 80 | - tinylibs包中增加简单的异步信号实现blinker模块 81 | 82 | #### Changed 83 | - 移动单例类Singleton从decorators到utils工具类 84 | 85 | ###[1.0.0b32] - 2019-4-22 86 | 87 | #### Changed 88 | - 修改schema message装饰器判断错误的情况 89 | 90 | 91 | ###[1.0.0b31] - 2019-4-21 92 | 93 | #### Changed 94 | - 修改sanic版本为长期支持版本18.12LTS 95 | 96 | ###[1.0.0b30] - 2019-4-20 97 | 98 | #### Changed 99 | - 优化redis客户端,修改出现的错误 100 | 101 | ###[1.0.0b29] - 2019-4-19 102 | 103 | #### Changed 104 | - 修改redis中获取session返回session对象出错的问题 105 | 106 | ###[1.0.0b28] - 2019-4-19 107 | 108 | #### Added 109 | - 增加保存和更新hash数据时对单个键值进行保存和更新的功能 110 | 111 | ###[1.0.0b27] - 2019-4-19 112 | 113 | #### Changed 114 | - 修改获取redis数据时可能出现的没有把字符串转换为对象的情况 115 | - 修改保存redis数据时指定是否进行dump以便进行性能的提高 116 | - 修改获取redis数据时指定是否进行load以便进行性能的提高 117 | 118 | 119 | ###[1.0.0b26] - 2019-4-18 120 | 121 | #### Changed 122 | - 修改Session中增加page_id和page_menu_id两个用于账户的页面权限管理 123 | 124 | ###[1.0.0b25] - 2019-4-16 125 | 126 | #### Changed 127 | - 修改TinyMySQL中execute的参数,修改find中的args参数名称 128 | 129 | ###[1.0.0b24] - 2019-4-2 130 | 131 | #### Added 132 | - 工具类中增加由对象名生成类名的功能 133 | - 工具类中增加解析yaml文件的功能 134 | - 工具类中增加返回objectid的功能 135 | - 增加基于pymysql的简单TinyMySQL功能,用于简单操作MySQL的时候使用 136 | 137 | ###[1.0.0b22] - 2019-3-25 138 | 139 | #### Changed 140 | - 修改update data中设置前缀的错误 141 | 142 | ###[1.0.0b20] - 2019-3-22 143 | 144 | #### Changed 145 | - 修改mongo插入中的错误,修改id时的错误 146 | 147 | ###[1.0.0b19] - 2019-3-22 148 | 149 | #### Changed 150 | - 修改mongo查询中的错误 151 | 152 | ###[1.0.0b18] - 2019-3-21 153 | 154 | #### Added 155 | - redis的session中增加角色ID 156 | - redis的session中增加静态权限ID 157 | - redis的session中增加动态权限ID 158 | 159 | #### Changed 160 | - redis的session中user_id更改为account_id 161 | 162 | ###[1.0.0b17] - 2019-3-21 163 | 164 | #### Added 165 | - 增加同步方法包装为异步方法的功能 166 | 167 | ###[1.0.0b16] - 2019-3-16 168 | 169 | #### Changed 170 | - 优化insert document中的ID处理逻辑 171 | - 优化update data处理逻辑 172 | - 优化query key处理逻辑,可以直接使用in、like等查询 173 | 174 | ###[1.0.0b15] - 2019-1-31 175 | 176 | #### Changed 177 | - 修改schema装饰器实现 178 | 179 | ###[1.0.0b14] - 2019-1-31 180 | 181 | #### Added 182 | - 增加元类单例实现 183 | 184 | #### Changed 185 | - 修改httpclient为元类单例的子类 186 | 187 | ###[1.0.0b13] - 2019-1-28 188 | 189 | #### Changed 190 | - 优化MySQL或查询支持列表 191 | 192 | ###[1.0.0b12] - 2019-1-28 193 | 194 | #### Changed 195 | - 删除http message 中不用的消息 196 | - 优化exceptions的实现方式 197 | 198 | ###[1.0.0b11] - 2019-1-25 199 | 200 | #### Changed 201 | - 修改schema_validate装饰器能够修改提示消息的功能,如果多个地方用到此装饰器,在其中一处修改即可 202 | 203 | ###[1.0.0b10] - 2019-1-25 204 | 205 | #### Added 206 | - 添加schema_validate装饰器用于校验schema 207 | #### Changed 208 | - 修改http message默认值 209 | 210 | ###[1.0.0b9] - 2019-1-21 211 | 212 | #### Added 213 | - 增加多数据库、多实例应用方式 214 | - 增加http client的测试 215 | - 增加在没有app时脚本中使用时的初始化功能,这样便于通用性 216 | - 增加错误类型,能够对错误进行定制 217 | - 增加单例的装饰器,修改httpclient为单例 218 | #### Changed 219 | - 修改一处可能引起错误的地方 220 | 221 | ###[1.0.0b8] - 2019-1-18 222 | 223 | #### Changed 224 | - 修改初始化方式,更改为sanic扩展的初始化方式,即init_app 225 | - 修改初始化时配置的加载顺序,默认先加载 226 | 227 | ###[1.0.0b7] - 2019-1-18 228 | 229 | #### Changed 230 | - 从b2到b7版本的修改记录忘记了,这里先不记录了 231 | 232 | ###[1.0.0b1] - 2018-12-26 233 | 234 | #### Added 235 | 236 | - MySQL基于aiomysql和sqlalchemy的CRUD封装 237 | - http基于aiohttp的CRUD封装 238 | - session基于aredis的CRUD封装 239 | - redis基于aredis的常用封装 240 | - mongo基于motor的CRUD封装 241 | - 所有消息可自定义配置,否则为默认配置 242 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tiny Bees 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude * __pycache__ 2 | recursive-exclude * *.py[co] 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aclients 2 | 基于sanic扩展,各种数据库异步crud操作,各种请求异步crud操作,本repo是个基础封装库,直接用于各个业务系统中 3 | 4 | # Installing aelog 5 | - ```pip install aclients``` 6 | 7 | # Usage 8 | ### 后续添加,现在没时间. 9 | -------------------------------------------------------------------------------- /aclients/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-25 下午4:40 8 | """ 9 | from .decorators import * 10 | from .aio_mongo_client import * 11 | from .aio_mysql_client import * 12 | from .aio_redis_client import * 13 | from .aio_http_client import * 14 | 15 | __version__ = "1.0.3" 16 | -------------------------------------------------------------------------------- /aclients/aio_http_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-26 上午11:49 8 | """ 9 | import asyncio 10 | import atexit 11 | from typing import Dict 12 | 13 | import aelog 14 | import aiohttp 15 | 16 | from .err_msg import http_msg 17 | from .exceptions import ClientConnectionError, ClientError, ClientResponseError, HttpError 18 | from .utils import Singleton, verify_message 19 | 20 | __all__ = ("AIOHttpClient", "AsyncResponse") 21 | 22 | 23 | class AsyncResponse(object): 24 | """ 25 | 异步响应对象,需要重新封装对象 26 | """ 27 | __slots__ = ["status_code", "reason", "headers", "cookies", "resp_body", "content"] 28 | 29 | def __init__(self, status_code: int, reason: str, headers: Dict, cookies: Dict, *, resp_body: Dict, 30 | content: bytes): 31 | """ 32 | 33 | Args: 34 | 35 | """ 36 | self.status_code = status_code 37 | self.reason = reason 38 | self.headers = headers 39 | self.cookies = cookies 40 | self.resp_body = resp_body 41 | self.content = content 42 | 43 | def json(self, ): 44 | """ 45 | 为了适配 46 | Args: 47 | 48 | Returns: 49 | 50 | """ 51 | return self.resp_body 52 | 53 | 54 | class AIOHttpClient(Singleton): 55 | """ 56 | 基于aiohttp的异步封装 57 | """ 58 | 59 | def __init__(self, app=None, *, timeout: int = 5 * 60, verify_ssl: bool = True, message: Dict = None, 60 | use_zh: bool = True, cookiejar_unsafe: bool = False): 61 | """ 62 | 基于aiohttp的异步封装 63 | Args: 64 | app: app应用 65 | timeout:request timeout 66 | verify_ssl:verify ssl 67 | message: 提示消息 68 | use_zh: 消息提示是否使用中文,默认中文 69 | cookiejar_unsafe: 是否打开cookiejar的非严格模式,默认false 70 | """ 71 | self.app = app 72 | self.session = None 73 | self.timeout = timeout 74 | self.verify_ssl = verify_ssl 75 | self.message = message or {} 76 | self.use_zh = use_zh 77 | self.msg_zh = None 78 | # 默认clientsession使用严格版本的cookiejar, 禁止ip地址的访问共享cookie 79 | # 如果访问的是ip地址的URL,并且需要保持cookie则需要打开 80 | self.cookiejar_unsafe = cookiejar_unsafe 81 | 82 | if app is not None: 83 | self.init_app(app, timeout=self.timeout, verify_ssl=self.verify_ssl, message=self.message, 84 | use_zh=self.use_zh) 85 | 86 | def init_app(self, app, *, timeout: int = None, verify_ssl: bool = None, message: Dict = None, 87 | use_zh: bool = None): 88 | """ 89 | 基于aiohttp的异步封装 90 | Args: 91 | app: app应用 92 | timeout:request timeout 93 | verify_ssl:verify ssl 94 | message: 提示消息 95 | use_zh: 消息提示是否使用中文,默认中文 96 | Returns: 97 | 98 | """ 99 | self.app = app 100 | self.timeout = timeout or app.config.get("ACLIENTS_HTTP_TIMEOUT", None) or self.timeout 101 | self.verify_ssl = verify_ssl or app.config.get("ACLIENTS_HTTP_VERIFYSSL", None) or self.verify_ssl 102 | message = message or app.config.get("ACLIENTS_HTTP_MESSAGE", None) or self.message 103 | use_zh = use_zh or app.config.get("ACLIENTS_HTTP_MSGZH", None) or self.use_zh 104 | self.message = verify_message(http_msg, message) 105 | self.msg_zh = "msg_zh" if use_zh else "msg_en" 106 | 107 | @app.listener('before_server_start') 108 | async def open_connection(app_, loop): 109 | """ 110 | 111 | Args: 112 | 113 | Returns: 114 | 115 | """ 116 | jar = aiohttp.CookieJar(unsafe=self.cookiejar_unsafe) 117 | self.session = aiohttp.ClientSession(cookie_jar=jar) 118 | 119 | @app.listener('after_server_stop') 120 | async def close_connection(app_, loop): 121 | """ 122 | 释放session连接池所有连接 123 | Args: 124 | 125 | Returns: 126 | 127 | """ 128 | if self.session: 129 | await self.session.close() 130 | 131 | def init_session(self, *, timeout: int = None, verify_ssl: bool = None, message: Dict = None, 132 | use_zh: bool = None): 133 | """ 134 | 基于aiohttp的异步封装 135 | Args: 136 | timeout:request timeout 137 | verify_ssl:verify ssl 138 | message: 提示消息 139 | use_zh: 消息提示是否使用中文,默认中文 140 | Returns: 141 | 142 | """ 143 | self.timeout = timeout or self.timeout 144 | self.verify_ssl = verify_ssl or self.verify_ssl 145 | use_zh = use_zh or self.use_zh 146 | self.message = verify_message(http_msg, message or self.message) 147 | self.msg_zh = "msg_zh" if use_zh else "msg_en" 148 | loop = asyncio.get_event_loop() 149 | 150 | async def open_connection(): 151 | """ 152 | 153 | Args: 154 | 155 | Returns: 156 | 157 | """ 158 | jar = aiohttp.CookieJar(unsafe=self.cookiejar_unsafe) 159 | self.session = aiohttp.ClientSession(cookie_jar=jar) 160 | 161 | async def close_connection(): 162 | """ 163 | 释放session连接池所有连接 164 | Args: 165 | 166 | Returns: 167 | 168 | """ 169 | if self.session: 170 | await self.session.close() 171 | 172 | loop.run_until_complete(open_connection()) 173 | atexit.register(lambda: loop.run_until_complete(close_connection())) 174 | 175 | async def _request(self, method: str, url: str, *, params: Dict = None, data: Dict = None, 176 | json: Dict = None, headers: Dict = None, timeout: int = None, verify_ssl: bool = None, 177 | **kwargs) -> AsyncResponse: 178 | """ 179 | 180 | Args: 181 | method, url, *, params=None, data=None, json=None, headers=None, **kwargs 182 | Returns: 183 | 184 | """ 185 | 186 | async def _async_get(): 187 | """ 188 | 189 | Args: 190 | 191 | Returns: 192 | 193 | """ 194 | return await self.session.get(url, params=params, headers=headers, timeout=timeout, verify_ssl=verify_ssl, 195 | **kwargs) 196 | 197 | async def _async_post(): 198 | """ 199 | 200 | Args: 201 | 202 | Returns: 203 | 204 | """ 205 | res = await self.session.post(url, params=params, data=data, json=json, headers=headers, timeout=timeout, 206 | verify_ssl=verify_ssl, **kwargs) 207 | return res 208 | 209 | async def _async_put(): 210 | """ 211 | 212 | Args: 213 | 214 | Returns: 215 | 216 | """ 217 | return await self.session.put(url, params=params, data=data, json=json, headers=headers, timeout=timeout, 218 | verify_ssl=verify_ssl, **kwargs) 219 | 220 | async def _async_patch(): 221 | """ 222 | 223 | Args: 224 | 225 | Returns: 226 | 227 | """ 228 | return await self.session.patch(url, params=params, data=data, json=json, headers=headers, timeout=timeout, 229 | verify_ssl=verify_ssl, **kwargs) 230 | 231 | async def _async_delete(): 232 | """ 233 | 234 | Args: 235 | 236 | Returns: 237 | 238 | """ 239 | return await self.session.delete(url, params=params, data=data, json=json, headers=headers, timeout=timeout, 240 | verify_ssl=verify_ssl, **kwargs) 241 | 242 | get_resp = {"GET": _async_get, "POST": _async_post, "PUT": _async_put, "DELETE": _async_delete, 243 | "PATCH": _async_patch} 244 | resp = None 245 | try: 246 | resp = await get_resp[method.upper()]() 247 | resp.raise_for_status() 248 | except KeyError as e: 249 | raise ClientError(url=url, message="error method {0}".format(str(e))) 250 | except (aiohttp.ClientConnectionError, asyncio.TimeoutError) as e: 251 | raise ClientConnectionError(url=url, message=str(e)) 252 | except aiohttp.ClientResponseError as e: 253 | try: 254 | resp_data = await resp.json() if resp else "" 255 | except (ValueError, TypeError, aiohttp.ContentTypeError): 256 | resp_data = await resp.text() if resp else "" 257 | raise ClientResponseError(url=url, status_code=e.status, message=e.message, headers=e.headers, 258 | body=resp_data) 259 | except aiohttp.ClientError as e: 260 | raise ClientError(url=url, message="aiohttp.ClientError: {}".format(vars(e))) 261 | 262 | async with resp: 263 | try: 264 | resp_json = await resp.json() 265 | except (ValueError, TypeError, aiohttp.ContentTypeError): 266 | try: 267 | resp_text = await resp.text() 268 | except (ValueError, TypeError): 269 | try: 270 | resp_bytes = await resp.read() 271 | except (aiohttp.ClientResponseError, aiohttp.ClientError) as e: 272 | aelog.exception(e) 273 | raise HttpError(e.code, message=self.message[200][self.msg_zh], error=e) 274 | else: 275 | return AsyncResponse(resp.status, resp.reason, resp.headers, resp.cookies, resp_body="", 276 | content=resp_bytes) 277 | else: 278 | return AsyncResponse(resp.status, resp.reason, resp.headers, resp.cookies, resp_body=resp_text, 279 | content=b"") 280 | else: 281 | return AsyncResponse(resp.status, resp.reason, resp.headers, resp.cookies, resp_body=resp_json, 282 | content=b"") 283 | 284 | async def async_request(self, method: str, url: str, *, params: Dict = None, data: Dict = None, 285 | json: Dict = None, headers: Dict = None, timeout: int = None, verify_ssl: bool = None, 286 | **kwargs) -> AsyncResponse: 287 | """ 288 | 289 | Args: 290 | 291 | Returns: 292 | 293 | """ 294 | verify_ssl = self.verify_ssl if verify_ssl is None else verify_ssl 295 | timeout = self.timeout if timeout is None else timeout 296 | return await self._request(method, url, params=params, data=data, json=json, headers=headers, 297 | timeout=timeout, verify_ssl=verify_ssl, **kwargs) 298 | 299 | async def async_get(self, url: str, *, params: Dict = None, headers: Dict = None, timeout: int = None, 300 | verify_ssl: bool = None, **kwargs) -> AsyncResponse: 301 | """ 302 | 303 | Args: 304 | 305 | Returns: 306 | 307 | """ 308 | verify_ssl = self.verify_ssl if verify_ssl is None else verify_ssl 309 | timeout = self.timeout if timeout is None else timeout 310 | return await self._request("GET", url, params=params, headers=headers, timeout=timeout, verify_ssl=verify_ssl, 311 | **kwargs) 312 | 313 | async def async_post(self, url: str, *, params: Dict = None, data: Dict = None, json: Dict = None, 314 | headers: Dict = None, timeout: int = None, verify_ssl: bool = None, 315 | **kwargs) -> AsyncResponse: 316 | """ 317 | 318 | Args: 319 | 320 | Returns: 321 | 322 | """ 323 | verify_ssl = self.verify_ssl if verify_ssl is None else verify_ssl 324 | timeout = self.timeout if timeout is None else timeout 325 | return await self._request("POST", url, params=params, data=data, json=json, headers=headers, timeout=timeout, 326 | verify_ssl=verify_ssl, **kwargs) 327 | 328 | async def async_put(self, url: str, *, params: Dict = None, data: Dict = None, json: Dict = None, 329 | headers: Dict = None, timeout: int = None, verify_ssl: bool = None, 330 | **kwargs) -> AsyncResponse: 331 | """ 332 | 333 | Args: 334 | 335 | Returns: 336 | 337 | """ 338 | verify_ssl = self.verify_ssl if verify_ssl is None else verify_ssl 339 | timeout = self.timeout if timeout is None else timeout 340 | return await self._request("PUT", url, params=params, data=data, json=json, headers=headers, timeout=timeout, 341 | verify_ssl=verify_ssl, **kwargs) 342 | 343 | async def async_patch(self, url: str, *, params: Dict = None, data: Dict = None, json: Dict = None, 344 | headers: Dict = None, timeout: int = None, verify_ssl: bool = None, 345 | **kwargs) -> AsyncResponse: 346 | """ 347 | 348 | Args: 349 | 350 | Returns: 351 | 352 | """ 353 | verify_ssl = self.verify_ssl if verify_ssl is None else verify_ssl 354 | timeout = self.timeout if timeout is None else timeout 355 | return await self._request("PATCH", url, params=params, data=data, json=json, headers=headers, timeout=timeout, 356 | verify_ssl=verify_ssl, **kwargs) 357 | 358 | async def async_delete(self, url, *, params: Dict = None, data: Dict = None, json: Dict = None, 359 | headers: Dict = None, verify_ssl: bool = None, timeout: int = None, 360 | **kwargs) -> AsyncResponse: 361 | """ 362 | 363 | Args: 364 | 365 | Returns: 366 | 367 | """ 368 | verify_ssl = self.verify_ssl if verify_ssl is None else verify_ssl 369 | timeout = self.timeout if timeout is None else timeout 370 | return await self._request("DELETE", url, params=params, data=data, json=json, headers=headers, 371 | timeout=timeout, verify_ssl=verify_ssl, **kwargs) 372 | -------------------------------------------------------------------------------- /aclients/aio_mongo_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-25 下午3:41 8 | """ 9 | import atexit 10 | from collections.abc import MutableMapping, MutableSequence 11 | from typing import Dict, List, NoReturn, Optional, Tuple, Union 12 | 13 | import aelog 14 | # noinspection PyProtectedMember 15 | from bson import ObjectId 16 | from bson.errors import BSONError 17 | from motor.motor_asyncio import AsyncIOMotorClient 18 | # noinspection PyPackageRequirements 19 | from pymongo.errors import (ConnectionFailure, DuplicateKeyError, InvalidName, PyMongoError) 20 | 21 | from .err_msg import mongo_msg 22 | from .exceptions import FuncArgsError, HttpError, MongoDuplicateKeyError, MongoError, MongoInvalidNameError 23 | from .utils import verify_message 24 | 25 | __all__ = ("AIOMongoClient",) 26 | 27 | 28 | class AIOMongoClient(object): 29 | """ 30 | mongo 非阻塞工具类 31 | """ 32 | 33 | def __init__(self, app=None, *, username: str = "mongo", passwd: str = None, host: str = "127.0.0.1", 34 | port: int = 27017, dbname: str = None, pool_size: int = 50, **kwargs): 35 | """ 36 | mongo 非阻塞工具类 37 | Args: 38 | app: app应用 39 | host:mongo host 40 | port:mongo port 41 | dbname: database name 42 | username: mongo user 43 | passwd: mongo password 44 | pool_size: mongo pool size 45 | """ 46 | self.app = app 47 | self.client = None 48 | self.db = None 49 | self.username = username 50 | self.passwd = passwd 51 | self.host = host 52 | self.port = port 53 | self.dbname = dbname 54 | self.pool_size = pool_size 55 | self.message = kwargs.get("message", {}) 56 | self.use_zh = kwargs.get("use_zh", True) 57 | self.msg_zh = None 58 | 59 | if app is not None: 60 | self.init_app(app, username=self.username, passwd=self.passwd, host=self.host, port=self.port, 61 | dbname=self.dbname, pool_size=self.pool_size, **kwargs) 62 | 63 | def init_app(self, app, *, username: str = None, passwd: str = None, host: str = None, port: int = None, 64 | dbname: str = None, pool_size: int = None, **kwargs): 65 | """ 66 | mongo 实例初始化 67 | Args: 68 | app: app应用 69 | host:mongo host 70 | port:mongo port 71 | dbname: database name 72 | username: mongo user 73 | passwd: mongo password 74 | pool_size: mongo pool size 75 | Returns: 76 | 77 | """ 78 | username = username or app.config.get("ACLIENTS_MONGO_USERNAME", None) or self.username 79 | passwd = passwd or app.config.get("ACLIENTS_MONGO_PASSWD", None) or self.passwd 80 | host = host or app.config.get("ACLIENTS_MONGO_HOST", None) or self.host 81 | port = port or app.config.get("ACLIENTS_MONGO_PORT", None) or self.port 82 | dbname = dbname or app.config.get("ACLIENTS_MONGO_DBNAME", None) or self.dbname 83 | pool_size = pool_size or app.config.get("ACLIENTS_MONGO_POOL_SIZE", None) or self.pool_size 84 | message = kwargs.get("message") or app.config.get("ACLIENTS_MONGO_MESSAGE", None) or self.message 85 | use_zh = kwargs.get("use_zh") or app.config.get("ACLIENTS_MONGO_MSGZH", None) or self.use_zh 86 | 87 | passwd = passwd if passwd is None else str(passwd) 88 | self.message = verify_message(mongo_msg, message) 89 | self.msg_zh = "msg_zh" if use_zh else "msg_en" 90 | self.app = app 91 | 92 | @app.listener('before_server_start') 93 | async def open_connection(app_, loop): 94 | """ 95 | 96 | Args: 97 | 98 | Returns: 99 | 100 | """ 101 | self.create_db_conn(host, port, pool_size, username, passwd, dbname) 102 | 103 | @app.listener('after_server_stop') 104 | async def close_connection(app_, loop): 105 | """ 106 | 107 | Args: 108 | 109 | Returns: 110 | 111 | """ 112 | if self.client: 113 | self.client.close() 114 | 115 | def init_engine(self, *, username: str = None, passwd: str = None, host: str = None, port: int = None, 116 | dbname: str = None, pool_size: int = None, **kwargs): 117 | """ 118 | mongo 实例初始化 119 | Args: 120 | host:mongo host 121 | port:mongo port 122 | dbname: database name 123 | username: mongo user 124 | passwd: mongo password 125 | pool_size: mongo pool size 126 | Returns: 127 | 128 | """ 129 | username = username or self.username 130 | passwd = passwd or self.passwd 131 | host = host or self.host 132 | port = port or self.port 133 | dbname = dbname or self.dbname 134 | pool_size = pool_size or self.pool_size 135 | message = kwargs.get("message") or self.message 136 | use_zh = kwargs.get("use_zh") or self.use_zh 137 | 138 | passwd = passwd if passwd is None else str(passwd) 139 | self.message = verify_message(mongo_msg, message) 140 | self.msg_zh = "msg_zh" if use_zh else "msg_en" 141 | 142 | # engine 143 | self.create_db_conn(host, port, pool_size, username, passwd, dbname) 144 | 145 | @atexit.register 146 | def close_connection(): 147 | """ 148 | 149 | Args: 150 | 151 | Returns: 152 | 153 | """ 154 | if self.client: 155 | self.client.close() 156 | 157 | def create_db_conn(self, host: str, port: int, pool_size: int, username: str, passwd: str, dbname: str 158 | ) -> NoReturn: 159 | try: 160 | self.client = AsyncIOMotorClient(host, port, maxPoolSize=pool_size, username=username, password=passwd) 161 | self.db = self.client.get_database(name=dbname) 162 | except ConnectionFailure as e: 163 | aelog.exception("Mongo connection failed host={} port={} error:{}".format(host, port, e)) 164 | raise MongoError("Mongo connection failed host={} port={} error:{}".format(host, port, e)) 165 | except InvalidName as e: 166 | aelog.exception("Invalid mongo db name {} {}".format(dbname, e)) 167 | raise MongoInvalidNameError("Invalid mongo db name {} {}".format(dbname, e)) 168 | except PyMongoError as err: 169 | aelog.exception("Mongo DB init failed! error: {}".format(err)) 170 | raise MongoError("Mongo DB init failed!") from err 171 | 172 | async def _insert_document(self, name: str, document: Union[Dict, List[Dict]], insert_one: bool = True 173 | ) -> Union[str, Tuple[str]]: 174 | """ 175 | 插入一个单独的文档 176 | Args: 177 | name:collection name 178 | document: document obj 179 | insert_one: insert_one insert_many的过滤条件,默认True 180 | Returns: 181 | 返回插入的Objectid 182 | """ 183 | try: 184 | if insert_one: 185 | result = await self.db.get_collection(name).insert_one(document) 186 | else: 187 | result = await self.db.get_collection(name).insert_many(document) 188 | except InvalidName as e: 189 | raise MongoInvalidNameError("Invalid collention name {} {}".format(name, e)) 190 | except DuplicateKeyError as e: 191 | raise MongoDuplicateKeyError("Duplicate key error, {}".format(e)) 192 | except PyMongoError as err: 193 | aelog.exception("Insert one document failed, {}".format(err)) 194 | raise HttpError(400, message=self.message[100][self.msg_zh], error=err) 195 | else: 196 | return str(result.inserted_id) if insert_one else (str(val) for val in result.inserted_ids) 197 | 198 | async def _insert_documents(self, name: str, documents: List[Dict]) -> Tuple[str]: 199 | """ 200 | 批量插入文档 201 | Args: 202 | name:collection name 203 | documents: documents obj 204 | Returns: 205 | 返回插入的Objectid列表 206 | """ 207 | return await self._insert_document(name, documents, insert_one=False) 208 | 209 | async def _find_document(self, name: str, query_key: Dict, filter_key: Dict = None) -> Optional[Dict]: 210 | """ 211 | 查询一个单独的document文档 212 | Args: 213 | name: collection name 214 | query_key: 查询document的过滤条件 215 | filter_key: 过滤返回值中字段的过滤条件 216 | Returns: 217 | 返回匹配的document或者None 218 | """ 219 | try: 220 | find_data = await self.db.get_collection(name).find_one(query_key, projection=filter_key) 221 | except InvalidName as e: 222 | raise MongoInvalidNameError("Invalid collention name {} {}".format(name, e)) 223 | except PyMongoError as err: 224 | aelog.exception("Find one document failed, {}".format(err)) 225 | raise HttpError(400, message=self.message[103][self.msg_zh], error=err) 226 | else: 227 | if find_data and find_data.get("_id", None) is not None: 228 | find_data["id"] = str(find_data.pop("_id")) 229 | return find_data 230 | 231 | async def _find_documents(self, name: str, query_key: Dict, filter_key: Dict = None, limit: int = None, 232 | skip: int = None, sort: List[Tuple] = None) -> List[Dict]: 233 | """ 234 | 批量查询documents文档 235 | Args: 236 | name: collection name 237 | query_key: 查询document的过滤条件 238 | filter_key: 过滤返回值中字段的过滤条件 239 | limit: 限制返回的document条数 240 | skip: 从查询结果中调过指定数量的document 241 | sort: 排序方式,可以自定多种字段的排序,值为一个列表的键值对, eg:[('field1', pymongo.ASCENDING)] 242 | Returns: 243 | 返回匹配的document列表 244 | """ 245 | try: 246 | find_data = [] 247 | cursor = self.db.get_collection(name).find(query_key, projection=filter_key, limit=limit, skip=skip, 248 | sort=sort) 249 | # find_data = await cursor.to_list(None) 250 | async for doc in cursor: 251 | if doc.get("_id", None) is not None: 252 | doc["id"] = str(doc.pop("_id")) 253 | find_data.append(doc) 254 | except InvalidName as e: 255 | raise MongoInvalidNameError("Invalid collention name {} {}".format(name, e)) 256 | except PyMongoError as err: 257 | aelog.exception("Find many documents failed, {}".format(err)) 258 | raise HttpError(400, message=self.message[104][self.msg_zh], error=err) 259 | else: 260 | return find_data 261 | 262 | async def _find_count(self, name: str, query_key: Dict) -> int: 263 | """ 264 | 查询documents的数量 265 | Args: 266 | name: collection name 267 | query_key: 查询document的过滤条件 268 | Returns: 269 | 返回匹配的document数量 270 | """ 271 | try: 272 | return await self.db.get_collection(name).count(query_key) 273 | except InvalidName as e: 274 | raise MongoInvalidNameError("Invalid collention name {} {}".format(name, e)) 275 | except PyMongoError as err: 276 | aelog.exception("Find many documents failed, {}".format(err)) 277 | raise HttpError(400, message=self.message[104][self.msg_zh], error=err) 278 | 279 | async def _update_document(self, name: str, query_key: Dict, update_data: Dict, upsert: bool = False, 280 | update_one: bool = True) -> Dict: 281 | """ 282 | 更新匹配到的一个的document 283 | Args: 284 | name: collection name 285 | query_key: 查询document的过滤条件 286 | update_data: 对匹配的document进行更新的document 287 | upsert: 没有匹配到document的话执行插入操作,默认False 288 | update_one: update_one or update_many的匹配条件 289 | Returns: 290 | 返回匹配的数量和修改数量的dict, eg:{"matched_count": 1, "modified_count": 1, "upserted_id":"f"} 291 | """ 292 | try: 293 | if update_one: 294 | result = await self.db.get_collection(name).update_one(query_key, update_data, upsert=upsert) 295 | else: 296 | result = await self.db.get_collection(name).update_many(query_key, update_data, upsert=upsert) 297 | except InvalidName as e: 298 | raise MongoInvalidNameError("Invalid collention name {} {}".format(name, e)) 299 | except DuplicateKeyError as e: 300 | raise MongoDuplicateKeyError("Duplicate key error, {}".format(e)) 301 | except PyMongoError as err: 302 | aelog.exception("Update documents failed, {}".format(err)) 303 | raise HttpError(400, message=self.message[101][self.msg_zh], error=err) 304 | else: 305 | return {"matched_count": result.matched_count, "modified_count": result.modified_count, 306 | "upserted_id": str(result.upserted_id) if result.upserted_id else None} 307 | 308 | async def _update_documents(self, name: str, query_key: Dict, update_data: Dict, upsert: bool = False) -> Dict: 309 | """ 310 | 更新匹配到的所有的document 311 | Args: 312 | name: collection name 313 | query_key: 查询document的过滤条件 314 | update_data: 对匹配的document进行更新的document 315 | upsert: 没有匹配到document的话执行插入操作,默认False 316 | Returns: 317 | 返回匹配的数量和修改数量的dict, eg:{"matched_count": 2, "modified_count": 2, "upserted_id":"f"} 318 | """ 319 | return await self._update_document(name, query_key, update_data, upsert, update_one=False) 320 | 321 | async def _delete_document(self, name: str, query_key: Dict, delete_one: bool = True) -> int: 322 | """ 323 | 删除匹配到的一个的document 324 | Args: 325 | name: collection name 326 | query_key: 查询document的过滤条件 327 | delete_one: delete_one delete_many的匹配条件 328 | Returns: 329 | 返回删除的数量 330 | """ 331 | try: 332 | if delete_one: 333 | result = await self.db.get_collection(name).delete_one(query_key) 334 | else: 335 | result = await self.db.get_collection(name).delete_many(query_key) 336 | except InvalidName as e: 337 | raise MongoInvalidNameError("Invalid collention name {} {}".format(name, e)) 338 | except PyMongoError as err: 339 | aelog.exception("Delete documents failed, {}".format(err)) 340 | raise HttpError(400, message=self.message[102][self.msg_zh], error=err) 341 | else: 342 | return result.deleted_count 343 | 344 | async def _delete_documents(self, name: str, query_key: Dict) -> int: 345 | """ 346 | 删除匹配到的所有的document 347 | Args: 348 | name: collection name 349 | query_key: 查询document的过滤条件 350 | Returns: 351 | 返回删除的数量 352 | """ 353 | return await self._delete_document(name, query_key, delete_one=False) 354 | 355 | async def _aggregate(self, name: str, pipline: List[Dict]) -> List[Dict]: 356 | """ 357 | 根据pipline进行聚合查询 358 | Args: 359 | name: collection name 360 | pipline: 聚合查询的pipeline,包含一个后者多个聚合命令 361 | Returns: 362 | 返回聚合后的documents 363 | """ 364 | result = [] 365 | try: 366 | async for doc in self.db.get_collection(name).aggregate(pipline): 367 | if doc.get("_id", None) is not None: 368 | doc["id"] = str(doc.pop("_id")) 369 | result.append(doc) 370 | except InvalidName as e: 371 | raise MongoInvalidNameError("Invalid collention name {} {}".format(name, e)) 372 | except PyMongoError as err: 373 | aelog.exception("Aggregate documents failed, {}".format(err)) 374 | raise HttpError(400, message=self.message[105][self.msg_zh], error=err) 375 | else: 376 | return result 377 | 378 | # noinspection PyAsyncCall 379 | async def insert_documents(self, name: str, documents: List[Dict]) -> Tuple[str]: 380 | """ 381 | 批量插入文档 382 | Args: 383 | name:collection name 384 | documents: documents obj 385 | Returns: 386 | 返回插入的转换后的_id列表 387 | """ 388 | if not isinstance(documents, MutableSequence): 389 | aelog.error("insert many document failed, documents is not a iterable type.") 390 | raise MongoError("insert many document failed, documents is not a iterable type.") 391 | documents = list(documents) 392 | for document in documents: 393 | if not isinstance(document, MutableMapping): 394 | aelog.error("insert one document failed, document is not a mapping type.") 395 | raise MongoError("insert one document failed, document is not a mapping type.") 396 | self._update_doc_id(document) 397 | return await self._insert_documents(name, documents) 398 | 399 | async def insert_document(self, name: str, document: Dict) -> str: 400 | """ 401 | 插入一个单独的文档 402 | Args: 403 | name:collection name 404 | document: document obj 405 | Returns: 406 | 返回插入的转换后的_id 407 | """ 408 | if not isinstance(document, MutableMapping): 409 | aelog.error("insert one document failed, document is not a mapping type.") 410 | raise MongoError("insert one document failed, document is not a mapping type.") 411 | document = dict(document) 412 | return await self._insert_document(name, self._update_doc_id(document)) 413 | 414 | @staticmethod 415 | def _update_doc_id(document: Dict) -> Dict: 416 | """ 417 | 修改文档中的_id 418 | Args: 419 | document: document obj 420 | Returns: 421 | 返回处理后的document 422 | """ 423 | if "id" in document: 424 | try: 425 | document["_id"] = ObjectId(document.pop("id")) 426 | except BSONError as e: 427 | raise FuncArgsError(str(e)) 428 | return document 429 | 430 | async def find_document(self, name: str, query_key: Dict = None, filter_key: Dict = None) -> Optional[Dict]: 431 | """ 432 | 查询一个单独的document文档 433 | Args: 434 | name: collection name 435 | query_key: 查询document的过滤条件 436 | filter_key: 过滤返回值中字段的过滤条件 437 | Returns: 438 | 返回匹配的document或者None 439 | """ 440 | return await self._find_document(name, self._update_query_key(query_key), filter_key=filter_key) 441 | 442 | async def find_documents(self, name: str, query_key: Dict = None, filter_key: Dict = None, limit: int = 0, 443 | page: int = 1, sort: List[Tuple] = None) -> List[Dict]: 444 | """ 445 | 批量查询documents文档 446 | Args: 447 | name: collection name 448 | query_key: 查询document的过滤条件 449 | filter_key: 过滤返回值中字段的过滤条件 450 | limit: 每页数据的数量 451 | page: 查询第几页的数据 452 | sort: 排序方式,可以自定多种字段的排序,值为一个列表的键值对, eg:[('field1', pymongo.ASCENDING)] 453 | Returns: 454 | 返回匹配的document列表 455 | """ 456 | skip = (int(page) - 1) * int(limit) 457 | return await self._find_documents(name, self._update_query_key(query_key), filter_key=filter_key, 458 | limit=int(limit), skip=skip, sort=sort) 459 | 460 | async def find_count(self, name: str, query_key: Dict = None) -> int: 461 | """ 462 | 查询documents的数量 463 | Args: 464 | name: collection name 465 | query_key: 查询document的过滤条件 466 | Returns: 467 | 返回匹配的document数量 468 | """ 469 | return await self._find_count(name, self._update_query_key(query_key)) 470 | 471 | @staticmethod 472 | def _update_query_key(query_key: Dict) -> Dict: 473 | """ 474 | 更新查询的query 475 | Args: 476 | query_key: 查询document的过滤条件 477 | Returns: 478 | 返回处理后的query key 479 | """ 480 | query_key = dict(query_key) if query_key else {} 481 | try: 482 | for key, val in query_key.items(): 483 | if isinstance(val, MutableMapping): 484 | if key != "id": 485 | query_key[key] = {key if key.startswith("$") else f"${key}": val for key, val in val.items()} 486 | else: 487 | query_key["_id"] = { 488 | key if key.startswith("$") else f"${key}": [ObjectId(val) for val in val] 489 | if "in" in key else val for key, val in query_key.pop(key).items()} 490 | else: 491 | if key == "id": 492 | query_key["_id"] = ObjectId(query_key.pop("id")) 493 | except BSONError as e: 494 | raise FuncArgsError(str(e)) 495 | else: 496 | return query_key 497 | 498 | async def update_documents(self, name: str, query_key: Dict, update_data: Dict, upsert: bool = False) -> Dict: 499 | """ 500 | 更新匹配到的所有的document 501 | Args: 502 | name: collection name 503 | query_key: 查询document的过滤条件 504 | update_data: 对匹配的document进行更新的document 505 | upsert: 没有匹配到document的话执行插入操作,默认False 506 | Returns: 507 | 返回匹配的数量和修改数量的dict, eg:{"matched_count": 2, "modified_count": 2, "upserted_id":"f"} 508 | """ 509 | update_data = dict(update_data) 510 | return await self._update_documents(name, self._update_query_key(query_key), 511 | self._update_update_data(update_data), upsert=upsert) 512 | 513 | @staticmethod 514 | def _update_update_data(update_data: Dict) -> Dict: 515 | """ 516 | 处理update data, 包装最常使用的操作 517 | Args: 518 | update_data: 需要更新的文档值 519 | Returns: 520 | 返回处理后的update data 521 | """ 522 | 523 | # $set用的比较多,这里默认做个封装 524 | if len(update_data) > 1: 525 | update_data = {"$set": update_data} 526 | else: 527 | operator, doc = update_data.popitem() 528 | pre_flag = operator.startswith("$") 529 | update_data = {"$set" if not pre_flag else operator: {operator: doc} if not pre_flag else doc} 530 | return update_data 531 | 532 | async def update_document(self, name: str, query_key: Dict, update_data: Dict, upsert: bool = False) -> Dict: 533 | """ 534 | 更新匹配到的一个的document 535 | Args: 536 | name: collection name 537 | query_key: 查询document的过滤条件 538 | update_data: 对匹配的document进行更新的document 539 | upsert: 没有匹配到document的话执行插入操作,默认False 540 | Returns: 541 | 返回匹配的数量和修改数量的dict, eg:{"matched_count": 1, "modified_count": 1, "upserted_id":"f"} 542 | """ 543 | update_data = dict(update_data) 544 | return await self._update_document(name, self._update_query_key(query_key), 545 | self._update_update_data(update_data), upsert=upsert) 546 | 547 | async def delete_documents(self, name: str, query_key: Dict) -> int: 548 | """ 549 | 删除匹配到的所有的document 550 | Args: 551 | name: collection name 552 | query_key: 查询document的过滤条件 553 | Returns: 554 | 返回删除的数量 555 | """ 556 | return await self._delete_documents(name, self._update_query_key(query_key)) 557 | 558 | async def delete_document(self, name: str, query_key: Dict) -> int: 559 | """ 560 | 删除匹配到的一个的document 561 | Args: 562 | name: collection name 563 | query_key: 查询document的过滤条件 564 | Returns: 565 | 返回删除的数量 566 | """ 567 | return await self._delete_document(name, self._update_query_key(query_key)) 568 | 569 | async def aggregate(self, name: str, pipline: List[Dict], page: int = None, limit: int = None) -> List[Dict]: 570 | """ 571 | 根据pipline进行聚合查询 572 | Args: 573 | name: collection name 574 | pipline: 聚合查询的pipeline,包含一个后者多个聚合命令 575 | limit: 每页数据的数量 576 | page: 查询第几页的数据 577 | Returns: 578 | 返回聚合后的documents 579 | """ 580 | if not isinstance(pipline, MutableSequence): 581 | aelog.error("Aggregate query failed, pipline arg is not a iterable type.") 582 | raise MongoError("Aggregate query failed, pipline arg is not a iterable type.") 583 | if page is not None and limit is not None: 584 | pipline.extend([{'$skip': (int(page) - 1) * int(limit)}, {'$limit': int(limit)}]) 585 | return await self._aggregate(name, pipline) 586 | -------------------------------------------------------------------------------- /aclients/aio_mysql_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-25 下午4:58 8 | """ 9 | import asyncio 10 | import atexit 11 | from typing import (Dict, List, MutableMapping, MutableSequence, Optional, Tuple, Union) 12 | 13 | import aelog 14 | import sqlalchemy as sa 15 | from aiomysql.sa import create_engine 16 | from aiomysql.sa.exc import Error 17 | from aiomysql.sa.result import ResultProxy 18 | from pymysql.err import IntegrityError, MySQLError 19 | from sqlalchemy.exc import SQLAlchemyError 20 | from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base 21 | from sqlalchemy.orm.attributes import InstrumentedAttribute 22 | from sqlalchemy.sql import (all_, and_, any_, asc, bindparam, case, cast, column, delete, desc, 23 | distinct, except_, except_all, exists, extract, false, func, funcfilter, insert, intersect, 24 | intersect_all, join, label, not_, null, nullsfirst, nullslast, or_, outerjoin, over, select, 25 | table, text, true, tuple_, type_coerce, union, union_all, update, within_group) 26 | 27 | from .err_msg import mysql_msg 28 | from .exceptions import FuncArgsError, HttpError, MysqlDuplicateKeyError, MysqlError, QueryArgsError 29 | from .utils import gen_class_name, verify_message 30 | 31 | __all__ = ("AIOMysqlClient", "all_", "any_", "and_", "or_", "bindparam", "select", "text", "table", "column", 32 | "over", "within_group", "label", "case", "cast", "extract", "tuple_", "except_", "except_all", "intersect", 33 | "intersect_all", "union", "union_all", "exists", "nullsfirst", "nullslast", "asc", "desc", "distinct", 34 | "type_coerce", "true", "false", "null", "join", "outerjoin", "funcfilter", "func", "not_") 35 | 36 | 37 | class AIOMysqlClient(object): 38 | """ 39 | MySQL异步操作指南 40 | """ 41 | Model = declarative_base() 42 | 43 | def __init__(self, app=None, *, username="root", passwd=None, host="127.0.0.1", port=3306, dbname=None, 44 | pool_size=50, **kwargs): 45 | """ 46 | mysql 非阻塞工具类 47 | Args: 48 | app: app应用 49 | host:mysql host 50 | port:mysql port 51 | dbname: database name 52 | username: mysql user 53 | passwd: mysql password 54 | pool_size: mysql pool size 55 | """ 56 | self.app = app 57 | self.aio_engine = None 58 | # default bind connection 59 | self.username = username 60 | self.passwd = passwd 61 | self.host = host 62 | self.port = port 63 | self.dbname = dbname 64 | self.pool_size = pool_size 65 | # other info 66 | self.pool_recycle = kwargs.get("pool_recycle", 3600) # free close time 67 | self.charset = "utf8mb4" 68 | self.message = kwargs.get("message", {}) 69 | self.use_zh = kwargs.get("use_zh", True) 70 | self.msg_zh = None 71 | 72 | if app is not None: 73 | self.init_app(app, username=self.username, passwd=self.passwd, host=self.host, port=self.port, 74 | dbname=self.dbname, pool_size=self.pool_size, **kwargs) 75 | 76 | def init_app(self, app, *, username=None, passwd=None, host=None, port=None, dbname=None, 77 | pool_size=None, **kwargs): 78 | """ 79 | mysql 实例初始化 80 | Args: 81 | app: app应用 82 | host:mysql host 83 | port:mysql port 84 | dbname: database name 85 | username: mysql user 86 | passwd: mysql password 87 | pool_size: mysql pool size 88 | 89 | Returns: 90 | 91 | """ 92 | username = username or app.config.get("ACLIENTS_MYSQL_USERNAME", None) or self.username 93 | passwd = passwd or app.config.get("ACLIENTS_MYSQL_PASSWD", None) or self.passwd 94 | host = host or app.config.get("ACLIENTS_MYSQL_HOST", None) or self.host 95 | port = port or app.config.get("ACLIENTS_MYSQL_PORT", None) or self.port 96 | dbname = dbname or app.config.get("ACLIENTS_MYSQL_DBNAME", None) or self.dbname 97 | self.pool_size = pool_size or app.config.get("ACLIENTS_MYSQL_POOL_SIZE", None) or self.pool_size 98 | 99 | self.pool_recycle = kwargs.get("pool_recycle") or app.config.get( 100 | "ACLIENTS_POOL_RECYCLE", None) or self.pool_recycle 101 | 102 | message = kwargs.get("message") or app.config.get("ACLIENTS_MYSQL_MESSAGE", None) or self.message 103 | use_zh = kwargs.get("use_zh") or app.config.get("ACLIENTS_MYSQL_MSGZH", None) or self.use_zh 104 | 105 | passwd = passwd if passwd is None else str(passwd) 106 | self.message = verify_message(mysql_msg, message) 107 | self.msg_zh = "msg_zh" if use_zh else "msg_en" 108 | self.app = app 109 | 110 | @app.listener('before_server_start') 111 | async def open_connection(app_, loop): 112 | """ 113 | 114 | Args: 115 | 116 | Returns: 117 | 118 | """ 119 | # engine 120 | self.aio_engine = await create_engine( 121 | host=host, port=port, user=username, password=passwd, db=dbname, maxsize=self.pool_size, 122 | pool_recycle=self.pool_recycle, charset=self.charset, connect_timeout=60) 123 | 124 | @app.listener('after_server_stop') 125 | async def close_connection(app_, loop): 126 | """ 127 | 128 | Args: 129 | 130 | Returns: 131 | 132 | """ 133 | if self.aio_engine: 134 | self.aio_engine.close() 135 | await self.aio_engine.wait_closed() 136 | 137 | def init_engine(self, *, username="root", passwd=None, host="127.0.0.1", port=3306, dbname=None, 138 | pool_size=50, **kwargs): 139 | """ 140 | mysql 实例初始化 141 | Args: 142 | host:mysql host 143 | port:mysql port 144 | dbname: database name 145 | username: mysql user 146 | passwd: mysql password 147 | pool_size: mysql pool size 148 | 149 | Returns: 150 | 151 | """ 152 | username = username or self.username 153 | passwd = passwd or self.passwd 154 | host = host or self.host 155 | port = port or self.port 156 | dbname = dbname or self.dbname 157 | self.pool_size = pool_size or self.pool_size 158 | 159 | self.pool_recycle = kwargs.get("pool_recycle") or self.pool_recycle 160 | 161 | message = kwargs.get("message") or self.message 162 | use_zh = kwargs.get("use_zh") or self.use_zh 163 | 164 | passwd = passwd if passwd is None else str(passwd) 165 | self.message = verify_message(mysql_msg, message) 166 | self.msg_zh = "msg_zh" if use_zh else "msg_en" 167 | loop = asyncio.get_event_loop() 168 | 169 | async def open_connection(): 170 | """ 171 | 172 | Args: 173 | 174 | Returns: 175 | 176 | """ 177 | # engine 178 | self.aio_engine = await create_engine( 179 | host=host, port=port, user=username, password=passwd, db=dbname, maxsize=self.pool_size, 180 | pool_recycle=self.pool_recycle, charset=self.charset, connect_timeout=60) 181 | 182 | async def close_connection(): 183 | """ 184 | 185 | Args: 186 | 187 | Returns: 188 | 189 | """ 190 | if self.aio_engine: 191 | self.aio_engine.close() 192 | await self.aio_engine.wait_closed() 193 | 194 | loop.run_until_complete(open_connection()) 195 | atexit.register(lambda: loop.run_until_complete(close_connection())) 196 | 197 | @staticmethod 198 | def _get_model_default_value(model) -> Dict: 199 | """ 200 | 201 | Args: 202 | model 203 | Returns: 204 | 205 | """ 206 | default_values = {} 207 | for key, val in model.__dict__.items(): 208 | if not key.startswith("_") and isinstance(val, InstrumentedAttribute): 209 | if val.default: 210 | if val.default.is_callable: 211 | default_values[key] = val.default.arg.__wrapped__() 212 | else: 213 | default_values[key] = val.default.arg 214 | return default_values 215 | 216 | @staticmethod 217 | def _get_model_onupdate_value(model) -> Dict: 218 | """ 219 | 220 | Args: 221 | model 222 | Returns: 223 | 224 | """ 225 | update_values = {} 226 | for key, val in model.__dict__.items(): 227 | if not key.startswith("_") and isinstance(val, InstrumentedAttribute): 228 | if val.onupdate and val.onupdate.is_callable: 229 | update_values[key] = val.onupdate.arg.__wrapped__() 230 | return update_values 231 | 232 | async def _insert_one(self, model, insert_data: Dict) -> Tuple[int, str]: 233 | """ 234 | 插入数据 235 | Args: 236 | model: model 237 | insert_data: 值类型 238 | Returns: 239 | 返回插入的条数 240 | """ 241 | try: 242 | query = insert(model).values(insert_data) 243 | new_values = self._get_model_default_value(model) 244 | new_values.update(insert_data) 245 | except SQLAlchemyError as e: 246 | aelog.exception(e) 247 | raise QueryArgsError(message="Cloumn args error: {}".format(str(e))) 248 | else: 249 | async with self.aio_engine.acquire() as conn: 250 | async with conn.begin() as trans: 251 | try: 252 | cursor = await conn.execute(query, new_values) 253 | await trans.commit() 254 | except IntegrityError as e: 255 | await trans.rollback() 256 | aelog.exception(e) 257 | if "Duplicate" in str(e): 258 | raise MysqlDuplicateKeyError(e) 259 | else: 260 | raise MysqlError(e) 261 | except (MySQLError, Error) as e: 262 | await trans.rollback() 263 | aelog.exception(e) 264 | raise MysqlError(e) 265 | except Exception as e: 266 | await trans.rollback() 267 | aelog.exception(e) 268 | raise HttpError(500, message=self.message[1][self.msg_zh], error=e) 269 | 270 | return cursor.rowcount, new_values.get("id") or cursor.lastrowid 271 | 272 | async def _update_data(self, model, query_key: Dict, or_query_key: Dict, update_data: Dict) -> int: 273 | """ 274 | 更新数据 275 | Args: 276 | model: model 277 | query_key: 更新的查询条件 278 | update_data: 值类型 279 | or_query_key: 或查询model的过滤条件 280 | Returns: 281 | 返回更新的条数 282 | """ 283 | try: 284 | query = update(model) 285 | if query_key or or_query_key: 286 | query = self._column_expression(model, query, query_key, or_query_key) 287 | query = query.values(update_data) 288 | new_values = self._get_model_onupdate_value(model) 289 | new_values.update(update_data) 290 | except SQLAlchemyError as e: 291 | aelog.exception(e) 292 | raise QueryArgsError(message="Cloumn args error: {}".format(str(e))) 293 | else: 294 | async with self.aio_engine.acquire() as conn: 295 | async with conn.begin() as trans: 296 | try: 297 | cursor = await conn.execute(query, new_values) 298 | await trans.commit() 299 | except IntegrityError as e: 300 | await trans.rollback() 301 | aelog.exception(e) 302 | if "Duplicate" in str(e): 303 | raise MysqlDuplicateKeyError(e) 304 | else: 305 | raise MysqlError(e) 306 | except (MySQLError, Error) as e: 307 | await trans.rollback() 308 | aelog.exception(e) 309 | raise MysqlError(e) 310 | except Exception as e: 311 | await trans.rollback() 312 | aelog.exception(e) 313 | raise HttpError(500, message=self.message[2][self.msg_zh], error=e) 314 | 315 | return cursor.rowcount 316 | 317 | async def _delete_data(self, model, query_key: Dict, or_query_key: Dict) -> int: 318 | """ 319 | 更新数据 320 | Args: 321 | model: model 322 | query_key: 删除的查询条件 323 | or_query_key: 或查询model的过滤条件 324 | Returns: 325 | 返回删除的条数 326 | """ 327 | try: 328 | query = delete(model) 329 | if query_key or or_query_key: 330 | query = self._column_expression(model, query, query_key, or_query_key) 331 | except SQLAlchemyError as e: 332 | aelog.exception(e) 333 | raise QueryArgsError(message="Cloumn args error: {}".format(str(e))) 334 | else: 335 | async with self.aio_engine.acquire() as conn: 336 | async with conn.begin() as trans: 337 | try: 338 | cursor = await conn.execute(query) 339 | await trans.commit() 340 | except (MySQLError, Error) as e: 341 | await trans.rollback() 342 | aelog.exception(e) 343 | raise MysqlError(e) 344 | except Exception as e: 345 | await trans.rollback() 346 | aelog.exception(e) 347 | raise HttpError(500, message=self.message[3][self.msg_zh], error=e) 348 | 349 | return cursor.rowcount 350 | 351 | async def _find_one(self, model: List, query_key: Dict, or_query_key: Dict) -> Optional[Dict]: 352 | """ 353 | 查询单条数据 354 | Args: 355 | model: 查询的model名称 356 | query_key: 查询model的过滤条件 357 | or_query_key: 或查询model的过滤条件 358 | Returns: 359 | 返回匹配的数据或者None 360 | """ 361 | try: 362 | query = select(model) 363 | if query_key or or_query_key: 364 | query = self._column_expression(model, query, query_key, or_query_key) 365 | except SQLAlchemyError as e: 366 | aelog.exception(e) 367 | raise QueryArgsError(message="Cloumn args error: {}".format(str(e))) 368 | else: 369 | try: 370 | async with self.aio_engine.acquire() as conn: 371 | async with conn.execute(query) as cursor: 372 | resp = await cursor.fetchone() 373 | await conn.execute('commit') # 理论上不应该加这个的,但是这里默认就会启动一个事务,很奇怪 374 | except (MySQLError, Error) as err: 375 | aelog.exception("Find one data failed, {}".format(err)) 376 | raise HttpError(400, message=self.message[4][self.msg_zh], error=err) 377 | else: 378 | return dict(resp) if resp else None 379 | 380 | async def _find_data(self, model: List, query_key: Dict, or_query_key: Dict, limit: int, 381 | skip: int, order: Tuple) -> List[Dict]: 382 | """ 383 | 查询单条数据 384 | Args: 385 | model: 查询的model名称 386 | query_key: 查询model的过滤条件 387 | limit: 每页条数 388 | skip: 需要跳过的条数 389 | order: 排序条件 390 | or_query_key: 或查询model的过滤条件 391 | Returns: 392 | 返回匹配的数据或者None 393 | """ 394 | try: 395 | query = select(model) 396 | if query_key or or_query_key: 397 | query = self._column_expression(model, query, query_key, or_query_key) 398 | if order: 399 | query = query.order_by(desc(order[0])) if order[1] == 1 else query.order_by(order[0]) 400 | else: 401 | model_ = model[0] if isinstance(model, MutableSequence) else model 402 | if getattr(model_, "id", None) is not None: 403 | query = query.order_by(asc("id")) 404 | if limit: 405 | query = query.limit(limit) 406 | if skip: 407 | query = query.offset(skip) 408 | except SQLAlchemyError as e: 409 | aelog.exception(e) 410 | raise QueryArgsError(message="Cloumn args error: {}".format(str(e))) 411 | else: 412 | try: 413 | async with self.aio_engine.acquire() as conn: 414 | async with conn.execute(query) as cursor: 415 | resp = await cursor.fetchall() 416 | await conn.execute('commit') 417 | except (MySQLError, Error) as err: 418 | aelog.exception("Find data failed, {}".format(err)) 419 | raise HttpError(400, message=self.message[5][self.msg_zh], error=err) 420 | else: 421 | return [dict(val) for val in resp] if resp else [] 422 | 423 | async def _find_count(self, model, query_key: Dict, or_query_key: Dict) -> int: 424 | """ 425 | 查询单条数据 426 | Args: 427 | model: 查询的model名称 428 | query_key: 查询model的过滤条件 429 | or_query_key: 或查询model的过滤条件 430 | Returns: 431 | 返回条数 432 | """ 433 | try: 434 | query = select([func.count().label("count")]).select_from(model) 435 | if query_key or or_query_key: 436 | query = self._column_expression(model, query, query_key, or_query_key) 437 | except SQLAlchemyError as e: 438 | aelog.exception(e) 439 | raise QueryArgsError(message="Cloumn args error: {}".format(str(e))) 440 | else: 441 | try: 442 | async with self.aio_engine.acquire() as conn: 443 | async with conn.execute(query) as cursor: 444 | resp = await cursor.fetchone() 445 | await conn.execute('commit') 446 | except (MySQLError, Error) as err: 447 | aelog.exception("Find data failed, {}".format(err)) 448 | raise HttpError(400, message=self.message[5][self.msg_zh], error=err) 449 | else: 450 | return resp.count 451 | 452 | @staticmethod 453 | def _column_expression(model: List, query, query_key: Dict, or_query_key: Dict): 454 | """ 455 | 查询单条数据 456 | Args: 457 | model: 查询的model名称 458 | query: 查询的query基本expression 459 | query_key: 查询model的过滤条件 460 | or_query_key: 或查询model的过滤条件 461 | Returns: 462 | 返回匹配的数据或者None 463 | """ 464 | model = model if isinstance(model, MutableSequence) else [model] 465 | query_key = query_key if isinstance(query_key, MutableMapping) else {} 466 | or_query_key = or_query_key if isinstance(or_query_key, MutableMapping) else {} 467 | 468 | maps = { 469 | "eq": lambda column_name, column_val: query.where(column_name == column_val), 470 | "ne": lambda column_name, column_val: query.where(column_name != column_val), 471 | "gt": lambda column_name, column_val: query.where(column_name > column_val), 472 | "gte": lambda column_name, column_val: query.where(column_name >= column_val), 473 | "lt": lambda column_name, column_val: query.where(column_name < column_val), 474 | "lte": lambda column_name, column_val: query.where(column_name <= column_val), 475 | "in": lambda column_name, column_val: query.where(column_name.in_(column_val)), 476 | "nin": lambda column_name, column_val: query.where(column_name.notin_(column_val)), 477 | "like": lambda column_name, column_val: query.where(column_name.like(column_val)), 478 | "ilike": lambda column_name, column_val: query.where(column_name.ilike(column_val)), 479 | "between": lambda column_name, column_val: query.where( 480 | column_name.between(column_val[0], column_val[1]))} 481 | or_maps = { 482 | "eq": lambda column_name, column_val: column_name == column_val, 483 | "ne": lambda column_name, column_val: column_name != column_val, 484 | "gt": lambda column_name, column_val: column_name > column_val, 485 | "gte": lambda column_name, column_val: column_name >= column_val, 486 | "lt": lambda column_name, column_val: column_name < column_val, 487 | "lte": lambda column_name, column_val: column_name <= column_val, 488 | "in": lambda column_name, column_val: column_name.in_(column_val), 489 | "nin": lambda column_name, column_val: column_name.notin_(column_val), 490 | "like": lambda column_name, column_val: column_name.like(column_val), 491 | "ilike": lambda column_name, column_val: column_name.ilike(column_val), 492 | "between": lambda column_name, column_val: column_name.between(column_val[0], column_val[1]) 493 | } 494 | 495 | def or_query(column_val: MutableSequence): 496 | """ 497 | 组装or查询表达式 498 | Args: 499 | 500 | Returns: 501 | 502 | """ 503 | return query.where(or_(*column_val)) 504 | 505 | # 如果出现[model1, model2]则只考虑第一个,因为如果多表查询,则必须在query_key中指定清楚,这里只处理大多数情况 506 | model = model[0] if not isinstance(model[0], InstrumentedAttribute) else getattr(model[0], "class_") 507 | for field_name, val in query_key.items(): 508 | field_name = getattr(model, field_name) if not isinstance(field_name, InstrumentedAttribute) else field_name 509 | # 因为判断相等的查询比较多,因此默认就是== 510 | if not isinstance(val, MutableMapping): 511 | query = maps["eq"](field_name, val) 512 | else: 513 | # 其他情况则需要指定是什么查询,大于、小于或者是like等 514 | # 可能会出现多个查询的情况,比如{"key": {"gt": 3, "lt": 9}} 515 | for operate, value in val.items(): 516 | if operate in maps: 517 | query = maps[operate](field_name, value) 518 | # or 查询 {"key": {"gt": 3, "lt": 9}} 519 | for field_name, or_value in or_query_key.items(): 520 | field_name = getattr(model, field_name) if not isinstance(field_name, InstrumentedAttribute) else field_name 521 | or_args = [] 522 | for operate, sub_or_value in or_value.items(): 523 | if operate in or_maps: 524 | if not isinstance(sub_or_value, MutableSequence): 525 | or_args.append(or_maps[operate](field_name, sub_or_value)) 526 | else: 527 | for val in sub_or_value: 528 | or_args.append(or_maps[operate](field_name, val)) 529 | else: 530 | query = or_query(or_args) 531 | return query 532 | 533 | async def execute(self, query) -> ResultProxy: 534 | """ 535 | 插入数据,更新或者删除数据 536 | Args: 537 | query: SQL的查询字符串或者sqlalchemy表达式 538 | Returns: 539 | 不确定执行的是什么查询,直接返回ResultProxy实例 540 | """ 541 | async with self.aio_engine.acquire() as conn: 542 | async with conn.begin() as trans: 543 | try: 544 | cursor = await conn.execute(query) 545 | await trans.commit() 546 | except IntegrityError as e: 547 | await trans.rollback() 548 | aelog.exception(e) 549 | if "Duplicate" in str(e): 550 | raise MysqlDuplicateKeyError(e) 551 | else: 552 | raise MysqlError(e) 553 | except (MySQLError, Error) as e: 554 | await trans.rollback() 555 | aelog.exception(e) 556 | raise MysqlError(e) 557 | except Exception as e: 558 | await trans.rollback() 559 | aelog.exception(e) 560 | raise HttpError(500, message=self.message[6][self.msg_zh], error=e) 561 | 562 | return cursor 563 | 564 | async def query(self, query) -> Optional[List[Dict]]: 565 | """ 566 | 查询数据,用于复杂的查询 567 | Args: 568 | query: SQL的查询字符串或者sqlalchemy表达式 569 | Returns: 570 | 不确定执行的是什么查询,直接返回ResultProxy实例 571 | """ 572 | try: 573 | async with self.aio_engine.acquire() as conn: 574 | async with conn.execute(query) as cursor: 575 | resp = await cursor.fetchall() 576 | await conn.execute('commit') 577 | except (MySQLError, Error) as err: 578 | aelog.exception("Find data failed, {}".format(err)) 579 | raise HttpError(400, message=self.message[5][self.msg_zh], error=err) 580 | else: 581 | return [dict(val) for val in resp] if resp else None 582 | 583 | async def insert_one(self, model, *, insert_data: Dict) -> Tuple[int, str]: 584 | """ 585 | 插入数据 586 | Args: 587 | model: model 588 | insert_data: 值类型 589 | Returns: 590 | 返回插入的条数 591 | """ 592 | return await self._insert_one(model, insert_data) 593 | 594 | async def find_one(self, model: Union[DeclarativeMeta, List], *, query_key: Dict = None, 595 | or_query_key: Dict = None) -> Optional[Dict]: 596 | """ 597 | 查询单条数据 598 | Args: 599 | model: 查询的model名称 600 | query_key: 查询model的过滤条件 601 | or_query_key: 或查询model的过滤条件 602 | Returns: 603 | 返回匹配的数据或者None 604 | """ 605 | model = model if isinstance(model, MutableSequence) else [model] 606 | return await self._find_one(model, query_key, or_query_key) 607 | 608 | async def find_data(self, model: Union[DeclarativeMeta, List], *, query_key: Dict = None, 609 | or_query_key: Dict = None, limit: int = 0, page: int = 1, 610 | order: Tuple = None) -> List[Dict]: 611 | """ 612 | 插入数据 613 | Args: 614 | model: model 615 | query_key: 查询表的过滤条件, {"key": {"gt": 3, "lt": 9}} 616 | or_query_key: 或查询model的过滤条件,{"key": {"gt": 3, "lt": 9}},{"key": {"eq": [3, 8]}} 617 | limit: 限制返回的表的条数 618 | page: 从查询结果中调过指定数量的行 619 | order: 排序条件 620 | Returns: 621 | 622 | """ 623 | if order and not isinstance(order, (list, tuple)): 624 | raise FuncArgsError("order must be tuple or list!") 625 | limit = int(limit) 626 | skip = (int(page) - 1) * limit 627 | model = model if isinstance(model, MutableSequence) else [model] 628 | return await self._find_data(model, query_key, or_query_key, limit=limit, skip=skip, order=order) 629 | 630 | async def find_count(self, model, *, query_key: Dict = None, or_query_key: Dict = None) -> int: 631 | """ 632 | 查询单条数据 633 | Args: 634 | model: 查询的model名称 635 | query_key: 查询model的过滤条件 636 | or_query_key: 或查询model的过滤条件 637 | Returns: 638 | 返回总条数 639 | """ 640 | return await self._find_count(model, query_key, or_query_key) 641 | 642 | async def update_data(self, model, *, query_key: Dict, or_query_key: Dict = None, update_data: Dict) -> int: 643 | """ 644 | 更新数据 645 | Args: 646 | model: model 647 | query_key: 更新的查询条件 648 | or_query_key: 或查询model的过滤条件 649 | update_data: 值类型 650 | Returns: 651 | 返回更新的条数 652 | """ 653 | return await self._update_data(model, query_key, or_query_key, update_data) 654 | 655 | async def delete_data(self, model, *, query_key: Dict, or_query_key: Dict = None) -> int: 656 | """ 657 | 更新数据 658 | Args: 659 | model: model 660 | query_key: 删除的查询条件, 必须有query_key,不允许删除整张表 661 | or_query_key: 或查询model的过滤条件 662 | Returns: 663 | 返回删除的条数 664 | """ 665 | if not query_key: 666 | raise FuncArgsError("query_key must be provide!") 667 | return await self._delete_data(model, query_key, or_query_key) 668 | 669 | def gen_model(self, model_cls, suffix: str = None, **kwargs): 670 | """ 671 | 用于根据现有的model生成新的model类 672 | 673 | 主要用于分表的查询和插入 674 | Args: 675 | model_cls: 要生成分表的model类 676 | suffix: 新的model类名的后缀 677 | kwargs: 其他的参数 678 | Returns: 679 | 680 | """ 681 | if kwargs: 682 | aelog.info(kwargs) 683 | if not issubclass(model_cls, DeclarativeMeta): 684 | raise ValueError("model_cls must be db.Model type.") 685 | 686 | table_name = f"{getattr(model_cls, '__tablename__', model_cls.__name__)}_{suffix}" 687 | class_name = f"{gen_class_name(table_name)}Model" 688 | if getattr(model_cls, "_cache_class", None) is None: 689 | setattr(model_cls, "_cache_class", {}) 690 | 691 | model_cls_ = getattr(model_cls, "_cache_class").get(class_name, None) 692 | if model_cls_ is None: 693 | model_fields = {} 694 | for attr_name, field in model_cls.__dict__.items(): 695 | if isinstance(field, InstrumentedAttribute) and not attr_name.startswith("_"): 696 | model_fields[attr_name] = sa.Column( 697 | type_=field.type, primary_key=field.primary_key, index=field.index, nullable=field.nullable, 698 | default=field.default, onupdate=field.onupdate, unique=field.unique, 699 | autoincrement=field.autoincrement, doc=field.doc) 700 | model_cls_ = type(class_name, (self.Model,), { 701 | "__doc__": model_cls.__doc__, 702 | "__table_args__ ": getattr( 703 | model_cls, "__table_args__", None) or {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'}, 704 | "__tablename__": table_name, 705 | "__module__": model_cls.__module__, 706 | **model_fields}) 707 | getattr(model_cls, "_cache_class")[class_name] = model_cls_ 708 | 709 | return model_cls_ 710 | -------------------------------------------------------------------------------- /aclients/aio_redis_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-25 下午5:15 8 | """ 9 | import atexit 10 | import secrets 11 | import uuid 12 | from collections import MutableMapping 13 | from typing import Any, Dict, List, NoReturn, Union 14 | 15 | import aelog 16 | import aredis 17 | import ujson 18 | from aredis import RedisError 19 | 20 | from .exceptions import RedisClientError 21 | from .utils import async_ignore_error 22 | 23 | __all__ = ("Session", "AIORedisClient") 24 | 25 | LONG_EXPIRED: int = 24 * 60 * 60 # 最长过期时间 26 | EXPIRED: int = 12 * 60 * 60 # 通用过期时间 27 | SESSION_EXPIRED: int = 60 * 60 # session过期时间 28 | 29 | 30 | class Session(object): 31 | """ 32 | 保存实际看结果的session实例 33 | Args: 34 | 35 | """ 36 | 37 | def __init__(self, account_id: str, *, session_id: str = None, org_id: str = None, role_id: str = None, 38 | permission_id: str = None, **kwargs): 39 | self.account_id = account_id # 账户ID 40 | self.session_id = secrets.token_urlsafe() if not session_id else session_id # session ID 41 | self.org_id = org_id or uuid.uuid4().hex # 账户的组织结构在redis中的ID 42 | self.role_id = role_id or uuid.uuid4().hex # 账户的角色在redis中的ID 43 | self.permission_id = permission_id or uuid.uuid4().hex # 账户的权限在redis中的ID 44 | self.static_permission_id = uuid.uuid4().hex # 账户的静态权限在redis中的ID 45 | self.dynamic_permission_id = uuid.uuid4().hex # 账户的动态权限在redis中的ID 46 | self.page_id = uuid.uuid4().hex # 账户的页面权限在redis中的ID 47 | self.page_menu_id = uuid.uuid4().hex # 账户的页面菜单权限在redis中的ID 48 | for k, v in kwargs.items(): 49 | setattr(self, k, v) 50 | 51 | 52 | class AIORedisClient(object): 53 | """ 54 | redis 非阻塞工具类 55 | """ 56 | 57 | def __init__(self, app=None, *, host: str = "127.0.0.1", port: int = 6379, dbname: int = 0, passwd: str = "", 58 | pool_size: int = 50): 59 | """ 60 | redis 非阻塞工具类 61 | Args: 62 | app: app应用 63 | host:redis host 64 | port:redis port 65 | dbname: database name 66 | passwd: redis password 67 | pool_size: redis pool size 68 | """ 69 | self.app = app 70 | self.pool = None 71 | self.redis_db: aredis.StrictRedis = None 72 | self.host = host 73 | self.port = port 74 | self.dbname = dbname 75 | self.passwd = passwd 76 | self.pool_size = pool_size 77 | self._account_key = "account_to_session" 78 | 79 | if app is not None: 80 | self.init_app(app, host=self.host, port=self.port, dbname=self.dbname, passwd=self.passwd, 81 | pool_size=self.pool_size) 82 | 83 | def init_app(self, app, *, host: str = None, port: int = None, dbname: int = None, passwd: str = "", 84 | pool_size: int = None): 85 | """ 86 | redis 非阻塞工具类 87 | Args: 88 | app: app应用 89 | host:redis host 90 | port:redis port 91 | dbname: database name 92 | passwd: redis password 93 | pool_size: redis pool size 94 | Returns: 95 | 96 | """ 97 | self.app = app 98 | host = host or app.config.get("ACLIENTS_REDIS_HOST", None) or self.host 99 | port = port or app.config.get("ACLIENTS_REDIS_PORT", None) or self.port 100 | dbname = dbname or app.config.get("ACLIENTS_REDIS_DBNAME", None) or self.dbname 101 | passwd = passwd or app.config.get("ACLIENTS_REDIS_PASSWD", None) or self.passwd 102 | pool_size = pool_size or app.config.get("ACLIENTS_REDIS_POOL_SIZE", None) or self.pool_size 103 | 104 | passwd = passwd if passwd is None else str(passwd) 105 | 106 | @app.listener('before_server_start') 107 | async def open_connection(app_, loop): 108 | """ 109 | 110 | Args: 111 | 112 | Returns: 113 | 114 | """ 115 | # 返回值都做了解码,应用层不需要再decode 116 | self.pool = aredis.ConnectionPool(host=host, port=port, db=dbname, password=passwd, decode_responses=True, 117 | max_connections=pool_size) 118 | self.redis_db = aredis.StrictRedis(connection_pool=self.pool, decode_responses=True) 119 | 120 | @app.listener('after_server_stop') 121 | async def close_connection(app_, loop): 122 | """ 123 | 释放redis连接池所有连接 124 | Args: 125 | 126 | Returns: 127 | 128 | """ 129 | self.redis_db = None 130 | if self.pool: 131 | self.pool.disconnect() 132 | 133 | def init_engine(self, *, host: str = None, port: int = None, dbname: int = None, passwd: str = "", 134 | pool_size: int = None): 135 | """ 136 | redis 非阻塞工具类 137 | Args: 138 | host:redis host 139 | port:redis port 140 | dbname: database name 141 | passwd: redis password 142 | pool_size: redis pool size 143 | Returns: 144 | 145 | """ 146 | host = host or self.host 147 | port = port or self.port 148 | dbname = dbname or self.dbname 149 | passwd = passwd or self.passwd 150 | pool_size = pool_size or self.pool_size 151 | 152 | passwd = passwd if passwd is None else str(passwd) 153 | # 返回值都做了解码,应用层不需要再decode 154 | self.pool = aredis.ConnectionPool(host=host, port=port, db=dbname, password=passwd, decode_responses=True, 155 | max_connections=pool_size) 156 | self.redis_db = aredis.StrictRedis(connection_pool=self.pool, decode_responses=True) 157 | 158 | @atexit.register 159 | def close_connection(): 160 | """ 161 | 释放redis连接池所有连接 162 | Args: 163 | 164 | Returns: 165 | 166 | """ 167 | self.redis_db = None 168 | if self.pool: 169 | self.pool.disconnect() 170 | 171 | async def save_session(self, session: Session, dump_responses: bool = False, ex: int = SESSION_EXPIRED) -> str: 172 | """ 173 | 利用hash map保存session 174 | Args: 175 | session: Session 实例 176 | dump_responses: 是否对每个键值进行dump 177 | ex: 过期时间,单位秒 178 | Returns: 179 | 180 | """ 181 | session_data = await self.response_dumps(dump_responses, session) 182 | 183 | try: 184 | if not await self.redis_db.hmset(session_data["session_id"], session_data): 185 | raise RedisClientError("save session failed, session_id={}".format(session_data["session_id"])) 186 | if not await self.redis_db.expire(session_data["session_id"], ex): 187 | aelog.error("set session expire failed, session_id={}".format(session_data["session_id"])) 188 | except RedisError as e: 189 | aelog.exception("save session error: {}, {}".format(session.session_id, e)) 190 | raise RedisClientError(str(e)) 191 | else: 192 | # 清除老的令牌 193 | # try: 194 | # old_session_id = await self.get_hash_data(self._account_key, field_name=session.account_id) 195 | # except RedisClientError as e: 196 | # aelog.info(f"{session.account_id} no old token token, {str(e)}") 197 | # else: 198 | # async with async_ignore_error(): 199 | # await self.delete_session(old_session_id, False) 200 | # # 更新新的令牌 201 | # await self.save_update_hash_data(self._account_key, field_name=session.account_id, 202 | # hash_data=session.session_id, ex=LONG_EXPIRED) 203 | return session.session_id 204 | 205 | @staticmethod 206 | async def response_dumps(dump_responses: bool, session: Session) -> Dict: 207 | session_data = dict(vars(session)) 208 | # 是否对每个键值进行dump 209 | if dump_responses: 210 | hash_data = {} 211 | for hash_key, hash_val in session_data.items(): 212 | if not isinstance(hash_val, str): 213 | async with async_ignore_error(): 214 | hash_val = ujson.dumps(hash_val) 215 | hash_data[hash_key] = hash_val 216 | session_data = hash_data 217 | return session_data 218 | 219 | async def delete_session(self, session_id: str, delete_key: bool = True) -> NoReturn: 220 | """ 221 | 利用hash map删除session 222 | Args: 223 | session_id: session id 224 | delete_key: 删除account到session的account key 225 | Returns: 226 | 227 | """ 228 | 229 | try: 230 | session_id_ = await self.redis_db.hget(session_id, "session_id") 231 | if session_id_ != session_id: 232 | raise RedisClientError("invalid session_id, session_id={}".format(session_id)) 233 | exist_keys = [] 234 | session_data = await self.get_session(session_id, cls_flag=False) 235 | exist_keys.append(session_data["org_id"]) 236 | exist_keys.append(session_data["role_id"]) 237 | exist_keys.append(session_data["permission_id"]) 238 | exist_keys.append(session_data["static_permission_id"]) 239 | exist_keys.append(session_data["dynamic_permission_id"]) 240 | exist_keys.append(session_data["page_id"]) 241 | exist_keys.append(session_data["page_menu_id"]) 242 | 243 | async with async_ignore_error(): # 删除已经存在的和账户相关的缓存key 244 | await self.delete_keys(exist_keys) 245 | # if delete_key is True: 246 | # await self.redis_db.hdel(self._account_key, session_data["account_id"]) 247 | 248 | if not await self.redis_db.delete(session_id): 249 | aelog.error("delete session failed, session_id={}".format(session_id)) 250 | except RedisError as e: 251 | aelog.exception("delete session error: {}, {}".format(session_id, e)) 252 | raise RedisClientError(str(e)) 253 | 254 | async def update_session(self, session: Session, dump_responses: bool = False, 255 | ex: int = SESSION_EXPIRED) -> NoReturn: 256 | """ 257 | 利用hash map更新session 258 | Args: 259 | session: Session实例 260 | ex: 过期时间,单位秒 261 | dump_responses: 是否对每个键值进行dump 262 | Returns: 263 | 264 | """ 265 | session_data = await self.response_dumps(dump_responses, session) 266 | 267 | try: 268 | if not await self.redis_db.hmset(session_data["session_id"], session_data): 269 | raise RedisClientError("update session failed, session_id={}".format(session_data["session_id"])) 270 | if not await self.redis_db.expire(session_data["session_id"], ex): 271 | aelog.error("set session expire failed, session_id={}".format(session_data["session_id"])) 272 | except RedisError as e: 273 | aelog.exception("update session error: {}, {}".format(session_data["session_id"], e)) 274 | raise RedisClientError(str(e)) 275 | # else: 276 | # 更新令牌 277 | # await self.save_update_hash_data(self._account_key, field_name=session.account_id, 278 | # hash_data=session.session_id, ex=LONG_EXPIRED) 279 | 280 | async def get_session(self, session_id: str, ex: int = SESSION_EXPIRED, cls_flag: bool = True, 281 | load_responses: bool = False) -> Union[Session, Dict[str, str]]: 282 | """ 283 | 获取session 284 | Args: 285 | session_id: session id 286 | ex: 过期时间,单位秒 287 | cls_flag: 是否返回session的类实例 288 | load_responses: 结果的键值是否进行load 289 | Returns: 290 | 291 | """ 292 | 293 | try: 294 | session_data = await self.redis_db.hgetall(session_id) 295 | if not session_data: 296 | raise RedisClientError("not found session, session_id={}".format(session_id)) 297 | 298 | if not await self.redis_db.expire(session_id, ex): 299 | aelog.error("set session expire failed, session_id={}".format(session_id)) 300 | except RedisError as e: 301 | aelog.exception("get session error: {}, {}".format(session_id, e)) 302 | raise RedisClientError(e) 303 | else: 304 | # 返回的键值对是否做load 305 | if load_responses: 306 | hash_data = {} 307 | for hash_key, hash_val in session_data.items(): 308 | async with async_ignore_error(): 309 | hash_val = ujson.loads(hash_val) 310 | hash_data[hash_key] = hash_val 311 | session_data = hash_data 312 | 313 | if cls_flag: 314 | return Session(session_data.pop('account_id'), session_id=session_data.pop('session_id'), 315 | org_id=session_data.pop("org_id"), role_id=session_data.pop("role_id"), 316 | permission_id=session_data.pop("permission_id"), **session_data) 317 | else: 318 | return session_data 319 | 320 | async def verify(self, session_id: str) -> Session: 321 | """ 322 | 校验session,主要用于登录校验 323 | Args: 324 | session_id 325 | Returns: 326 | 327 | """ 328 | try: 329 | session = await self.get_session(session_id) 330 | except RedisClientError as e: 331 | raise RedisClientError(str(e)) 332 | else: 333 | if not session: 334 | raise RedisClientError("invalid session_id, session_id={}".format(session_id)) 335 | return session 336 | 337 | async def save_update_hash_data(self, name: str, hash_data: Dict, field_name: str = None, ex: int = EXPIRED, 338 | dump_responses: bool = False) -> str: 339 | """ 340 | 获取hash对象field_name对应的值 341 | Args: 342 | name: redis hash key的名称 343 | field_name: 保存的hash mapping 中的某个字段 344 | hash_data: 获取的hash对象中属性的名称 345 | ex: 过期时间,单位秒 346 | dump_responses: 是否对每个键值进行dump 347 | Returns: 348 | 反序列化对象 349 | """ 350 | if field_name is None and not isinstance(hash_data, MutableMapping): 351 | raise ValueError("hash data error, must be MutableMapping.") 352 | 353 | try: 354 | if not field_name: 355 | # 是否对每个键值进行dump 356 | if dump_responses: 357 | rs_data = {} 358 | for hash_key, hash_val in hash_data.items(): 359 | if not isinstance(hash_val, str): 360 | async with async_ignore_error(): 361 | hash_val = ujson.dumps(hash_val) 362 | rs_data[hash_key] = hash_val 363 | hash_data = rs_data 364 | 365 | if not await self.redis_db.hmset(name, hash_data): 366 | raise RedisClientError("save hash data mapping failed, session_id={}".format(name)) 367 | else: 368 | hash_data = hash_data if isinstance(hash_data, str) else ujson.dumps(hash_data) 369 | await self.redis_db.hset(name, field_name, hash_data) 370 | 371 | if not await self.redis_db.expire(name, ex): 372 | aelog.error("set hash data expire failed, session_id={}".format(name)) 373 | except RedisError as e: 374 | raise RedisClientError(str(e)) 375 | else: 376 | return name 377 | 378 | async def get_hash_data(self, name: str, field_name: str = None, ex: int = EXPIRED, 379 | load_responses: bool = False) -> Dict: 380 | """ 381 | 获取hash对象field_name对应的值 382 | Args: 383 | name: redis hash key的名称 384 | field_name: 获取的hash对象中属性的名称 385 | ex: 过期时间,单位秒 386 | load_responses: 结果的键值是否进行load 387 | Returns: 388 | 反序列化对象 389 | """ 390 | try: 391 | if field_name: 392 | hash_data = await self.redis_db.hget(name, field_name) 393 | # 返回的键值对是否做load 394 | if load_responses: 395 | async with async_ignore_error(): 396 | hash_data = ujson.loads(hash_data) 397 | else: 398 | hash_data = await self.redis_db.hgetall(name) 399 | # 返回的键值对是否做load 400 | if load_responses: 401 | rs_data = {} 402 | for hash_key, hash_val in hash_data.items(): 403 | async with async_ignore_error(): 404 | hash_val = ujson.loads(hash_val) 405 | rs_data[hash_key] = hash_val 406 | hash_data = rs_data 407 | if not hash_data: 408 | raise RedisClientError("not found hash data, name={}, field_name={}".format(name, field_name)) 409 | 410 | if not await self.redis_db.expire(name, ex): 411 | aelog.error("set expire failed, name={}".format(name)) 412 | except RedisError as e: 413 | raise RedisClientError(str(e)) 414 | else: 415 | return hash_data 416 | 417 | async def get_list_data(self, name: str, start: int = 0, end: int = -1, ex: int = EXPIRED) -> List: 418 | """ 419 | 获取redis的列表中的数据 420 | Args: 421 | name: redis key的名称 422 | start: 获取数据的起始位置,默认列表的第一个值 423 | end: 获取数据的结束位置,默认列表的最后一个值 424 | ex: 过期时间,单位秒 425 | Returns: 426 | 427 | """ 428 | try: 429 | data = await self.redis_db.lrange(name, start=start, end=end) 430 | if not await self.redis_db.expire(name, ex): 431 | aelog.error("set expire failed, name={}".format(name)) 432 | except RedisError as e: 433 | raise RedisClientError(str(e)) 434 | else: 435 | return data 436 | 437 | async def save_list_data(self, name: str, list_data: Union[List, str], save_to_left: bool = True, 438 | ex: int = EXPIRED) -> str: 439 | """ 440 | 保存数据到redis的列表中 441 | Args: 442 | name: redis key的名称 443 | list_data: 保存的值,可以是单个值也可以是元祖 444 | save_to_left: 是否保存到列表的左边,默认保存到左边 445 | ex: 过期时间,单位秒 446 | Returns: 447 | 448 | """ 449 | list_data = (list_data,) if isinstance(list_data, str) else list_data 450 | try: 451 | if save_to_left: 452 | if not await self.redis_db.lpush(name, *list_data): 453 | raise RedisClientError("lpush value to head failed.") 454 | else: 455 | if not await self.redis_db.rpush(name, *list_data): 456 | raise RedisClientError("lpush value to tail failed.") 457 | if not await self.redis_db.expire(name, ex): 458 | aelog.error("set expire failed, name={}".format(name)) 459 | except RedisError as e: 460 | raise RedisClientError(str(e)) 461 | else: 462 | return name 463 | 464 | async def save_update_usual_data(self, name: str, value: Any, ex: int = EXPIRED) -> str: 465 | """ 466 | 保存列表、映射对象为普通的字符串 467 | Args: 468 | name: redis key的名称 469 | value: 保存的值,可以是可序列化的任何职 470 | ex: 过期时间,单位秒 471 | Returns: 472 | 473 | """ 474 | value = ujson.dumps(value) if not isinstance(value, str) else value 475 | try: 476 | if not await self.redis_db.set(name, value, ex): 477 | raise RedisClientError("set serializable value failed!") 478 | except RedisError as e: 479 | raise RedisClientError(str(e)) 480 | else: 481 | return name 482 | 483 | async def incrbynumber(self, name: str, amount: int = 1, ex: int = EXPIRED) -> str: 484 | """ 485 | 486 | Args: 487 | 488 | Returns: 489 | 490 | """ 491 | try: 492 | if isinstance(amount, int): 493 | if not await self.redis_db.incr(name, amount): 494 | raise RedisClientError("Increments int value failed!") 495 | else: 496 | if not await self.redis_db.incrbyfloat(name, amount): 497 | raise RedisClientError("Increments float value failed!") 498 | if not await self.redis_db.expire(name, ex): 499 | aelog.error("set expire failed, name={}".format(name)) 500 | except RedisError as e: 501 | raise RedisClientError(str(e)) 502 | else: 503 | return name 504 | 505 | async def get_usual_data(self, name: str, load_responses: bool = True, update_expire: bool = True, 506 | ex: int = EXPIRED) -> Union[Dict, str]: 507 | """ 508 | 获取name对应的值 509 | Args: 510 | name: redis key的名称 511 | load_responses: 是否转码默认转码 512 | update_expire: 是否更新过期时间 513 | ex: 过期时间,单位秒 514 | Returns: 515 | 反序列化对象 516 | """ 517 | data = await self.redis_db.get(name) 518 | 519 | if data is not None and update_expire: # 保证key存在时设置过期时间 520 | if not await self.redis_db.expire(name, ex): 521 | aelog.error("set expire failed, name={}".format(name)) 522 | 523 | if load_responses: 524 | async with async_ignore_error(): 525 | data = ujson.loads(data) 526 | 527 | return data 528 | 529 | async def is_exist_key(self, name: str) -> bool: 530 | """ 531 | 判断redis key是否存在 532 | Args: 533 | name: redis key的名称 534 | Returns: 535 | 536 | """ 537 | return await self.redis_db.exists(name) 538 | 539 | async def delete_keys(self, names: List[str]) -> NoReturn: 540 | """ 541 | 删除一个或多个redis key 542 | Args: 543 | names: redis key的名称 544 | Returns: 545 | 546 | """ 547 | names = (names,) if isinstance(names, str) else names 548 | if not await self.redis_db.delete(*names): 549 | aelog.error("Delete redis keys failed {}.".format(*names)) 550 | 551 | async def get_keys(self, pattern_name: str) -> List: 552 | """ 553 | 根据正则表达式获取redis的keys 554 | Args: 555 | pattern_name:正则表达式的名称 556 | Returns: 557 | 558 | """ 559 | return await self.redis_db.keys(pattern_name) 560 | -------------------------------------------------------------------------------- /aclients/decorators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 19-1-21 下午6:47 8 | """ 9 | from functools import wraps 10 | from typing import Dict, List, Tuple, Union 11 | 12 | import aelog 13 | from marshmallow import EXCLUDE, Schema, ValidationError 14 | from sanic.request import Request 15 | 16 | from .err_msg import schema_msg 17 | from .exceptions import FuncArgsError, HttpError 18 | from .utils import verify_message 19 | 20 | __all__ = ("singleton", "schema_validate") 21 | 22 | 23 | def singleton(cls): 24 | """ 25 | singleton for class of app 26 | """ 27 | 28 | instances = {} 29 | 30 | @wraps(cls) 31 | def _singleton(*args, **kwargs): 32 | """ 33 | singleton for class of app 34 | """ 35 | cls_name = "_{0}".format(cls) 36 | if cls_name not in instances: 37 | instances[cls_name] = cls(*args, **kwargs) 38 | return instances[cls_name] 39 | 40 | return _singleton 41 | 42 | 43 | def schema_validate(schema_obj, required: Union[Tuple, List] = tuple(), is_extends: bool = True, 44 | excluded: Union[Tuple, List] = tuple(), use_zh: bool = True, message: Dict = None): 45 | """ 46 | 校验post的json格式和类型是否正确 47 | Args: 48 | schema_obj: 定义的schema对象 49 | required: 需要标记require的字段 50 | excluded: 排除不需要的字段 51 | is_extends: 是否继承schemea本身其他字段的require属性, 默认继承 52 | use_zh: 消息提示是否使用中文,默认中文 53 | message: 提示消息 54 | Returns: 55 | """ 56 | 57 | if not issubclass(schema_obj, Schema): 58 | raise FuncArgsError(message="schema_obj type error!") 59 | if not isinstance(required, (tuple, list)): 60 | raise FuncArgsError(message="required type error!") 61 | if not isinstance(excluded, (tuple, list)): 62 | raise FuncArgsError(message="excluded type error!") 63 | 64 | msg_zh = "msg_zh" if use_zh else "msg_en" 65 | # 此处的功能保证,如果调用了多个校验装饰器,则其中一个更改了,所有的都会更改 66 | if not getattr(schema_validate, "message", None): 67 | setattr(schema_validate, "message", verify_message(schema_msg, message or {})) 68 | 69 | def _validated(func): 70 | """ 71 | 校验post的json格式和类型是否正确 72 | """ 73 | 74 | @wraps(func) 75 | async def _wrapper(*args, **kwargs): 76 | """ 77 | 校验post的json格式和类型是否正确 78 | """ 79 | schema_message = getattr(schema_validate, "message", None) 80 | 81 | request = args[0] if isinstance(args[0], Request) else args[1] 82 | new_schema_obj = schema_obj(unknown=EXCLUDE) 83 | if required: 84 | for key, val in new_schema_obj.fields.items(): 85 | if key in required: # 反序列化期间,把特别需要的字段标记为required 86 | setattr(new_schema_obj.fields[key], "required", True) 87 | setattr(new_schema_obj.fields[key], "dump_only", False) 88 | elif not is_extends: 89 | setattr(new_schema_obj.fields[key], "required", False) 90 | try: 91 | valid_data = new_schema_obj.load(request.json, unknown=EXCLUDE) 92 | # 把load后不需要的字段过滤掉,主要用于不允许修改的字段load后过滤掉 93 | for val in excluded: 94 | valid_data.pop(val, None) 95 | except ValidationError as err: 96 | # 异常退出 97 | aelog.exception('Request body validation error, please check! {} {} error={}'.format( 98 | request.method, request.path, err.messages)) 99 | raise HttpError(400, message=schema_message[201][msg_zh], error=err.messages) 100 | except Exception as err: 101 | aelog.exception("Request body validation unknow error, please check!. {} {} error={}".format( 102 | request.method, request.path, str(err))) 103 | raise HttpError(500, message=schema_message[202][msg_zh], error=str(err)) 104 | else: 105 | request["json"] = valid_data 106 | return await func(*args, **kwargs) 107 | 108 | return _wrapper 109 | 110 | return _validated 111 | -------------------------------------------------------------------------------- /aclients/err_msg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-25 下午2:42 8 | 可配置消息模块 9 | """ 10 | 11 | __all__ = ("mysql_msg", "mongo_msg", "http_msg", "schema_msg") 12 | 13 | # mysql 从1到100 14 | mysql_msg = { 15 | 1: {"msg_code": 1, "msg_zh": "MySQL插入数据失败.", "msg_en": "MySQL insert data failed.", 16 | "description": "MySQL插入数据时最终失败的提示"}, 17 | 2: {"msg_code": 2, "msg_zh": "MySQL更新数据失败.", "msg_en": "MySQL update data failed.", 18 | "description": "MySQL更新数据时最终失败的提示"}, 19 | 3: {"msg_code": 3, "msg_zh": "MySQL删除数据失败.", "msg_en": "MySQL delete data failed.", 20 | "description": "MySQL删除数据时最终失败的提示"}, 21 | 4: {"msg_code": 4, "msg_zh": "MySQL查找单条数据失败.", "msg_en": "MySQL find one data failed.", 22 | "description": "MySQL查找单条数据时最终失败的提示"}, 23 | 5: {"msg_code": 5, "msg_zh": "MySQL查找多条数据失败.", "msg_en": "MySQL find many data failed.", 24 | "description": "MySQL查找多条数据时最终失败的提示"}, 25 | 6: {"msg_code": 6, "msg_zh": "MySQL执行SQL失败.", "msg_en": "MySQL execute sql failed.", 26 | "description": "MySQL执行SQL失败的提示"}, 27 | } 28 | 29 | # mongo 从100到200 30 | mongo_msg = { 31 | 100: {"msg_code": 100, "msg_zh": "MongoDB插入数据失败.", "msg_en": "MongoDB insert data failed.", 32 | "description": "MongoDB插入数据时最终失败的提示"}, 33 | 101: {"msg_code": 101, "msg_zh": "MongoDB更新数据失败.", "msg_en": "MongoDB update data failed.", 34 | "description": "MongoDB更新数据时最终失败的提示"}, 35 | 102: {"msg_code": 102, "msg_zh": "MongoDB删除数据失败.", "msg_en": "MongoDB delete data failed.", 36 | "description": "MongoDB删除数据时最终失败的提示"}, 37 | 103: {"msg_code": 103, "msg_zh": "MongoDB查找单条数据失败.", "msg_en": "MongoDB find one data failed.", 38 | "description": "MongoDB查找单条数据时最终失败的提示"}, 39 | 104: {"msg_code": 104, "msg_zh": "MongoDB查找多条数据失败.", "msg_en": "MongoDB find many data failed.", 40 | "description": "MongoDB查找多条数据时最终失败的提示"}, 41 | 105: {"msg_code": 105, "msg_zh": "MongoDB聚合查询数据失败.", "msg_en": "MongoDB aggregate query data failed.", 42 | "description": "MongoDB聚合查询数据时最终失败的提示"}, 43 | } 44 | 45 | # request and schema 从200到300 46 | http_msg = { 47 | 200: {"msg_code": 200, "msg_zh": "获取API响应结果失败.", "msg_en": "Failed to get API response result.", 48 | "description": "async request 获取API响应结果失败时的提示"}, 49 | } 50 | 51 | schema_msg = { 52 | # schema valication message 53 | 201: {"msg_code": 201, "msg_zh": "数据提交有误,请重新检查.", "msg_en": "Request body validation error, please check!", 54 | "description": "marmallow校验body错误时的提示"}, 55 | 202: {"msg_code": 202, "msg_zh": "数据提交未知错误,请重新检查.", 56 | "msg_en": "Request body validation unknow error, please check!", 57 | "description": "marmallow校验body未知错误时的提示"}, 58 | } 59 | -------------------------------------------------------------------------------- /aclients/exceptions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-25 下午2:08 8 | """ 9 | from sanic.exceptions import SanicException 10 | 11 | __all__ = ("ClientError", "ClientResponseError", "ClientConnectionError", "HttpError", "RedisClientError", 12 | "RedisConnectError", "MysqlDuplicateKeyError", "MysqlError", "MysqlInvalidNameError", "FuncArgsError", 13 | "Error", "PermissionDeniedError", "QueryArgsError", "MongoError", "MongoDuplicateKeyError", 14 | "MongoInvalidNameError", "CommandArgsError", "EmailError", "ConfigError") 15 | 16 | 17 | class Error(Exception): 18 | """ 19 | 异常基类 20 | """ 21 | 22 | def __init__(self, message=None): 23 | self.message = message 24 | 25 | def __str__(self): 26 | return "Error: message='{}'".format(self.message) 27 | 28 | def __repr__(self): 29 | return "<{} '{}'>".format(self.__class__.__name__, self.message) 30 | 31 | 32 | class ClientError(Error): 33 | """ 34 | 主要处理异步请求的error 35 | """ 36 | 37 | def __init__(self, url, *, message=None): 38 | self.url = url 39 | self.message = message 40 | super().__init__(message) 41 | 42 | def __str__(self): 43 | return "Error: url='{}', message='{}'".format(self.url, self.message) 44 | 45 | def __repr__(self): 46 | return "<{} '{}, {}'>".format(self.__class__.__name__, self.url, self.message) 47 | 48 | 49 | class ClientResponseError(ClientError): 50 | """ 51 | 响应异常 52 | """ 53 | 54 | def __init__(self, url, *, status_code=None, message=None, headers=None, body=None): 55 | self.url = url 56 | self.status_code = status_code 57 | self.message = message 58 | self.headers = headers 59 | self.body = body 60 | super().__init__(self.url, message=self.message) 61 | 62 | def __str__(self): 63 | return "Error: code={}, url='{}', message='{}', body='{}'".format( 64 | self.status_code, self.url, self.message, self.body) 65 | 66 | def __repr__(self): 67 | return "<{} '{}, {}, {}'>".format(self.__class__.__name__, self.status_code, self.url, self.message) 68 | 69 | 70 | class ClientConnectionError(ClientError): 71 | """连接异常""" 72 | 73 | pass 74 | 75 | 76 | class HttpError(Error, SanicException): 77 | """ 78 | 主要处理http 错误,从接口返回 79 | """ 80 | 81 | def __init__(self, status_code, *, message=None, error=None): 82 | self.status_code = status_code 83 | self.message = message 84 | self.error = error 85 | super(SanicException, self).__init__(message, status_code) 86 | 87 | def __str__(self): 88 | return "{}, '{}':'{}'".format(self.status_code, self.message, self.message or self.error) 89 | 90 | def __repr__(self): 91 | return "<{} '{}: {}'>".format(self.__class__.__name__, self.status_code, self.error or self.message) 92 | 93 | 94 | class RedisClientError(Error): 95 | """ 96 | 主要处理redis的error 97 | """ 98 | 99 | pass 100 | 101 | 102 | class RedisConnectError(RedisClientError): 103 | """ 104 | 主要处理redis的connect error 105 | """ 106 | pass 107 | 108 | 109 | class EmailError(Error): 110 | """ 111 | 主要处理email error 112 | """ 113 | 114 | pass 115 | 116 | 117 | class ConfigError(Error): 118 | """ 119 | 主要处理config error 120 | """ 121 | 122 | pass 123 | 124 | 125 | class MysqlError(Error): 126 | """ 127 | 主要处理mongo错误 128 | """ 129 | 130 | pass 131 | 132 | 133 | class MysqlDuplicateKeyError(MysqlError): 134 | """ 135 | 处理键重复引发的error 136 | """ 137 | 138 | pass 139 | 140 | 141 | class MysqlInvalidNameError(MysqlError): 142 | """ 143 | 处理名称错误引发的error 144 | """ 145 | 146 | pass 147 | 148 | 149 | class MongoError(Error): 150 | """ 151 | 主要处理mongo错误 152 | """ 153 | 154 | pass 155 | 156 | 157 | class MongoDuplicateKeyError(MongoError): 158 | """ 159 | 处理键重复引发的error 160 | """ 161 | 162 | pass 163 | 164 | 165 | class MongoInvalidNameError(MongoError): 166 | """ 167 | 处理名称错误引发的error 168 | """ 169 | 170 | pass 171 | 172 | 173 | class FuncArgsError(Error): 174 | """ 175 | 处理函数参数不匹配引发的error 176 | """ 177 | 178 | pass 179 | 180 | 181 | class PermissionDeniedError(Error): 182 | """ 183 | 处理权限被拒绝时的错误 184 | """ 185 | 186 | pass 187 | 188 | 189 | class QueryArgsError(Error): 190 | """ 191 | 处理salalemy 拼接query错误 192 | """ 193 | 194 | pass 195 | 196 | 197 | class CommandArgsError(Error): 198 | """ 199 | 处理执行命令时,命令失败错误 200 | """ 201 | 202 | pass 203 | -------------------------------------------------------------------------------- /aclients/tinylibs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 19-4-2 上午9:11 8 | """ 9 | 10 | from .tinymysql import * 11 | from .blinker import * 12 | -------------------------------------------------------------------------------- /aclients/tinylibs/blinker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 19-4-22 下午11:35 8 | 9 | 实现简单的信号,用于业务解耦 10 | """ 11 | import asyncio 12 | from typing import NoReturn, Tuple 13 | 14 | from sanic import Sanic 15 | 16 | from aclients.utils import Cached 17 | 18 | __all__ = ("Signal", "add_task") 19 | 20 | 21 | def add_task(app: Sanic, func, **kwargs) -> NoReturn: 22 | """ 23 | 添加异步执行任务 24 | Args: 25 | app: sanic的应用 26 | func: 要执行的函数 27 | kwargs: 执行函数需要的参数 28 | Returns: 29 | 30 | """ 31 | if not isinstance(app, Sanic): 32 | raise TypeError("app type must be Sanic type.") 33 | if not asyncio.iscoroutinefunction(func): 34 | raise TypeError("func type must be coroutine function.") 35 | app.loop.create_task(func(**kwargs)) 36 | 37 | 38 | class Signal(Cached): 39 | """ 40 | 异步信号实现 41 | """ 42 | 43 | def __init__(self, signal_name): 44 | """ 45 | 异步信号实现 46 | Args: 47 | signal_name: 信号名称 48 | 49 | """ 50 | self.signal_name = signal_name 51 | self.receiver = [] 52 | 53 | def connect(self, receiver) -> NoReturn: 54 | """ 55 | 连接信号的订阅者 56 | Args: 57 | receiver: 信号订阅者 58 | Returns: 59 | 60 | """ 61 | self.receiver.append(receiver) 62 | 63 | def disconnect(self, receiver) -> NoReturn: 64 | """ 65 | 取消连接信号的订阅者 66 | Args: 67 | receiver: 信号订阅者 68 | Returns: 69 | 70 | """ 71 | self.receiver.remove(receiver) 72 | 73 | def send(self, app: Sanic, **kwargs) -> Tuple: 74 | """ 75 | 发出信号到信号的订阅者,订阅者执行各自的功能 76 | Args: 77 | app: sanic的应用 78 | kwargs: 订阅者执行需要的参数 79 | Returns: 80 | 81 | """ 82 | if not isinstance(app, Sanic): 83 | raise TypeError("app type must be Sanic type.") 84 | for func in self.receiver: 85 | app.loop.create_task(func(**kwargs)) 86 | return app, kwargs 87 | -------------------------------------------------------------------------------- /aclients/tinylibs/tinymysql.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 19-4-2 上午9:04 8 | """ 9 | from typing import Dict, List, Tuple 10 | 11 | import aelog 12 | import pymysql 13 | 14 | __all__ = ("TinyMysql",) 15 | 16 | 17 | class TinyMysql(object): 18 | """ 19 | pymysql 操作数据库的各种方法 20 | Args: 21 | 22 | Returns: 23 | 24 | """ 25 | 26 | def __init__(self, db_host: str = "127.0.0.1", db_port: int = 3306, db_name: str = None, 27 | db_user: str = "root", db_pwd: str = "123456"): 28 | """ 29 | pymysql 操作数据库的各种方法 30 | Args: 31 | 32 | Returns: 33 | 34 | """ 35 | self.conn_pool = {} # 各个不同连接的连接池 36 | self.db_host = db_host 37 | self.db_port = db_port 38 | self.db_user = db_user 39 | self.db_pwd = db_pwd 40 | self.db_name = db_name 41 | 42 | @property 43 | def conn(self, ) -> pymysql.Connection: 44 | """ 45 | 获取MySQL的连接对象 46 | """ 47 | name = "{0}{1}{2}".format(self.db_host, self.db_port, self.db_name) 48 | 49 | def get_connection(): 50 | return pymysql.connect(host=self.db_host, port=self.db_port, db=self.db_name, user=self.db_user, 51 | passwd=self.db_pwd, charset="utf8", cursorclass=pymysql.cursors.DictCursor) 52 | 53 | if not self.conn_pool.get(name): 54 | self.conn_pool[name] = get_connection() 55 | return self.conn_pool[name] 56 | else: 57 | try: 58 | self.conn_pool[name].ping() 59 | except pymysql.Error: 60 | self.conn_pool[name] = get_connection() 61 | return self.conn_pool[name] 62 | 63 | def execute_many(self, sql: str, args_data: List[Tuple]) -> int: 64 | """ 65 | 批量插入数据 66 | Args: 67 | sql: 插入的SQL语句 68 | args_data: 批量插入的数据,为一个包含元祖的列表 69 | Returns: 70 | INSERT INTO traffic_100 (IMEI,lbs_dict_id,app_key) VALUES(%s,%s,%s) 71 | [('868403022323171', None, 'EB23B21E6E1D930E850E7267E3F00095'), 72 | ('865072026982119', None, 'EB23B21E6E1D930E850E7267E3F00095')] 73 | 74 | """ 75 | 76 | count = None 77 | try: 78 | with self.conn.cursor() as cursor: 79 | count = cursor.executemany(sql, args_data) 80 | except pymysql.Error as e: 81 | self.conn.rollback() 82 | aelog.exception(e) 83 | except Exception as e: 84 | self.conn.rollback() 85 | aelog.exception(e) 86 | else: 87 | self.conn.commit() 88 | return count 89 | 90 | def execute(self, sql: str, args_data: Tuple) -> int: 91 | """ 92 | 执行单条记录,更新、插入或者删除 93 | Args: 94 | sql: 插入的SQL语句 95 | args_data: 批量插入的数据,为一个包含元祖的列表 96 | Returns: 97 | INSERT INTO traffic_100 (IMEI,lbs_dict_id,app_key) VALUES(%s,%s,%s) 98 | ('868403022323171', None, 'EB23B21E6E1D930E850E7267E3F00095') 99 | 100 | """ 101 | 102 | count = None 103 | try: 104 | with self.conn.cursor() as cursor: 105 | count = cursor.execute(sql, args_data) 106 | except pymysql.Error as e: 107 | self.conn.rollback() 108 | aelog.exception(e) 109 | except Exception as e: 110 | self.conn.rollback() 111 | aelog.exception(e) 112 | else: 113 | self.conn.commit() 114 | return count 115 | 116 | def find_one(self, sql: str, args: Tuple = None) -> Dict: 117 | """ 118 | 查询单条记录 119 | Args: 120 | sql: sql 语句 121 | args: 查询参数 122 | Returns: 123 | 返回单条记录的返回值 124 | """ 125 | 126 | try: 127 | with self.conn.cursor() as cursor: 128 | cursor.execute(sql, args) 129 | except pymysql.Error as e: 130 | aelog.exception(e) 131 | else: 132 | return cursor.fetchone() 133 | 134 | def find_data(self, sql: str, args: Tuple = None, size: int = None) -> List[Dict]: 135 | """ 136 | 查询指定行数的数据 137 | Args: 138 | sql: sql 语句 139 | args: 查询参数 140 | size: 返回记录的条数 141 | Returns: 142 | 返回包含指定行数数据的列表,或者所有行数数据的列表 143 | """ 144 | 145 | try: 146 | with self.conn.cursor() as cursor: 147 | cursor.execute(sql, args) 148 | except pymysql.Error as e: 149 | aelog.exception(e) 150 | else: 151 | return cursor.fetchall() if not size else cursor.fetchmany(size) 152 | -------------------------------------------------------------------------------- /aclients/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-26 下午3:32 8 | """ 9 | import asyncio 10 | import multiprocessing 11 | import secrets 12 | import string 13 | import sys 14 | import weakref 15 | from collections import MutableMapping, MutableSequence 16 | from concurrent.futures import ThreadPoolExecutor 17 | from typing import Dict, List, Union 18 | 19 | import yaml 20 | from aiocontext import async_contextmanager 21 | from bson import ObjectId 22 | 23 | from aclients.exceptions import Error, FuncArgsError 24 | 25 | try: 26 | from yaml import CLoader as Loader 27 | except ImportError: 28 | from yaml import Loader 29 | 30 | __all__ = ("async_ignore_error", "verify_message", "wrap_async_func", "analysis_yaml", 31 | "gen_class_name", "objectid", "Singleton", "Cached", "gen_ident") 32 | 33 | # 执行任务的线程池 34 | pool = ThreadPoolExecutor(multiprocessing.cpu_count() * 10 + multiprocessing.cpu_count()) 35 | 36 | 37 | def gen_ident(ident_len: int = 8): 38 | """ 39 | 获取随机的标识码以字母开头, 默认8个字符的长度 40 | Args: 41 | 42 | Returns: 43 | 44 | """ 45 | ident_len = ident_len - 1 46 | alphabet = f"{string.ascii_lowercase}{string.digits}" 47 | ident = ''.join(secrets.choice(alphabet) for _ in range(ident_len)) 48 | return f"{secrets.choice(string.ascii_lowercase)}{ident}" 49 | 50 | 51 | @async_contextmanager 52 | async def async_ignore_error(error=Exception): 53 | """ 54 | 个别情况下会忽略遇到的错误 55 | Args: 56 | 57 | Returns: 58 | 59 | """ 60 | # noinspection PyBroadException 61 | try: 62 | yield 63 | except error: 64 | pass 65 | 66 | 67 | def verify_message(src_message: Dict, message: Union[List, Dict]): 68 | """ 69 | 对用户提供的message进行校验 70 | Args: 71 | src_message: 默认提供的消息内容 72 | message: 指定的消息内容 73 | Returns: 74 | 75 | """ 76 | src_message = dict(src_message) 77 | message = message if isinstance(message, MutableSequence) else [message] 78 | required_field = {"msg_code", "msg_zh", "msg_en"} 79 | 80 | for msg in message: 81 | if isinstance(msg, MutableMapping): 82 | if set(msg.keys()).intersection(required_field) == required_field and msg["msg_code"] in src_message: 83 | src_message[msg["msg_code"]].update(msg) 84 | return src_message 85 | 86 | 87 | async def wrap_async_func(func, *args, **kwargs): 88 | """ 89 | 包装同步阻塞请求为异步非阻塞 90 | Args: 91 | func: 实际请求的函数名或者方法名 92 | args: 函数参数 93 | kwargs: 函数参数 94 | Returns: 95 | 返回执行后的结果 96 | """ 97 | try: 98 | result = await asyncio.wrap_future(pool.submit(func, *args, **kwargs)) 99 | except TypeError as e: 100 | raise FuncArgsError("Args error: {}".format(e)) 101 | except Exception as e: 102 | raise Error("Error: {}".format(e)) 103 | else: 104 | return result 105 | 106 | 107 | def gen_class_name(underline_name: str): 108 | """ 109 | 由下划线的名称变为驼峰的名称 110 | Args: 111 | underline_name 112 | Returns: 113 | 114 | """ 115 | return "".join([name.capitalize() for name in underline_name.split("_")]) 116 | 117 | 118 | def analysis_yaml(full_conf_path: str): 119 | """ 120 | 解析yaml文件 121 | Args: 122 | full_conf_path: yaml配置文件路径 123 | Returns: 124 | 125 | """ 126 | with open(full_conf_path, 'rt', encoding="utf8") as f: 127 | try: 128 | conf = yaml.load(f, Loader=Loader) 129 | except yaml.YAMLError as e: 130 | print("Yaml配置文件出错, {}".format(e)) 131 | sys.exit() 132 | return conf 133 | 134 | 135 | def objectid(): 136 | """ 137 | 138 | Args: 139 | 140 | Returns: 141 | 142 | """ 143 | return str(ObjectId()) 144 | 145 | 146 | class _Singleton(type): 147 | """ 148 | singleton for class 149 | """ 150 | 151 | def __init__(cls, *args, **kwargs): 152 | cls.__instance = None 153 | super().__init__(*args, **kwargs) 154 | 155 | def __call__(cls, *args, **kwargs): 156 | if cls.__instance is None: 157 | cls.__instance = super().__call__(*args, **kwargs) 158 | return cls.__instance 159 | else: 160 | return cls.__instance 161 | 162 | 163 | class _Cached(type): 164 | def __init__(cls, *args, **kwargs): 165 | super().__init__(*args, **kwargs) 166 | cls.__cache = weakref.WeakValueDictionary() 167 | 168 | def __call__(cls, *args, **kwargs): 169 | cached_name = f"{args}{kwargs}" 170 | if cached_name in cls.__cache: 171 | return cls.__cache[cached_name] 172 | else: 173 | obj = super().__call__(*args, **kwargs) 174 | cls.__cache[cached_name] = obj # 这里是弱引用不能直接赋值,否则会被垃圾回收期回收 175 | return obj 176 | 177 | 178 | class Singleton(metaclass=_Singleton): 179 | pass 180 | 181 | 182 | class Cached(metaclass=_Cached): 183 | pass 184 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | aclients 2 | ======== 3 | 4 | 基于sanic扩展,各种数据库异步crud操作,各种请求异步crud操作,本repo是个基础封装库,直接用于各个业务系统中 5 | 6 | Installing aelog 7 | ================ 8 | 9 | - ``pip install aclients`` 10 | 11 | Usage 12 | ===== 13 | 14 | 后续添加,现在没时间. 15 | ~~~~~~~~~~~~~~~~~~~~~ 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sanic>=18.12 2 | aelog>=1.0.3 3 | aredis>=1.1.3 4 | hiredis 5 | aiomysql>=0.0.20 6 | sqlalchemy>=1.2.12 7 | aiohttp>=3.3.2 8 | cchardet 9 | aiodns 10 | motor>=1.2.2 11 | ujson 12 | marshmallow>=3.0.0rc3 13 | PyYAML>=3.13 14 | pymongo>=3.8.0 15 | asyncio-contextmanager==1.0.1 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | MIT License 5 | 6 | Copyright (c) 2020 Tiny Bees 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | 26 | """ 27 | from setuptools import setup 28 | 29 | from aclients import __version__ 30 | 31 | setup(name='aclients', 32 | version=__version__, 33 | description='基础封装库', 34 | long_description=open('README.md').read(), 35 | long_description_content_type='text/markdown', 36 | author='TinyBees', 37 | author_email='a598824322@qq.com', 38 | url='https://github.com/tinybees/aclients', 39 | packages=['aclients', 'aclients.tinylibs'], 40 | entry_points={}, 41 | requires=['sanic', 'aelog', 'aredis', 'hiredis', 'aiomysql', 'sqlalchemy', 'aiohttp', 'cchardet', 'aiodns', 42 | 'motor', 'ujson', 'marshmallow', 'PyYAML'], 43 | install_requires=['sanic>=18.12', 44 | 'aelog>=1.0.3', 45 | 'aredis>=1.1.3', 46 | 'hiredis', 47 | 'aiomysql>=0.0.20', 48 | 'sqlalchemy>=1.2.12', 49 | 'aiohttp>=3.3.2', 50 | 'cchardet', 51 | 'aiodns', 52 | 'motor>=1.2.2', 53 | 'ujson', 54 | 'marshmallow>=3.0.0rc3', 55 | 'PyYAML>=3.13', 56 | 'pymongo>=3.8.0'], 57 | python_requires=">=3.5", 58 | keywords="mysql, mongo, redis, http, asyncio, crud, session", 59 | license='MIT', 60 | classifiers=[ 61 | 'Development Status :: 4 - Beta', 62 | 'Intended Audience :: Developers', 63 | 'License :: OSI Approved :: MIT License', 64 | 'Natural Language :: Chinese (Simplified)', 65 | 'Operating System :: POSIX :: Linux', 66 | 'Operating System :: Microsoft :: Windows', 67 | 'Operating System :: MacOS :: MacOS X', 68 | 'Topic :: Software Development :: Libraries :: Python Modules', 69 | 'Topic :: Utilities', 70 | 'Programming Language :: Python', 71 | 'Programming Language :: Python :: 3.5', 72 | 'Programming Language :: Python :: 3.6', 73 | 'Programming Language :: Python :: 3.7', 74 | 'Programming Language :: Python :: 3.8'] 75 | ) 76 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-12-25 下午4:42 8 | """ 9 | 10 | -------------------------------------------------------------------------------- /tests/jrpc_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 2020/2/21 下午2:14 8 | """ 9 | from aclients import AIOHttpClient 10 | from aclients.jsonrpc import AIOJRPCClient 11 | 12 | 13 | async def sub(client: AIOJRPCClient): 14 | """ 15 | 16 | Args: 17 | 18 | Returns: 19 | 20 | """ 21 | r = await client["local"].sub(5, 2).done() 22 | print(r) 23 | s = await client["local"].test().done() 24 | print(s) 25 | 26 | s2 = await client["local"].sub(5, 2).test(3).done() 27 | print(s2) 28 | 29 | 30 | if __name__ == '__main__': 31 | aio_http = AIOHttpClient() 32 | aio_http.init_session() 33 | aio_jrpc = AIOJRPCClient(aio_http) 34 | aio_jrpc.register("local", ("127.0.0.1", 8000)) 35 | aio_jrpc.register("loca2", ("127.0.0.1", 8001)) 36 | import asyncio 37 | 38 | loop = asyncio.get_event_loop() 39 | loop.run_until_complete(sub(aio_jrpc)) 40 | -------------------------------------------------------------------------------- /tests/jrpc_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 2020/2/21 下午1:47 8 | """ 9 | import asyncio 10 | 11 | from sanic import Sanic 12 | 13 | from aclients.jsonrpc import SanicJsonRPC 14 | 15 | app = Sanic() 16 | jsonrpc = SanicJsonRPC() 17 | jsonrpc.init_app(app) 18 | 19 | 20 | @jsonrpc.jrpc 21 | async def sub(a: int, b: int) -> int: 22 | await asyncio.sleep(0.1) 23 | return a - b 24 | 25 | 26 | @jsonrpc.jrpc 27 | async def test() -> str: 28 | await asyncio.sleep(0.1) 29 | return "中文" 30 | 31 | 32 | if __name__ == '__main__': 33 | app.run(host='127.0.0.1', port=8000) 34 | -------------------------------------------------------------------------------- /tests/verify_gen_sql.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 2020/2/28 下午3:30 8 | """ 9 | import unittest 10 | from datetime import datetime 11 | 12 | import sqlalchemy as sa 13 | 14 | from aclients import AIOMysqlClient 15 | from aclients.utils import objectid 16 | 17 | mysql_db = AIOMysqlClient() 18 | 19 | 20 | class MessageDisplayModel(mysql_db.Model): 21 | """ 22 | 消息展示 23 | """ 24 | __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8mb4'} 25 | 26 | __tablename__ = "message_display" 27 | 28 | id = sa.Column(sa.String(24), primary_key=True, default=objectid, nullable=False, doc='实例ID') 29 | msg_code = sa.Column(sa.Integer, index=True, unique=True, nullable=False, doc='消息编码') 30 | msg_zh = sa.Column(sa.String(200), unique=True, nullable=False, doc='中文消息') 31 | msg_en = sa.Column(sa.String(200), doc='英文消息') 32 | description = sa.Column(sa.String(255), doc='描述') 33 | created_time = sa.Column(sa.DateTime, default=datetime.now, nullable=False, doc='创建时间') 34 | updated_time = sa.Column(sa.DateTime, default=datetime.now, onupdate=datetime.now, nullable=False, doc='更新时间') 35 | 36 | 37 | class TestSQL(unittest.TestCase): 38 | """ 39 | 测试Query类中的生成SQL功能 40 | """ 41 | 42 | def test_select_sql(self): 43 | """ 44 | Args: 45 | """ 46 | sql = mysql_db.query.model(MessageDisplayModel).where( 47 | MessageDisplayModel.id == "5e53bb135b64856045ccb8dc").select_query().sql() 48 | self.assertEqual(sql, { 49 | 'sql': 'SELECT message_display.id, message_display.msg_code, message_display.msg_zh, message_display.msg_en, message_display.description, message_display.created_time, message_display.updated_time \nFROM message_display \nWHERE message_display.id = %(id_1)s', 50 | 'params': {'id_1': '5e53bb135b64856045ccb8dc'}}) 51 | 52 | def test_select_count_sql(self, ): 53 | """ 54 | Args: 55 | """ 56 | sql = mysql_db.query.model(MessageDisplayModel).where( 57 | MessageDisplayModel.id == "5e53bb135b64856045ccb8dc").select_query(True).sql() 58 | self.assertEqual(sql, { 59 | 'sql': 'SELECT count(*) AS count \nFROM message_display \nWHERE message_display.id = %(id_1)s', 60 | 'params': {'id_1': '5e53bb135b64856045ccb8dc'}}) 61 | 62 | def test_insert_sql(self, ): 63 | """ 64 | Args: 65 | """ 66 | sql = mysql_db.query.model(MessageDisplayModel).insert_query({ 67 | "msg_code": 3100, "id": "5e53bb135b64856045ccb8dc", "msg_zh": "fdfdf"}).sql() 68 | self.assertEqual(sql["sql"], 69 | 'INSERT INTO message_display (id, msg_code, msg_zh, created_time, updated_time) VALUES (%(id)s, %(msg_code)s, %(msg_zh)s, %(created_time)s, %(updated_time)s)' 70 | ) 71 | self.assertEqual(sql["params"]["id"], "5e53bb135b64856045ccb8dc") 72 | self.assertEqual(sql["params"]["msg_code"], 3100) 73 | self.assertEqual(sql["params"]["msg_zh"], "fdfdf") 74 | 75 | def test_update_sql(self, ): 76 | """ 77 | Args: 78 | """ 79 | sql = mysql_db.query.model(MessageDisplayModel).where( 80 | MessageDisplayModel.id == "5e53bb135b64856045ccb8dc").update_query({"msg_code": 3100}).sql() 81 | self.assertEqual(sql["sql"], 82 | 'UPDATE message_display SET msg_code=%(msg_code)s, updated_time=%(updated_time)s WHERE message_display.id = %(id_1)s') 83 | self.assertEqual(sql["params"]["id_1"], "5e53bb135b64856045ccb8dc") 84 | self.assertEqual(sql["params"]["msg_code"], 3100) 85 | 86 | def test_delete_sql(self, ): 87 | """ 88 | Args: 89 | """ 90 | sql = mysql_db.query.model(MessageDisplayModel).where( 91 | MessageDisplayModel.id == "5e53bb135b64856045ccb8dc").delete_query().sql() 92 | self.assertEqual(sql, {'sql': 'DELETE FROM message_display WHERE message_display.id = %(id_1)s', 93 | 'params': {'id_1': '5e53bb135b64856045ccb8dc'}}) 94 | 95 | def test_paginate_sql(self, ): 96 | """ 97 | Args: 98 | """ 99 | sql = mysql_db.query.model(MessageDisplayModel).where( 100 | MessageDisplayModel.id == "5e53bb135b64856045ccb8dc").paginate_query().sql() 101 | self.assertEqual(len(sql), 2) 102 | self.assertEqual(len(sql[0]), 2) 103 | self.assertEqual(len(sql[1]), 2) 104 | 105 | 106 | if __name__ == '__main__': 107 | unittest.TextTestRunner(verbosity=2).run(TestSQL) 108 | -------------------------------------------------------------------------------- /tests/verify_http_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | 4 | """ 5 | @author: guoyanfeng 6 | @software: PyCharm 7 | @time: 18-10-2 上午10:08 8 | """ 9 | 10 | import asyncio 11 | import atexit 12 | 13 | from aclients import AIOHttpClient 14 | from aclients.exceptions import ClientResponseError 15 | 16 | requests = AIOHttpClient() 17 | 18 | 19 | async def get_verify(url, params, headers): 20 | """ 21 | 22 | Args: 23 | 24 | Returns: 25 | 26 | """ 27 | text = await requests.async_get(url, params=params, headers=headers) 28 | print("get", text.json()) 29 | print("header", text.headers) 30 | print(text.headers["Connection"]) 31 | assert text.json()["args"] == params 32 | assert list(headers.values())[0] in text.json()["headers"].values() 33 | 34 | 35 | async def get_verify_error(url, params, headers): 36 | """ 37 | 38 | Args: 39 | 40 | Returns: 41 | 42 | """ 43 | try: 44 | await requests.async_get(url, params=params, headers=headers) 45 | except ClientResponseError as e: 46 | assert isinstance(e, ClientResponseError) 47 | 48 | 49 | async def post_verify_data(url, data, headers): 50 | """ 51 | 52 | Args: 53 | 54 | Returns: 55 | 56 | """ 57 | text = await requests.async_post(url, data=data, headers=headers) 58 | print("post_data", text.json()) 59 | assert text.json()["data"] == data 60 | assert list(headers.values())[0] in text.json()["headers"].values() 61 | 62 | 63 | async def post_verify_json(url, json, headers): 64 | """ 65 | 66 | Args: 67 | 68 | Returns: 69 | 70 | """ 71 | text = await requests.async_post(url, json=json, headers=headers) 72 | print("post_json", text.json()) 73 | assert text.json()["json"] == json 74 | assert list(headers.values())[0] in text.json()["headers"].values() 75 | 76 | 77 | async def put_verify_data(url, data, headers): 78 | """ 79 | 80 | Args: 81 | 82 | Returns: 83 | 84 | """ 85 | text = await requests.async_put(url, data=data, headers=headers) 86 | print("put_data", text.json()) 87 | assert text.json()["data"] == data 88 | assert list(headers.values())[0] in text.json()["headers"].values() 89 | 90 | 91 | async def put_verify_json(url, json, headers): 92 | """ 93 | 94 | Args: 95 | 96 | Returns: 97 | 98 | """ 99 | text = await requests.async_put(url, json=json, headers=headers) 100 | print("put_json", text.json()) 101 | assert text.json()["json"] == json 102 | assert list(headers.values())[0] in text.json()["headers"].values() 103 | 104 | 105 | async def patch_verify_data(url, data, headers): 106 | """ 107 | 108 | Args: 109 | 110 | Returns: 111 | 112 | """ 113 | text = await requests.async_patch(url, data=data, headers=headers) 114 | print("patch_data", text.json()) 115 | assert text.json()["data"] == data 116 | assert list(headers.values())[0] in text.json()["headers"].values() 117 | 118 | 119 | async def patch_verify_json(url, json, headers): 120 | """ 121 | 122 | Args: 123 | 124 | Returns: 125 | 126 | """ 127 | text = await requests.async_patch(url, json=json, headers=headers) 128 | print("patch_json", text.json()) 129 | assert text.json()["json"] == json 130 | assert list(headers.values())[0] in text.json()["headers"].values() 131 | 132 | 133 | async def delete_verify(url): 134 | """ 135 | 136 | Args: 137 | 138 | Returns: 139 | 140 | """ 141 | text = await requests.async_delete(url) 142 | print("delete", text.json()) 143 | 144 | 145 | if __name__ == '__main__': 146 | loop = asyncio.get_event_loop() 147 | atexit.register(loop.close) 148 | requests.init_session() 149 | tasks = [get_verify("http://httpbin.org/get", params={"test": "get"}, headers={"test": "header"}), 150 | get_verify_error("http://httpbin.org/get344", params={"test": "get"}, headers={"test": "header"}), 151 | post_verify_data("http://httpbin.org/post", data="post data", headers={"test": "header data"}), 152 | post_verify_json("http://httpbin.org/post", json={"test": "post json"}, headers={"test": "header json"}), 153 | put_verify_data("http://httpbin.org/put", data="put data", headers={"test": "header data"}), 154 | put_verify_json("http://httpbin.org/put", json={"test": "put json"}, headers={"test": "header json"}), 155 | patch_verify_data("http://httpbin.org/patch", data="patch data", headers={"test": "header patch"}), 156 | patch_verify_json("http://httpbin.org/patch", json={"test": "patch json"}, headers={"test": "header " 157 | "patch"}), 158 | delete_verify("http://httpbin.org/delete")] 159 | loop.run_until_complete(asyncio.wait(tasks)) 160 | --------------------------------------------------------------------------------