├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── demo ├── connections │ ├── http_client_demo.py │ └── sqlalchemy_demo │ │ ├── __init__.py │ │ ├── demo.py │ │ ├── manager.py │ │ └── table.py ├── decorators │ ├── base.py │ ├── cache.py │ └── run_on_executor_demo.py ├── logging │ └── logging_demo.py ├── meta_cls │ └── singleton_demo.py └── utils │ ├── async_util_demo.py │ ├── excel_util_demo.py │ ├── jwt_util_demo.py │ ├── serializer_util_demo.py │ ├── time_util_demo.py │ └── tree_util_demo.py ├── py_tools ├── __init__.py ├── chatbot │ ├── __init__.py │ ├── app_server.py │ ├── chatbot.py │ └── factory.py ├── connections │ ├── __init__.py │ ├── db │ │ ├── __init__.py │ │ ├── mysql │ │ │ ├── __init__.py │ │ │ ├── client.py │ │ │ └── orm_model.py │ │ └── redis_client.py │ ├── http │ │ ├── __init__.py │ │ └── client.py │ ├── mq │ │ ├── __init__.py │ │ ├── kafka_client.py │ │ └── rabbitmq_client.py │ └── oss │ │ ├── __init__.py │ │ └── minio_client.py ├── constants │ ├── __init__.py │ └── const.py ├── data_schemas │ ├── __init__.py │ ├── time.py │ └── unit.py ├── decorators │ ├── __init__.py │ ├── base.py │ └── cache.py ├── enums │ ├── __init__.py │ ├── base.py │ ├── error.py │ ├── feishu.py │ ├── http.py │ ├── pub_biz.py │ └── time.py ├── exceptions │ ├── __init__.py │ └── base.py ├── logging │ ├── __init__.py │ ├── base.py │ └── default_logging_conf.py ├── meta_cls │ ├── __init__.py │ └── base.py └── utils │ ├── __init__.py │ ├── async_util.py │ ├── excel_util.py │ ├── file_util.py │ ├── func_util.py │ ├── jwt_util.py │ ├── mask_util.py │ ├── project_templates │ ├── __init__.py │ ├── make_pro.py │ └── python_project │ │ └── src │ │ ├── constants │ │ └── __init__.py │ │ ├── dao │ │ ├── __init__.py │ │ ├── orm │ │ │ ├── manage │ │ │ │ ├── __init__.py │ │ │ │ └── user.py │ │ │ └── table │ │ │ │ ├── __init__.py │ │ │ │ └── user.py │ │ └── redis │ │ │ ├── __init__.py │ │ │ ├── cache_info.py │ │ │ └── client.py │ │ ├── data_schemas │ │ ├── __init__.py │ │ ├── api_schemas │ │ │ └── __init__.py │ │ └── logic_schemas │ │ │ └── __init__.py │ │ ├── enums │ │ ├── __init__.py │ │ └── base.py │ │ ├── handlers │ │ └── __init__.py │ │ ├── middlewares │ │ └── __init__.py │ │ ├── routes │ │ └── __init__.py │ │ ├── server.py │ │ ├── services │ │ ├── __init__.py │ │ └── base.py │ │ ├── settings │ │ ├── __init__.py │ │ ├── base_setting.py │ │ ├── db_setting.py │ │ └── log_setting.py │ │ └── utils │ │ ├── __init__.py │ │ ├── context_util.py │ │ ├── log_util.py │ │ ├── trace_util.py │ │ └── web_util.py │ ├── re_util.py │ ├── serializer_util.py │ ├── time_util.py │ └── tree_util.py ├── requirements.txt ├── ruff.toml ├── setup.py └── tests ├── __init__.py ├── chatbot └── test_chatbot.py ├── connections └── test_http_client.py ├── decorators └── base.py ├── meta_cls └── singleton.py └── utils ├── test_jwt_util.py └── test_re_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | *.xlsx 7 | 8 | .idea 9 | # C extensions 10 | *.so 11 | tmp 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: [ commit ] 2 | 3 | # Install 4 | # 1. pip install pre-commit 5 | # 2. pre-commit install 6 | # 3. pre-commit run --all-files # 检查全部文件 7 | # 4. pre-commit run --files py_tools/connections/* # 检查指定目录或文件 8 | repos: 9 | - repo: https://github.com/pycqa/isort 10 | rev: 5.11.5 11 | hooks: 12 | - id: isort 13 | args: ['--profile', 'black'] 14 | exclude: >- 15 | (?x)^( 16 | .*__init__\.py$ 17 | ) 18 | 19 | - repo: https://github.com/astral-sh/ruff-pre-commit 20 | # Ruff version. 21 | rev: v0.2.0 22 | hooks: 23 | - id: ruff 24 | args: [ --fix ] 25 | - id: ruff-format 26 | 27 | - repo: https://github.com/psf/black 28 | rev: 23.3.0 29 | hooks: 30 | - id: black 31 | args: ['--line-length', '120'] 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include py_tools/utils/project_templates * 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Py-Tools 2 | 3 | > Py-Tools 是一个实用的 Python 工具集和可复用组件库,旨在简化常见任务,提高 Python 项目的开发效率。 4 | > 5 | > 设计细节请移步到掘金查看:https://juejin.cn/column/7131286129713610766 6 | 7 | ## 安装 8 | - 环境要求:python version >= 3.8 9 | - 历史版本记录:https://pypi.org/project/hui-tools/#history 10 | 11 | > 根据 [PEP 625](https://peps.python.org/pep-0625/) 要求,从 0.6.0 版本开始,将包名从 `hui-tools` 改成了 `huidevkit`, 历史版本的包名仍然可用,但是不推荐使用。 12 | 13 | ### 默认安装 14 | ```python 15 | pip install huidevkit 16 | ``` 17 | 默认安装如下功能可以使用 18 | - 时间工具类 19 | - http客户端 20 | - 同步异步互转装饰器 21 | - 常用枚举 22 | - pydantic 23 | - loguru的日志器 24 | - jwt工具类 25 | - 等... 26 | 27 | ### 全部安装 28 | ```python 29 | pip install huidevkit[all] 30 | ``` 31 | 32 | ### 可选安装 33 | ```python 34 | pip install huidevkit[db-orm, db-redis, excel-tools] 35 | ``` 36 | 37 | 可选参数参考: 38 | ```python 39 | extras_require = { 40 | "db-orm": ["sqlalchemy[asyncio]==2.0.20", "aiomysql==0.2.0"], 41 | "db-redis": ["redis>=4.5.4"], 42 | "cache-proxy": ["redis>=4.5.4", "python-memcached==1.62", "cacheout==0.14.1"], 43 | "minio": ["minio==7.1.17"], 44 | "excel-tools": ["pandas==2.2.2", "openpyxl==3.0.10"], 45 | "test": ["pytest==7.3.1", "pytest-mock==3.14.0", "pytest-asyncio==0.23.8"], 46 | } 47 | ``` 48 | 49 | ### 简单使用 50 | > 所有功能都是从 py_tools 包中导入使用 51 | > 详细使用请查看项目的DEMO: https://github.com/HuiDBK/py-tools/tree/master/demo 52 | 53 | 生成python web项目结构模板 54 | ```python 55 | py_tools make_project WebDemo 56 | ``` 57 | 58 | 快速配置项目日志 59 | ```python 60 | from py_tools.constants import BASE_DIR 61 | from py_tools.logging import logger, setup_logging 62 | 63 | 64 | def main(): 65 | setup_logging(log_dir=BASE_DIR / "logs") 66 | logger.info("use log dir") 67 | logger.error("test error") 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | ``` 73 | 74 | 异步http客户端 75 | ```python 76 | import asyncio 77 | from py_tools.connections.http import AsyncHttpClient 78 | 79 | 80 | async def main(): 81 | url = "https://juejin.cn/" 82 | resp = await AsyncHttpClient().get(url).execute() 83 | text_data = await AsyncHttpClient().get(url, params={"test": "hui"}).text() 84 | json_data = await AsyncHttpClient().post(url, data={"test": "hui"}).json() 85 | byte_data = await AsyncHttpClient().get(url).bytes() 86 | 87 | async with AsyncHttpClient() as client: 88 | upload_file_ret = await client.upload_file(url, file="test.txt").json() 89 | 90 | async for chunk in AsyncHttpClient().get(url).stream(chunk_size=512): 91 | # 流式调用 92 | print(chunk) 93 | 94 | await AsyncHttpClient.close() 95 | 96 | 97 | if __name__ == "__main__": 98 | asyncio.run(main()) 99 | ``` 100 | 101 | mysql数据库操作demo 102 | ```python 103 | import asyncio 104 | import uuid 105 | from typing import List 106 | 107 | from connections.sqlalchemy_demo.manager import UserFileManager 108 | from connections.sqlalchemy_demo.table import UserFileTable 109 | from sqlalchemy import func 110 | 111 | from py_tools.connections.db.mysql import BaseOrmTable, DBManager, SQLAlchemyManager 112 | 113 | db_client = SQLAlchemyManager( 114 | host="127.0.0.1", 115 | port=3306, 116 | user="root", 117 | password="123456", 118 | db_name="hui-demo", 119 | ) 120 | 121 | 122 | async def create_and_transaction_demo(): 123 | async with UserFileManager.transaction() as session: 124 | await UserFileManager(session).bulk_add(table_objs=[{"filename": "aaa", "oss_key": uuid.uuid4().hex}]) 125 | user_file_obj = UserFileTable(filename="eee", oss_key=uuid.uuid4().hex) 126 | file_id = await UserFileManager(session).add(table_obj=user_file_obj) 127 | print("file_id", file_id) 128 | 129 | ret: UserFileTable = await UserFileManager(session).query_by_id(2) 130 | print("query_by_id", ret) 131 | 132 | # a = 1 / 0 133 | 134 | ret = await UserFileManager(session).query_one( 135 | cols=[UserFileTable.filename, UserFileTable.oss_key], conds=[UserFileTable.filename == "ccc"], 136 | ) 137 | print("ret", ret) 138 | 139 | 140 | async def query_demo(): 141 | ret = await UserFileManager().query_one(conds=[UserFileTable.filename == "ccc"]) 142 | print("ret", ret) 143 | 144 | file_count = await UserFileManager().query_one(cols=[func.count()], flat=True) 145 | print("str col one ret", file_count) 146 | 147 | filename = await UserFileManager().query_one(cols=[UserFileTable.filename], conds=[UserFileTable.id == 2], flat=True) 148 | print("filename", filename) 149 | 150 | ret = await UserFileManager().query_all(cols=[UserFileTable.filename, UserFileTable.oss_key]) 151 | print("ret", ret) 152 | 153 | ret = await UserFileManager().query_all(cols=["filename", "oss_key"]) 154 | print("str col ret", ret) 155 | 156 | ret: List[UserFileTable] = await UserFileManager().query_all() 157 | print("ret", ret) 158 | 159 | ret = await UserFileManager().query_all(cols=[UserFileTable.id], flat=True) 160 | print("ret", ret) 161 | 162 | 163 | async def list_page_demo(): 164 | """分页查询demo""" 165 | total_count, data_list = await UserFileManager().list_page( 166 | cols=["filename", "oss_key", "file_size"], curr_page=2, page_size=10 167 | ) 168 | print("total_count", total_count, f"data_list[{len(data_list)}]", data_list) 169 | 170 | 171 | async def run_raw_sql_demo(): 172 | """运行原生sql demo""" 173 | count_sql = "select count(*) as total_count from user_file" 174 | count_ret = await UserFileManager().run_sql(count_sql, query_one=True) 175 | print("count_ret", count_ret) 176 | 177 | data_sql = "select * from user_file where id > :id_val and file_size >= :file_size_val" 178 | params = {"id_val": 20, "file_size_val": 0} 179 | data_ret = await UserFileManager().run_sql(data_sql, params=params) 180 | print("dict data_ret", data_ret) 181 | 182 | # 连表查询 183 | data_sql = """ 184 | select 185 | user.id as user_id, 186 | username, 187 | user_file.id as file_id, 188 | filename, 189 | oss_key 190 | from 191 | user_file 192 | join user on user.id = user_file.creator 193 | where 194 | user_file.creator = :user_id 195 | """ 196 | data_ret = await UserFileManager().run_sql(data_sql, params={"user_id": 1}) 197 | print("join sql data_ret", data_ret) 198 | 199 | 200 | async def curd_demo(): 201 | await create_and_transaction_demo() 202 | await query_demo() 203 | await list_page_demo() 204 | await run_raw_sql_demo() 205 | 206 | 207 | async def create_tables(): 208 | # 根据映射创建库表 209 | async with DBManager.connection() as conn: 210 | await conn.run_sync(BaseOrmTable.metadata.create_all) 211 | 212 | 213 | async def main(): 214 | db_client.init_mysql_engine() 215 | DBManager.init_db_client(db_client) 216 | await create_tables() 217 | await curd_demo() 218 | 219 | 220 | if __name__ == "__main__": 221 | asyncio.run(main()) 222 | 223 | ``` 224 | 225 | ## Todo List 226 | 227 | ### 连接客户端 228 | 1. [x] http 同步异步客户端 229 | 2. [x] MySQL 客户端 - SQLAlchemy-ORM 封装 230 | 3. [x] Redis 客户端 231 | 4. [x] Minio 客户端 232 | 5. 消息队列客户端,rabbitmq、kafka 233 | 6. websocket 客户端 234 | 235 | ### 工具类 236 | - [x] 同异步函数转化工具类 237 | - [x] excel 工具类 238 | - [x] 文件 工具类 239 | - [x] 实用函数工具模块 240 | - [x] 数据掩码工具类 241 | - [x] 常用正则工具类 242 | - [x] 时间工具类 243 | - [x] 树结构转换工具类 244 | - [x] pydantic model 、dataclass 与 SQLALChemy table 序列化与反序列工具类 245 | - 认证相关工具类 246 | - [x] JWT 工具类 247 | - 图片操作工具类,例如校验图片分辨率 248 | - 邮件服务工具类 249 | - 配置解析工具类 250 | - 编码工具类,统一 base64、md5等编码入口 251 | - 加密工具类 252 | 253 | ### 装饰器 254 | 1. [x] 超时装饰器 255 | 2. [x] 重试装饰器 256 | 3. [x] 缓存装饰器 257 | 4. [x] 异步执行装饰器 258 | 259 | ### 枚举 260 | 1. [x] 通用枚举类封装 261 | 2. [x] 错误码枚举封装 262 | 3. [x] 常用枚举 263 | 264 | ### 异常 265 | 1. [x] 业务异常类封装 266 | 267 | ### 日志 268 | 1. [x] logger 日志器(loguru) 269 | 2. [x] 快速配置项目日志函数 270 | 271 | ## 工程目录结构 272 | 273 | ``` 274 | py-tools/ 275 | ├── py_tools/ 276 | │ ├── chatbot/ 277 | │ ├── connections/ 278 | │ ├── constants/ 279 | │ ├── data_schemas/ 280 | │ ├── decorators/ 281 | │ ├── enums/ 282 | │ ├── exceptions/ 283 | │ ├── meta_cls/ 284 | │ └── utils/ 285 | ├── docs/ 286 | ├── demo/ 287 | ├── tests/ 288 | ├── .gitignore 289 | ├── LICENSE 290 | ├── README.md 291 | └── requirements.txt 292 | ``` 293 | 294 | 295 | 296 | ### 项目模块 297 | 298 | - **chatbot**: 用于构建和管理聊天机器人互动的工具集。 299 | - **connections**: 用于连接各种服务和 API 的连接管理工具。 300 | - **constants**: Python 项目中常用的常量。 301 | - **data_schemas**: 用于处理结构化数据的数据模型及相关工具。 302 | - **decorators**: 一系列有用的装饰器,用以增强函数和方法。 303 | - **enums**: 定义常用枚举类型,方便在项目中使用。 304 | - **exceptions**: 自定义异常类,用于项目中的错误处理。 305 | - **meta_cls**: 元类和元编程相关的工具和技术。 306 | - **utils**: 包含各种实用函数和工具,用于简化日常编程任务。 307 | 308 | 309 | 310 | ### 项目文档 311 | 312 | 在 `docs` 目录下存放一些项目相关文档。 313 | 314 | 315 | 316 | ### 示例 317 | 318 | 在 `demo` 目录下,您可以找到一些使用 Py-Tools 的示例代码,这些代码展示了如何使用这些工具集实现实际项目中的任务。 319 | 320 | demo:https://github.com/HuiDBK/py-tools/tree/master/demo 321 | 322 | ### 测试 323 | 324 | 在 `tests` 目录下,包含了针对 Py-Tools 的各个组件的单元测试。通过运行这些测试,您可以确保工具集在您的环境中正常工作。 325 | 326 | 327 | 328 | ## 一起贡献 329 | > 欢迎您对本项目进行贡献。请在提交 Pull Request 之前阅读项目的贡献指南,并确保您的代码符合项目的代码风格。 330 | 331 | 1. Fork后克隆本项目到本地: 332 | ```bash 333 | git clone https://github.com//py-tools.git 334 | ``` 335 | 336 | 2. 安装依赖: 337 | ```python 338 | pip install -r requirements.txt 339 | ``` 340 | 341 | 3. 配置python代码风格检查到 git hook 中 342 | 343 | 安装 pre-commit 344 | ```python 345 | pip install pre-commit 346 | ``` 347 | 348 | 再项目目录下执行 349 | ```python 350 | pre-commit install 351 | ``` 352 | 安装成功后 git commit 后会提前进行代码检查 353 | 354 | 4. 提PR 355 | -------------------------------------------------------------------------------- /demo/connections/http_client_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { http客户端案例 } 5 | # @Date: 2023/08/10 11:39 6 | import asyncio 7 | 8 | import aiohttp 9 | 10 | from py_tools.connections.http import AsyncHttpClient, HttpClient 11 | from py_tools.constants.const import BASE_DIR 12 | from py_tools.logging import logger 13 | from py_tools.utils.async_util import AsyncUtil 14 | 15 | 16 | async def upload_file_demo(): 17 | form = aiohttp.FormData() 18 | file_path = BASE_DIR / "README.md" 19 | form.add_field("file", open(file_path, "rb"), filename="new_name.md", content_type="application/octet-stream") 20 | url = "http://localhost:8000/file_upload/file_params" 21 | upload_ret = await AsyncHttpClient().post(url=url, data=form).json() 22 | logger.debug(f"upload_ret {upload_ret}") 23 | 24 | # file_path 25 | upload_ret = await AsyncHttpClient().upload_file(url=url, file=file_path, filename="hui.md").json() 26 | logger.debug(f"upload_ret {upload_ret}") 27 | 28 | # file_bytes 29 | with open(file_path, "rb") as f: 30 | file_bytes = f.read() 31 | 32 | upload_ret = await AsyncHttpClient().upload_file(url=url, file=file_bytes, filename="hui_bytes.md").json() 33 | logger.debug(f"upload_ret {upload_ret}") 34 | 35 | 36 | async def async_http_client_demo(): 37 | logger.debug("async_http_client_demo") 38 | url = "https://juejin.cn/" 39 | 40 | # 调用 41 | resp = await AsyncHttpClient().get(url).execute() 42 | # json_data = await AsyncHttpClient().get(url).json() 43 | text_data = await AsyncHttpClient(new_session=True).get(url).text() 44 | byte_data = await AsyncHttpClient().get(url).bytes() 45 | 46 | logger.debug(f"resp {resp}") 47 | # logger.debug(f"json_data {json_data}") 48 | logger.debug(f"text_data {text_data}") 49 | logger.debug(f"byte_data {byte_data}") 50 | 51 | # 上传文件 52 | # await upload_file_demo() 53 | 54 | # 流式调用 55 | async for chunk in AsyncHttpClient().get(url).stream(): 56 | print(chunk) 57 | 58 | async with AsyncHttpClient() as client: 59 | # 独立的 aiohttp.ClientSession,用完通过上下文管理器关闭 60 | text = await client.get("https://juejin.cn/").text() 61 | print(text) 62 | 63 | 64 | def sync_http_client_demo(): 65 | logger.debug("sync_http_client_demo") 66 | url = "https://juejin.cn/" 67 | http_client = HttpClient() 68 | for i in range(2): 69 | text_content = http_client.get(url).text 70 | logger.debug(text_content) 71 | 72 | 73 | async def main(): 74 | jobs = [async_http_client_demo(), async_http_client_demo()] 75 | # await asyncio.gather(*jobs) 76 | await AsyncUtil.run_jobs(jobs, show_progress=True) 77 | await AsyncHttpClient.close() 78 | 79 | await AsyncHttpClient().get("https://juejin.cn/").bytes() 80 | await AsyncHttpClient.close() 81 | 82 | sync_http_client_demo() 83 | 84 | 85 | if __name__ == "__main__": 86 | asyncio.run(main()) 87 | -------------------------------------------------------------------------------- /demo/connections/sqlalchemy_demo/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: __init__.py 5 | # @Desc: { 模块描述 } 6 | # @Date: 2024/03/28 22:54 7 | 8 | 9 | def main(): 10 | pass 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /demo/connections/sqlalchemy_demo/demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { sqlalchemy demo } 5 | # @Date: 2023/09/04 14:22 6 | import asyncio 7 | import uuid 8 | from typing import List 9 | 10 | from connections.sqlalchemy_demo.manager import UserFileManager 11 | from connections.sqlalchemy_demo.table import UserFileTable 12 | from sqlalchemy import func 13 | 14 | from py_tools.connections.db.mysql import BaseOrmTable, DBManager, SQLAlchemyManager 15 | 16 | db_client = SQLAlchemyManager( 17 | host="127.0.0.1", 18 | port=3306, 19 | user="root", 20 | password="123456", 21 | db_name="db_demo", 22 | ) 23 | 24 | 25 | async def create_and_transaction_demo(): 26 | async with UserFileManager.transaction() as session: 27 | await UserFileManager(session).bulk_add(table_objs=[{"filename": "aaa", "oss_key": uuid.uuid4().hex}]) 28 | user_file_obj = UserFileTable(filename="eee", oss_key=uuid.uuid4().hex) 29 | file_id = await UserFileManager(session).add(table_obj=user_file_obj) 30 | print("file_id", file_id) 31 | 32 | ret: UserFileTable = await UserFileManager(session).query_by_id(2) 33 | print("query_by_id", ret) 34 | 35 | # a = 1 / 0 36 | 37 | ret = await UserFileManager(session).query_one( 38 | cols=[UserFileTable.filename, UserFileTable.oss_key], conds=[UserFileTable.filename == "ccc"] 39 | ) 40 | print("ret", ret) 41 | 42 | 43 | async def query_demo(): 44 | ret = await UserFileManager().query_one(conds=[UserFileTable.filename == "ccc"]) 45 | print("ret", ret) 46 | 47 | file_count = await UserFileManager().query_one(cols=[func.count()], flat=True) 48 | print("str col one ret", file_count) 49 | 50 | filename = await UserFileManager().query_one( 51 | cols=[UserFileTable.filename], conds=[UserFileTable.id == 2], flat=True 52 | ) 53 | print("filename", filename) 54 | 55 | ret = await UserFileManager().query_one(conds=[UserFileTable.id == 3]) 56 | print(ret) 57 | 58 | ret = await UserFileManager().query_one(cols=["filename", "oss_key"], conds=[UserFileTable.id == 3]) 59 | ret["test"] = "hui" 60 | print(ret) 61 | 62 | ret = await UserFileManager().query_all(cols=[UserFileTable.filename, UserFileTable.oss_key]) 63 | ret[0]["test"] = "hui" 64 | print("ret", ret) 65 | 66 | ret = await UserFileManager().query_all(cols=["filename", "oss_key"]) 67 | print("str col ret", ret) 68 | 69 | ret: List[UserFileTable] = await UserFileManager().query_all() 70 | print("ret", ret) 71 | 72 | ret = await UserFileManager().query_all(cols=[UserFileTable.id], flat=True) 73 | print("ret", ret) 74 | 75 | 76 | async def delete_demo(): 77 | file_count = await UserFileManager().query_one(cols=[func.count()], flat=True) 78 | print("file_count", file_count) 79 | 80 | ret = await UserFileManager().delete_by_id(file_count) 81 | print("delete_by_id ret", ret) 82 | 83 | ret = await UserFileManager().bulk_delete_by_ids(pk_ids=[10, 11, 12]) 84 | print("bulk_delete_by_ids ret", ret) 85 | 86 | ret = await UserFileManager().delete(conds=[UserFileTable.id == 13]) 87 | print("delete ret", ret) 88 | 89 | ret = await UserFileManager().delete(conds=[UserFileTable.id == 5], logic_del=True) 90 | print("logic_del ret", ret) 91 | 92 | ret = await UserFileManager().delete( 93 | conds=[UserFileTable.id == 6], logic_del=True, logic_field="is_del", logic_del_set_value=1 94 | ) 95 | print("logic_del set logic_field ret", ret) 96 | 97 | 98 | async def update_demo(): 99 | ret = await UserFileManager().update(values={"filename": "hui"}, conds=[UserFileTable.id == 1]) 100 | print("update ret", ret) 101 | 102 | # 添加 103 | user_file_info = {"filename": "huidbk", "oss_key": uuid.uuid4().hex} 104 | user_file: UserFileTable = await UserFileManager().update_or_add(table_obj=user_file_info) 105 | print("update_or_add add", user_file) 106 | 107 | # 更新 108 | user_file.file_suffix = "png" 109 | user_file.file_size = 100 110 | user_file.filename = "hui-update_or_add" 111 | ret = await UserFileManager().update_or_add(table_obj=user_file) 112 | print("update_or_add update", ret) 113 | 114 | 115 | async def list_page_demo(): 116 | """分页查询demo""" 117 | total_count, data_list = await UserFileManager().list_page( 118 | cols=["filename", "oss_key", "file_size"], curr_page=2, page_size=10 119 | ) 120 | print("total_count", total_count, f"data_list[{len(data_list)}]", data_list) 121 | 122 | 123 | async def run_raw_sql_demo(): 124 | """运行原生sql demo""" 125 | count_sql = "select count(*) as total_count from user_file" 126 | count_ret = await UserFileManager().run_sql(count_sql, query_one=True) 127 | print("count_ret", count_ret) 128 | 129 | data_sql = "select * from user_file where id > :id_val and file_size >= :file_size_val" 130 | params = {"id_val": 20, "file_size_val": 0} 131 | data_ret = await UserFileManager().run_sql(data_sql, params=params) 132 | print("dict data_ret", data_ret) 133 | 134 | data_sql = "select * from user_file where id > :id_val" 135 | data_ret = await UserFileManager().run_sql(sql=data_sql, params={"id_val": 4}) 136 | print("dict data_ret", data_ret) 137 | 138 | # 连表查询 139 | data_sql = """ 140 | select 141 | user.id as user_id, 142 | username, 143 | user_file.id as file_id, 144 | filename, 145 | oss_key 146 | from 147 | user_file 148 | join user on user.id = user_file.creator 149 | where 150 | user_file.creator = :user_id 151 | """ 152 | data_ret = await UserFileManager().run_sql(data_sql, params={"user_id": 1}) 153 | print("join sql data_ret", data_ret) 154 | 155 | 156 | async def curd_demo(): 157 | await create_and_transaction_demo() 158 | await query_demo() 159 | await list_page_demo() 160 | await update_demo() 161 | await delete_demo() 162 | await run_raw_sql_demo() 163 | 164 | 165 | async def create_tables(): 166 | # 根据映射创建库表 167 | async with DBManager.connection() as conn: 168 | await conn.run_sync(BaseOrmTable.metadata.create_all) 169 | 170 | 171 | async def main(): 172 | db_client.init_mysql_engine() 173 | DBManager.init_db_client(db_client) 174 | await create_tables() 175 | await curd_demo() 176 | 177 | 178 | if __name__ == "__main__": 179 | asyncio.run(main()) 180 | -------------------------------------------------------------------------------- /demo/connections/sqlalchemy_demo/manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { sqlalchemy demo } 5 | # @Date: 2023/09/04 14:22 6 | from connections.sqlalchemy_demo.table import UserTable, UserFileTable 7 | from py_tools.connections.db.mysql import DBManager 8 | 9 | 10 | class UserManager(DBManager): 11 | orm_table = UserTable 12 | 13 | 14 | class UserFileManager(DBManager): 15 | orm_table = UserFileTable 16 | -------------------------------------------------------------------------------- /demo/connections/sqlalchemy_demo/table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { sqlalchemy demo } 5 | # @Date: 2023/09/04 14:22 6 | from datetime import datetime 7 | 8 | from sqlalchemy import String 9 | from sqlalchemy.orm import Mapped, mapped_column 10 | 11 | from py_tools.connections.db.mysql import BaseOrmTable, BaseOrmTableWithTS 12 | 13 | 14 | class UserTable(BaseOrmTableWithTS): 15 | """用户表""" 16 | 17 | __tablename__ = "user" 18 | username: Mapped[str] = mapped_column(String(100), default="", comment="用户昵称") 19 | age: Mapped[int] = mapped_column(default=0, comment="年龄") 20 | password: Mapped[str] = mapped_column(String(100), default="", comment="用户密码") 21 | phone: Mapped[str] = mapped_column(String(100), default="", comment="手机号") 22 | email: Mapped[str] = mapped_column(String(100), default="", comment="邮箱") 23 | avatar: Mapped[str] = mapped_column(String(100), default="", comment="头像") 24 | 25 | 26 | class UserFileTable(BaseOrmTable): 27 | """用户文件表""" 28 | 29 | __tablename__ = "user_file" 30 | filename: Mapped[str] = mapped_column(String(100), default="", comment="文件名称") 31 | creator: Mapped[int] = mapped_column(default=0, comment="文件创建者") 32 | file_suffix: Mapped[str] = mapped_column(String(100), default="", comment="文件后缀") 33 | file_size: Mapped[int] = mapped_column(default=0, comment="文件大小") 34 | oss_key: Mapped[str] = mapped_column(String(100), default="", comment="oss key(minio)") 35 | is_del: Mapped[int] = mapped_column(default=0, comment="是否删除") 36 | deleted_at: Mapped[datetime] = mapped_column(nullable=True, comment="删除时间") 37 | -------------------------------------------------------------------------------- /demo/decorators/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/02/12 22:15 6 | import asyncio 7 | import time 8 | 9 | from loguru import logger 10 | 11 | from py_tools.decorators import retry, set_timeout 12 | from py_tools.exceptions import MaxRetryException, MaxTimeoutException 13 | 14 | 15 | @retry() 16 | def user_place_order_success_demo(): 17 | """用户下单成功模拟""" 18 | logger.debug("user place order success") 19 | return {"code": 0, "msg": "ok"} 20 | 21 | 22 | @retry(max_count=3, interval=3) 23 | def user_place_order_fail_demo(): 24 | """用户下单失败模拟""" 25 | a = 1 / 0 # 使用除零异常模拟业务错误 26 | logger.debug("user place order success") 27 | return {"code": 0, "msg": "ok"} 28 | 29 | 30 | @set_timeout(2) 31 | @retry(max_count=3) 32 | def user_place_order_timeout_demo(): 33 | """用户下单失败模拟""" 34 | time.sleep(5) # 模拟业务超时 35 | logger.debug("user place order success") 36 | return {"code": 0, "msg": "ok"} 37 | 38 | 39 | @retry(max_count=2, interval=3) 40 | async def async_user_place_order_demo(): 41 | logger.debug("user place order success") 42 | return {"code": 0, "msg": "ok"} 43 | 44 | 45 | @retry(max_count=2) 46 | async def async_user_place_order_fail_demo(): 47 | a = 1 / 0 48 | logger.debug("user place order success") 49 | return {"code": 0, "msg": "ok"} 50 | 51 | 52 | @set_timeout(2) 53 | @retry(max_count=3) 54 | async def async_user_place_order_timeout_demo(): 55 | await asyncio.sleep(3) 56 | logger.debug("user place order success") 57 | return {"code": 0, "msg": "ok"} 58 | 59 | 60 | def sync_demo(): 61 | """同步案例""" 62 | user_place_order_success_demo() 63 | 64 | try: 65 | user_place_order_fail_demo() 66 | except MaxRetryException as e: 67 | # 超出最大重新次数异常,业务逻辑处理 68 | logger.debug(f"sync 超出最大重新次数 {e}") 69 | 70 | try: 71 | user_place_order_timeout_demo() 72 | except MaxTimeoutException as e: 73 | # 超时异常,业务逻辑处理 74 | logger.debug(f"sync 超时异常, {e}") 75 | 76 | 77 | async def async_demo(): 78 | """异步案例""" 79 | await async_user_place_order_demo() 80 | 81 | try: 82 | await async_user_place_order_fail_demo() 83 | except MaxRetryException as e: 84 | logger.debug(f"async 超出最大重新次数 {e}") 85 | 86 | try: 87 | await async_user_place_order_timeout_demo() 88 | except MaxTimeoutException as e: 89 | logger.debug(f"async 超时异常, {e}") 90 | 91 | 92 | async def main(): 93 | # sync_demo() 94 | 95 | await async_demo() 96 | 97 | 98 | if __name__ == '__main__': 99 | asyncio.run(main()) 100 | -------------------------------------------------------------------------------- /demo/decorators/cache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: cache.py 5 | # @Desc: { cache demo 模块 } 6 | # @Date: 2024/04/23 11:11 7 | import asyncio 8 | import time 9 | from datetime import timedelta 10 | 11 | import cacheout 12 | 13 | from py_tools.connections.db.redis_client import BaseRedisManager 14 | from py_tools.decorators.cache import AsyncRedisCacheProxy, MemoryCacheProxy, RedisCacheProxy, cache_json 15 | 16 | 17 | class RedisManager(BaseRedisManager): 18 | client = None 19 | 20 | 21 | class AsyncRedisManager(BaseRedisManager): 22 | client = None 23 | 24 | 25 | RedisManager.init_redis_client(async_client=False) 26 | AsyncRedisManager.init_redis_client(async_client=True) 27 | 28 | memory_proxy = MemoryCacheProxy(cache_client=cacheout.Cache()) 29 | redis_proxy = RedisCacheProxy(cache_client=RedisManager.client) 30 | aredis_proxy = AsyncRedisCacheProxy(cache_client=AsyncRedisManager.client) 31 | 32 | 33 | @cache_json(key_prefix="demo", ttl=3) 34 | def memory_cache_demo_func(name: str, age: int): 35 | return {"test_memory_cache": "hui-test", "name": name, "age": age} 36 | 37 | 38 | @cache_json(cache_proxy=redis_proxy, ttl=10) 39 | def redis_cache_demo_func(name: str, age: int): 40 | return {"test_redis_cache": "hui-test", "name": name, "age": age} 41 | 42 | 43 | @cache_json(cache_proxy=aredis_proxy, ttl=timedelta(minutes=1)) 44 | async def aredis_cache_demo_func(name: str, age: int): 45 | return {"test_async_redis_cache": "hui-test", "name": name, "age": age} 46 | 47 | 48 | @AsyncRedisManager.cache_json(ttl=30) 49 | async def aredis_manager_cache_demo_func(name: str, age: int): 50 | return {"test_async_redis_manager_cache": "hui-test", "name": name, "age": age} 51 | 52 | 53 | def memory_cache_demo(): 54 | print("memory_cache_demo") 55 | ret1 = memory_cache_demo_func(name="hui", age=18) 56 | print("ret1", ret1) 57 | print() 58 | 59 | ret2 = memory_cache_demo_func(name="hui", age=18) 60 | print("ret2", ret2) 61 | print() 62 | 63 | time.sleep(3) 64 | ret3 = memory_cache_demo_func(age=18, name="hui") 65 | print("ret3", ret3) 66 | print() 67 | 68 | assert ret1 == ret2 == ret3 69 | 70 | # ret4 = memory_cache_demo_func(name="huidbk", age=18) 71 | # print("ret4", ret4) 72 | # print() 73 | # 74 | # ret5 = memory_cache_demo_func(name="huidbk", age=20) 75 | # print("ret5", ret5) 76 | # print() 77 | # 78 | # assert ret4 != ret5 79 | # 80 | # ret6 = memory_cache_demo_func(name="huidbk", age=20) 81 | # print("ret6", ret6) 82 | # print() 83 | # 84 | # assert ret5 == ret6 85 | 86 | 87 | def redis_cache_demo(): 88 | print("redis_cache_demo") 89 | ret1 = redis_cache_demo_func(name="hui", age=18) 90 | print("ret1", ret1) 91 | print() 92 | 93 | ret2 = redis_cache_demo_func(name="hui", age=18) 94 | print("ret2", ret2) 95 | 96 | assert ret1 == ret2 97 | 98 | 99 | async def aredis_cache_demo(): 100 | print("aredis_cache_demo") 101 | ret1 = await aredis_cache_demo_func(name="hui", age=18) 102 | print("ret1", ret1) 103 | print() 104 | 105 | ret2 = await aredis_cache_demo_func(name="hui", age=18) 106 | print("ret2", ret2) 107 | 108 | assert ret1 == ret2 109 | 110 | 111 | async def aredis_manager_cache_demo(): 112 | print("aredis_manager_cache_demo") 113 | ret1 = await aredis_manager_cache_demo_func(name="hui", age=18) 114 | print("ret1", ret1) 115 | print() 116 | 117 | ret2 = await aredis_manager_cache_demo_func(name="hui", age=18) 118 | print("ret2", ret2) 119 | 120 | assert ret1 == ret2 121 | 122 | 123 | async def main(): 124 | memory_cache_demo() 125 | 126 | redis_cache_demo() 127 | 128 | await aredis_cache_demo() 129 | 130 | await aredis_manager_cache_demo() 131 | 132 | 133 | if __name__ == "__main__": 134 | asyncio.run(main()) 135 | -------------------------------------------------------------------------------- /demo/decorators/run_on_executor_demo.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from concurrent.futures import ThreadPoolExecutor 4 | 5 | from py_tools.decorators.base import run_on_executor 6 | from loguru import logger 7 | 8 | thread_executor = ThreadPoolExecutor(max_workers=3) 9 | 10 | 11 | @run_on_executor(background=True) 12 | async def async_func_bg_task(): 13 | logger.debug("async_func_bg_task start") 14 | await asyncio.sleep(1) 15 | logger.debug("async_func_bg_task running") 16 | await asyncio.sleep(1) 17 | logger.debug("async_func_bg_task end") 18 | return "async_func_bg_task ret end" 19 | 20 | 21 | @run_on_executor() 22 | async def async_func(): 23 | logger.debug("async_func start") 24 | await asyncio.sleep(1) 25 | logger.debug("async_func running") 26 | await asyncio.sleep(1) 27 | return "async_func ret end" 28 | 29 | 30 | @run_on_executor(background=True, executor=thread_executor) 31 | def sync_func_bg_task(): 32 | logger.debug("sync_func_bg_task start") 33 | time.sleep(1) 34 | logger.debug("sync_func_bg_task running") 35 | time.sleep(1) 36 | logger.debug("sync_func_bg_task end") 37 | return "sync_func_bg_task end" 38 | 39 | 40 | @run_on_executor() 41 | def sync_func(): 42 | logger.debug("sync_func start") 43 | time.sleep(1) 44 | logger.debug("sync_func running") 45 | time.sleep(1) 46 | return "sync_func ret end" 47 | 48 | 49 | async def main(): 50 | ret = await async_func() 51 | logger.debug(ret) 52 | 53 | async_bg_task = await async_func_bg_task() 54 | logger.debug(f"async bg task {async_bg_task}") 55 | logger.debug("async_func_bg_task 等待后台执行中") 56 | 57 | loop = asyncio.get_event_loop() 58 | for i in range(3): 59 | loop.create_task(async_func()) 60 | 61 | ret = await sync_func() 62 | logger.debug(ret) 63 | 64 | sync_bg_task = sync_func_bg_task() 65 | logger.debug(f"sync bg task {sync_bg_task}") 66 | logger.debug("sync_func_bg_task 等待后台执行") 67 | 68 | await asyncio.sleep(10) 69 | 70 | if __name__ == '__main__': 71 | asyncio.run(main()) 72 | -------------------------------------------------------------------------------- /demo/logging/logging_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: logging_demo.py 5 | # @Desc: { 日志使用案例 } 6 | # @Date: 2024/08/12 14:53 7 | import logging 8 | 9 | from py_tools.constants import BASE_DIR 10 | from py_tools.logging import logger, setup_logging 11 | from py_tools.logging.default_logging_conf import default_logging_conf 12 | 13 | 14 | def main(): 15 | setup_logging(log_dir=BASE_DIR / "logs") 16 | logger.info("use log dir") 17 | logger.error("test error") 18 | 19 | log_conf = default_logging_conf.get("server_handler") 20 | log_conf["sink"] = BASE_DIR / "logs/server.log" 21 | setup_logging(log_conf=log_conf, console_log_level=logging.WARN) 22 | 23 | logger.info("use log conf") 24 | logger.error("test error") 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /demo/meta_cls/singleton_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 元类demo模块 } 5 | # @Date: 2023/08/28 11:18 6 | from py_tools.meta_cls import SingletonMetaCls 7 | 8 | 9 | class Foo(metaclass=SingletonMetaCls): 10 | 11 | def __init__(self): 12 | print("Foo __init__") 13 | self.bar = "bar" 14 | 15 | def __new__(cls, *args, **kwargs): 16 | print("Foo __new__") 17 | return super().__new__(cls, *args, **kwargs) 18 | 19 | def tow_bar(self): 20 | return self.bar * 2 21 | 22 | 23 | foo1 = Foo() 24 | foo2 = Foo() 25 | print("foo1 is foo2", foo1 is foo2) 26 | print("foo2 two_bar", foo2.tow_bar()) 27 | 28 | 29 | class Demo(Foo): 30 | 31 | def __init__(self): 32 | self.bar = "demo_bar" 33 | 34 | 35 | demo1 = Demo() 36 | demo2 = Demo() 37 | print("demo1 is demo2", demo1 is demo2) 38 | print("demo2 two_bar", demo2.tow_bar()) 39 | -------------------------------------------------------------------------------- /demo/utils/async_util_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: async.py 5 | # @Desc: { 异步相关工具函数demo } 6 | # @Date: 2024/04/24 15:32 7 | import asyncio 8 | import time 9 | from concurrent.futures import ThreadPoolExecutor 10 | 11 | from py_tools.utils.async_util import AsyncUtil 12 | 13 | BASE_EXECUTOR = ThreadPoolExecutor(max_workers=3) 14 | 15 | 16 | async def async_bg_task(name, age): 17 | print(f"async_bg_task run... {name}, {age}") 18 | await asyncio.sleep(1) 19 | print("async_bg_task done") 20 | return name, age 21 | 22 | 23 | def sync_bg_task(name, age): 24 | print(f"sync_bg_task run... {name}, {age}") 25 | time.sleep(1) 26 | print("sync_bg_task done") 27 | return name, age 28 | 29 | 30 | async def main(): 31 | AsyncUtil.run_bg_task(sync_bg_task, name="sync-hui", age=18) 32 | 33 | AsyncUtil.run_bg_task(async_bg_task(name="async-hui", age=18)) 34 | ret = await AsyncUtil.run_bg_task(async_bg_task(name="async-hui", age=18)) 35 | print(ret) 36 | 37 | future_task = AsyncUtil.run_bg_task(sync_bg_task, name="executor-sync-hui", age=18, executor=BASE_EXECUTOR) 38 | print(future_task.result()) 39 | 40 | await asyncio.sleep(5) 41 | 42 | ret = await AsyncUtil.async_run(sync_bg_task, name="async to sync", age=18, executor=BASE_EXECUTOR) 43 | print(ret) 44 | 45 | await AsyncUtil.SyncToAsync(sync_bg_task)(name="sync to async", age=18) 46 | 47 | 48 | def async_to_sync_demo(): 49 | ret = AsyncUtil.sync_run(async_bg_task(name="sync run async", age=18)) 50 | print("sync_run", ret) 51 | 52 | ret = AsyncUtil.AsyncToSync(async_bg_task)(name="async to async", age=18) 53 | print(ret) 54 | 55 | 56 | if __name__ == "__main__": 57 | asyncio.run(main()) 58 | async_to_sync_demo() 59 | -------------------------------------------------------------------------------- /demo/utils/excel_util_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { excel操作案例 } 5 | # @Date: 2023/04/17 0:31 6 | from io import BytesIO 7 | 8 | from py_tools.constants import DEMO_DATA 9 | from py_tools.logging import logger 10 | from py_tools.utils import ExcelUtil 11 | from py_tools.utils.excel_util import ColumnMapping, DataCollect, SheetMapping 12 | 13 | 14 | def list_to_excel_demo(): 15 | user_list = [ 16 | dict(id=1, name="hui", age=20), 17 | dict(id=2, name="wang", age=22), 18 | dict(id=3, name="zack", age=25), 19 | ] 20 | user_col_mappings = [ 21 | ColumnMapping(column_name="id", column_alias="用户id"), 22 | ColumnMapping(column_name="name", column_alias="用户名"), 23 | ColumnMapping(column_name="age", column_alias="年龄"), 24 | ] 25 | 26 | file_path = DEMO_DATA / "user.xlsx" 27 | ExcelUtil.list_to_excel(file_path, user_list, col_mappings=user_col_mappings) 28 | 29 | # 导出为excel文件字节流处理 30 | excel_bio = BytesIO() 31 | ExcelUtil.list_to_excel(excel_bio, data_list=user_list, col_mappings=user_col_mappings, sheet_name="buffer_demo") 32 | excel_bytes = excel_bio.getvalue() 33 | logger.debug(f"excel_bytes type => {type(excel_bytes)}") 34 | 35 | # 这里以重新写到文件里为例,字节流再业务中按需操作即可 36 | with open(f"{DEMO_DATA}/user_byte.xlsx", mode="wb") as f: 37 | f.write(excel_bytes) 38 | 39 | 40 | def multi_list_to_excel_demo(): 41 | user_list = [ 42 | {"id": 1, "name": "hui", "age": 18}, 43 | {"id": 2, "name": "wang", "age": 19}, 44 | {"id": 3, "name": "zack", "age": 20}, 45 | ] 46 | 47 | book_list = [ 48 | {"id": 1, "name": "Python基础教程", "author": "hui", "price": 30}, 49 | {"id": 2, "name": "Java高级编程", "author": "wang", "price": 50}, 50 | {"id": 3, "name": "机器学习实战", "author": "zack", "price": 70}, 51 | ] 52 | 53 | user_col_mappings = [ 54 | ColumnMapping(column_name="id", column_alias="编号"), 55 | ColumnMapping(column_name="name", column_alias="姓名"), 56 | ColumnMapping(column_name="age", column_alias="年龄"), 57 | ] 58 | book_col_mappings = [ 59 | ColumnMapping(column_name="id", column_alias="编号"), 60 | ColumnMapping(column_name="name", column_alias="书名"), 61 | ColumnMapping(column_name="author", column_alias="作者"), 62 | ColumnMapping(column_name="price", column_alias="价格"), 63 | ] 64 | 65 | data_collects = [ 66 | DataCollect(data_list=user_list, col_mappings=user_col_mappings, sheet_name="用户信息"), 67 | DataCollect(data_list=book_list, col_mappings=book_col_mappings, sheet_name="图书信息"), 68 | ] 69 | 70 | ExcelUtil.multi_list_to_excel(f"{DEMO_DATA}/multi_sheet_data.xlsx", data_collects) 71 | 72 | 73 | def read_excel_demo(): 74 | data = [ 75 | {"id": 1, "name": "hui", "age": 30}, 76 | {"id": 2, "name": "zack", "age": 25}, 77 | {"id": 3, "name": "", "age": 40}, 78 | ] 79 | 80 | user_col_mappings = [ 81 | ColumnMapping(column_name="id", column_alias="用户id"), 82 | ColumnMapping(column_name="name", column_alias="用户名"), 83 | ColumnMapping(column_name="age", column_alias="年龄"), 84 | ] 85 | 86 | user_id_and_name_mappings = [ 87 | ColumnMapping(column_name="用户id", column_alias="id"), 88 | ColumnMapping(column_name="用户名", column_alias="name"), 89 | ] 90 | 91 | # 将数据写入Excel文件 92 | file_path = DEMO_DATA / "read_demo.xlsx" 93 | ExcelUtil.list_to_excel(file_path, data, col_mappings=user_col_mappings) 94 | 95 | # 读取Excel文件 96 | result = ExcelUtil.read_excel(file_path, col_mappings=user_id_and_name_mappings, all_col=False, nan_replace="") 97 | 98 | logger.debug(f"read_excel {result}") 99 | 100 | 101 | def merge_excel_files_demo(): 102 | # 合并多个Excel文件 103 | ExcelUtil.merge_excel_files( 104 | input_files=[f"{DEMO_DATA}/user.xlsx", f"{DEMO_DATA}/multi_sheet_data.xlsx"], 105 | output_file=f"{DEMO_DATA}/merged_data.xlsx", 106 | sheet_mappings=[ 107 | SheetMapping(file_name="user.xlsx", sheet_name="user"), 108 | SheetMapping(file_name="multi_sheet_data.xlsx", sheet_name="multi_sheet_data"), 109 | ], 110 | ) 111 | 112 | 113 | def main(): 114 | list_to_excel_demo() 115 | 116 | multi_list_to_excel_demo() 117 | 118 | read_excel_demo() 119 | 120 | merge_excel_files_demo() 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /demo/utils/jwt_util_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: jwt_util_demo.py 5 | # @Desc: { jwt util demo } 6 | # @Date: 2024/11/04 15:20 7 | from datetime import timedelta 8 | 9 | from loguru import logger 10 | 11 | from py_tools.utils import JWTUtil 12 | 13 | 14 | def main(): 15 | # 初始化密钥和算法 16 | jwt_util = JWTUtil(secret_key="your_secret_key", algorithm="HS256") 17 | 18 | # 生成 JWT 19 | data = {"user_id": "12345", "role": "admin"} 20 | token = jwt_util.generate_token(data) 21 | logger.info(f"Generated Token: {token}") 22 | 23 | # 验证 JWT 24 | decoded_data = jwt_util.verify_token(token) 25 | if decoded_data: 26 | logger.info(f"Decoded Data: {decoded_data}") 27 | else: 28 | logger.info("Token is invalid or expired.") 29 | 30 | # 刷新 JWT 31 | refreshed_token = jwt_util.refresh_token(token, expires_delta=timedelta(days=1)) 32 | logger.info(f"Refreshed Token: {refreshed_token}") 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /demo/utils/serializer_util_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: serializer_util_demo.py 5 | # @Desc: { 序列号工具类 demo } 6 | # @Date: 2024/11/15 17:13 7 | from dataclasses import dataclass 8 | 9 | from pydantic import BaseModel 10 | from sqlalchemy import Column, String 11 | 12 | from py_tools.connections.db.mysql import BaseOrmTable 13 | from py_tools.utils import SerializerUtil 14 | 15 | 16 | # sqlalchemy 示例 17 | class UserTable(BaseOrmTable): 18 | __tablename__ = "user" 19 | username = Column(String(20)) 20 | email = Column(String(50)) 21 | 22 | 23 | # Pydantic 示例 24 | class UserModel(BaseModel): 25 | id: int 26 | username: str 27 | email: str 28 | 29 | 30 | @dataclass 31 | class UserDataclass: 32 | id: int 33 | username: str 34 | email: str 35 | 36 | 37 | class UserCustomModel: 38 | def __init__(self, id: int, username: str, email: str): 39 | self.id = id 40 | self.username = username 41 | self.email = email 42 | 43 | def to_dict(self): 44 | return {"id": self.id, "username": self.username, "email": self.email} 45 | 46 | 47 | def serializer_demo(): 48 | user_table_obj = UserTable(id=2, username="wang", email="wang@example.com") 49 | user_model_obj = UserModel(id=3, username="zack", email="zack@example.com") 50 | user_dataclass_obj = UserDataclass(id=4, username="lisa", email="lisa@example.com") 51 | user_custom_model = UserCustomModel(id=5, username="lily", email="lily@example.com") 52 | user_infos = [ 53 | {"id": 1, "username": "hui", "email": "hui@example.com"}, 54 | user_table_obj, 55 | user_model_obj, 56 | user_dataclass_obj, 57 | user_custom_model, 58 | ] 59 | 60 | print("data_to_model") 61 | user_models = SerializerUtil.data_to_model(data_obj=user_infos, to_model=UserModel) 62 | print(type(user_models), user_models) 63 | 64 | user_models = SerializerUtil.data_to_model(data_obj=user_infos, to_model=UserTable) 65 | print(type(user_models), user_models) 66 | 67 | user_models = SerializerUtil.data_to_model(data_obj=user_infos, to_model=UserDataclass) 68 | print(type(user_models), user_models) 69 | 70 | user_models = SerializerUtil.data_to_model(data_obj=user_infos, to_model=UserCustomModel) 71 | print(type(user_models), user_models) 72 | 73 | user_model = SerializerUtil.data_to_model(data_obj=user_infos[0], to_model=UserModel) 74 | user_table = SerializerUtil.data_to_model(data_obj=user_infos[0], to_model=UserTable) 75 | user_dataclass = SerializerUtil.data_to_model(data_obj=user_infos[0], to_model=UserDataclass) 76 | print(type(user_model), user_model) 77 | print(type(user_table), user_table) 78 | print(type(user_dataclass), user_dataclass) 79 | 80 | # model_to_data 81 | print("\n\nmodel_to_data") 82 | user_infos = SerializerUtil.model_to_data(user_infos) 83 | print(type(user_infos), user_infos) 84 | 85 | user_info = SerializerUtil.model_to_data(user_model_obj) 86 | print(type(user_info), user_info) 87 | 88 | user_info = SerializerUtil.model_to_data(user_table_obj) 89 | print(type(user_info), user_info) 90 | 91 | user_info = SerializerUtil.model_to_data(user_dataclass_obj) 92 | print(type(user_info), user_info) 93 | 94 | user_info = SerializerUtil.model_to_data(user_custom_model) 95 | print(type(user_info), user_info) 96 | 97 | 98 | def main(): 99 | serializer_demo() 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /demo/utils/time_util_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 时间工具类案例 } 5 | # @Date: 2023/04/30 21:08 6 | import time 7 | from datetime import datetime 8 | 9 | from py_tools.enums import TimeFormatEnum 10 | from py_tools.utils import TimeUtil 11 | 12 | 13 | def time_util_demo(): 14 | # 创建一个TimeUtil实例,默认使用当前时间 15 | time_util = TimeUtil.instance() 16 | 17 | print("昨天的日期:", time_util.yesterday) 18 | 19 | print("明天的日期:", time_util.tomorrow) 20 | 21 | print("一周后的日期:", time_util.week_later) 22 | 23 | print("一个月后的日期:", time_util.month_later) 24 | 25 | # 从现在开始增加10天 26 | print("10天后的日期:", time_util.add_time(days=10)) 27 | 28 | # 从现在开始减少5天 29 | print("5天前的日期:", time_util.sub_time(days=5)) 30 | 31 | date_str = "2023-05-01 12:00:00" 32 | print("字符串转换为datetime对象:", time_util.str_to_datetime(date_str)) 33 | 34 | print("datetime对象转换为字符串:", time_util.datetime_to_str()) 35 | 36 | timestamp = time.time() 37 | print("时间戳转换为时间字符串:", time_util.timestamp_to_str(timestamp)) 38 | 39 | time_str = "2023-05-01 12:00:00" 40 | print("时间字符串转换为时间戳:", time_util.str_to_timestamp(time_str)) 41 | 42 | print("时间戳转换为datetime对象:", time_util.timestamp_to_datetime(timestamp)) 43 | 44 | print("当前时间的时间戳:", time_util.timestamp) 45 | 46 | date1 = datetime(2023, 4, 24) # 2023年4月24日,星期一 47 | date2 = datetime(2023, 5, 1) # 2023年5月1日,星期一 48 | time_util = TimeUtil(datetime_obj=date1) 49 | 50 | # 计算两个日期之间的工作日数量 51 | weekday_count = time_util.count_weekdays_between(date2, include_end_date=True) 52 | print(f"从 {date1} 到 {date2} 之间有 {weekday_count} 个工作日。(包含末尾日期)") 53 | 54 | weekday_count = time_util.count_weekdays_between(date2, include_end_date=False) 55 | print(f"从 {date1} 到 {date2} 之间有 {weekday_count} 个工作日。(不包含末尾日期)") 56 | 57 | # 获取两个日期之间的差值 58 | date_diff = time_util.date_diff(date2) 59 | print(date_diff) 60 | 61 | datetime_ret = time_util.datetime_to_str(format_str=TimeFormatEnum.DateTime_CN) 62 | date_only_ret = time_util.datetime_to_str(format_str=TimeFormatEnum.DateOnly_CN) 63 | time_only_ret = time_util.datetime_to_str(format_str=TimeFormatEnum.TimeOnly_CN) 64 | print(datetime_ret) 65 | print(date_only_ret) 66 | print(time_only_ret) 67 | 68 | 69 | def main(): 70 | time_util_demo() 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /demo/utils/tree_util_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: tree.py 5 | # @Desc: { 模块描述 } 6 | # @Date: 2024/04/24 11:55 7 | import copy 8 | from pprint import pprint 9 | 10 | from py_tools.utils.tree_util import ( 11 | list_to_tree_bfs, 12 | list_to_tree_dfs, 13 | tree_to_list_bfs, 14 | tree_to_list_dfs, 15 | ) 16 | 17 | depart_list = [ 18 | {"id": 1, "name": "a1", "pid": 0}, 19 | {"id": 2, "name": "a1_2", "pid": 1}, 20 | {"id": 3, "name": "a1_3", "pid": 1}, 21 | {"id": 4, "name": "a1_4", "pid": 1}, 22 | {"id": 5, "name": "a2_1", "pid": 2}, 23 | {"id": 6, "name": "a2_2", "pid": 2}, 24 | {"id": 7, "name": "a3", "pid": 0}, 25 | {"id": 8, "name": "a3_1", "pid": 7}, 26 | {"id": 9, "name": "a3_2", "pid": 7}, 27 | {"id": 10, "name": "a4", "pid": 0}, 28 | ] 29 | 30 | 31 | def main(): 32 | depart_tree_list = list_to_tree_dfs(copy.deepcopy(depart_list), root_pid=0, sub_field="subs", need_level=False) 33 | depart_tree_list = list_to_tree_bfs(copy.deepcopy(depart_list), root_pid=0, sub_field="subs", need_level=True) 34 | 35 | print("原来列表") 36 | pprint(depart_list) 37 | 38 | print("列表 bfs=> 树形列表") 39 | pprint(depart_tree_list) 40 | 41 | print("树形列表 dfs=> 列表") 42 | pprint(tree_to_list_dfs(copy.deepcopy(depart_tree_list), sub_field="subs", need_level=True)) 43 | 44 | print("树形列表 bfs=> 列表") 45 | pprint(tree_to_list_bfs(copy.deepcopy(depart_tree_list), sub_field="subs", need_level=True)) 46 | 47 | print("原来树形列表") 48 | pprint(depart_tree_list) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /py_tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /py_tools/chatbot/__init__.py: -------------------------------------------------------------------------------- 1 | from py_tools.chatbot.chatbot import ( 2 | DingTalkChatBot, 3 | FeiShuChatBot, 4 | WeComChatbot, 5 | BaseChatBot, 6 | ) 7 | 8 | from py_tools.chatbot.app_server import FeiShuAppServer, FeiShuTaskChatBot 9 | from py_tools.chatbot.factory import ChatBotFactory, ChatBotType 10 | 11 | __all__ = [ 12 | "DingTalkChatBot", 13 | "FeiShuChatBot", 14 | "FeiShuTaskChatBot", 15 | "WeComChatbot", 16 | "BaseChatBot", 17 | "FeiShuAppServer", 18 | "FeiShuTaskChatBot", 19 | "ChatBotFactory", 20 | "ChatBotType", 21 | ] 22 | -------------------------------------------------------------------------------- /py_tools/chatbot/app_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 机器人应用服务模块 } 5 | # @Date: 2023/05/03 18:51 6 | import json 7 | from typing import List 8 | 9 | import requests 10 | from cacheout import Cache 11 | from loguru import logger 12 | 13 | from py_tools.enums.feishu import FeishuReceiveType 14 | from py_tools.exceptions import SendMsgException 15 | 16 | 17 | class FeiShuAppServer: 18 | """飞书应用服务类""" 19 | 20 | # 用于缓存应用的access_token,减少http请求 21 | token_cache = Cache() 22 | 23 | # 通过飞书应用的 app_id, app_secret 获取 access_token 24 | # API参考: https://open.feishu.cn/document/ukTMukTMukTM/uMTNz4yM1MjLzUzM 25 | GET_FEISHU_TENANT_ACCESS_TOKEN_URL = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" 26 | 27 | # 获取飞书用户 open_id 28 | # API参考: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/contact-v3/user/batch_get_id 29 | GET_FEISHU_USER_OPEN_ID_URL = "https://open.feishu.cn/open-apis/contact/v3/users/batch_get_id" 30 | 31 | # 给飞书用户/群聊发送消息 32 | # API参考: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/create 33 | NOTIFY_FEISHU_USER_MSG_URL = "https://open.feishu.cn/open-apis/im/v1/messages" 34 | 35 | # 获取用户或机器人所在的群列表 36 | # API参考: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/chat/list 37 | GET_FEISHU_USER_OF_GROUP_URL = "https://open.feishu.cn/open-apis/im/v1/chats" 38 | 39 | def __init__(self, app_id: str, app_secret: str, timeout=10): 40 | """ 41 | 飞书应用服务初始化 42 | Args: 43 | app_id: 应用id 44 | app_secret: 应用密钥 45 | timeout: 请求连接超时 默认10s 46 | """ 47 | self.app_id = app_id 48 | self.app_secret = app_secret 49 | self.token_cache_key = f"{app_id}:{app_secret}:token" # 用于缓存 access_token 的key 50 | self.req_timeout = timeout 51 | 52 | def _get_tenant_access_token(self): 53 | """ 54 | 获取飞书access_token用于访问飞书相关接口 55 | 先从程序缓存中获取,没有则再发请求获取 56 | API参考: https://open.feishu.cn/document/ukTMukTMukTM/uMTNz4yM1MjLzUzM 57 | """ 58 | 59 | # 先从程序内存缓存中获取 tenant_access_token 60 | tenant_access_token = self.token_cache.get(key=self.token_cache_key) 61 | if tenant_access_token: 62 | logger.debug(f"cache get tenant_access_token {tenant_access_token}") 63 | return tenant_access_token 64 | 65 | # 缓存没有再请求 66 | app_info = { 67 | "app_id": self.app_id, 68 | "app_secret": self.app_secret 69 | } 70 | resp = requests.post(url=self.GET_FEISHU_TENANT_ACCESS_TOKEN_URL, json=app_info, timeout=self.req_timeout) 71 | ret_info = resp.json() 72 | if ret_info.get("code") != 0: 73 | raise ValueError(f"FeiShuAppServer get_tenant_access_token error, {ret_info}") 74 | 75 | expire = ret_info.get("expire", 0) 76 | tenant_access_token = ret_info.get("tenant_access_token") 77 | ttl = expire - 5 * 60 # 缓存比过期时间少5分钟 78 | 79 | # 存入当前程序内存中过期则才重新访问请求获取 80 | self.token_cache.set(key=self.token_cache_key, value=tenant_access_token, ttl=ttl) 81 | 82 | return tenant_access_token 83 | 84 | def _get_user_open_id( 85 | self, 86 | mobiles: list = None, 87 | emails: list = None, 88 | user_id_type: FeishuReceiveType = FeishuReceiveType.OPEN_ID 89 | ) -> List[dict]: 90 | """ 91 | 根据手机号或邮箱号获取飞书用户的open_id 92 | API参考: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/contact-v3/user/batch_get_id 93 | 94 | Args: 95 | mobiles: 飞书用户手机号列表 96 | emails: 飞书用户邮箱列表 97 | user_id_type: 用户 ID 类型,默认 open_id 98 | 99 | Raises: 100 | ValueError 101 | 102 | Returns: user_list 103 | [ 104 | {"mobile": "130xxxx1752", "user_id": "ou_xxx"}, 105 | {"email": "liuminhui@fuzhi.ai", "user_id": "ou_xxx"} 106 | ] 107 | """ 108 | if not mobiles and not emails: 109 | raise ValueError("FeiShuAppServer _get_user_open_id error, 手机号或邮箱需必填一项") 110 | 111 | receiver_info = { 112 | "mobiles": mobiles, 113 | "emails": emails 114 | } 115 | 116 | headers = {"Authorization": f"Bearer {self._get_tenant_access_token()}"} 117 | resp = requests.post( 118 | url=self.GET_FEISHU_USER_OPEN_ID_URL, 119 | params={"user_id_type": user_id_type.value}, 120 | json=receiver_info, 121 | headers=headers, 122 | timeout=self.req_timeout 123 | ) 124 | ret_info = resp.json() 125 | if ret_info.get("code") != 0: 126 | raise ValueError( 127 | f"FeiShuAppServer _get_user_open_id error, mobiles is {mobiles}, emails is {emails}, {ret_info}" 128 | ) 129 | 130 | user_list = ret_info.get("data", {}).get("user_list") 131 | return user_list 132 | 133 | def _get_user_or_bot_groups(self, user_id_type: str = FeishuReceiveType.OPEN_ID.value, page_size: int = 100): 134 | """ 135 | 获取用户或机器人所在的群列表 136 | 参考API: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/chat/list 137 | 138 | Args: 139 | user_id_type: 用户id类型,默认 open_id 140 | page_size: 分页大小,默认每次拉取100 141 | 142 | Raises: 143 | ValueError 144 | 145 | Returns: all_groups 146 | """ 147 | 148 | all_groups = list() # 收集所有的群聊列表信息 149 | has_more = True # 是否有更多的数据 150 | while has_more: 151 | # 循环分页获取所有的群聊 152 | headers = {"Authorization": f"Bearer {self._get_tenant_access_token()}"} 153 | query_params = { 154 | "user_id_type": user_id_type, 155 | "page_size": page_size, 156 | } 157 | resp = requests.get( 158 | url=self.GET_FEISHU_USER_OF_GROUP_URL, 159 | params=query_params, 160 | headers=headers, 161 | timeout=self.req_timeout 162 | ) 163 | ret_info = resp.json() 164 | if ret_info.get("code") != 0: 165 | raise ValueError(f"FeiShuAppServer _get_user_chat_id error, {ret_info}") 166 | 167 | group_data = ret_info.get("data", {}) 168 | has_more = group_data.get("has_more") 169 | page_token = group_data.get("page_token") 170 | group_items = group_data.get("items", []) 171 | all_groups.extend(group_items) 172 | query_params["page_token"] = page_token # 继续获取分页数据时需要带上page_token 173 | 174 | return all_groups 175 | 176 | def send_msg(self, content: str, receive_id_type: FeishuReceiveType, receive_id: str): 177 | """ 178 | 发送飞书单聊、群聊信息 179 | API参考: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/create 180 | 181 | Args: 182 | content: 消息内容 183 | receive_id_type: 消息接收者id类型 open_id/user_id/union_id/email/chat_id 184 | receive_id: 消息接收者的ID,ID类型应与查询参数 receive_id_type 对应 185 | 186 | Raises: 187 | SendError 188 | 189 | Returns: 190 | """ 191 | msg_data = { 192 | "receive_id": receive_id, 193 | "msg_type": "text", 194 | "content": json.dumps({"text": content}, ensure_ascii=False) 195 | } 196 | headers = {"Authorization": f"Bearer {self._get_tenant_access_token()}"} 197 | try: 198 | resp = requests.post( 199 | url=self.NOTIFY_FEISHU_USER_MSG_URL, 200 | params={"receive_id_type": receive_id_type.value}, 201 | json=msg_data, 202 | headers=headers, 203 | timeout=self.req_timeout 204 | ) 205 | ret_info = resp.json() 206 | if ret_info.get("code") != 0: 207 | raise ValueError(f"FeiShuTaskChatBot user_task_notify error, {ret_info}") 208 | except Exception as e: 209 | raise SendMsgException(e) from e 210 | 211 | 212 | class FeiShuTaskChatBot(FeiShuAppServer): 213 | """ 214 | 飞书任务通知机器人 215 | 支持单聊通知、群聊通知 216 | """ 217 | 218 | def user_task_notify(self, content: str, receive_mobiles: list = None, receive_emails: list = None): 219 | """ 220 | 用户任务单聊通知 221 | 参考API: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/create 222 | 步骤: 223 | 1、通过 app_id、app_secret 换取 access_token 224 | 2、通过 access_token 和 手机号 查询飞书用户 open_id 225 | 3、最后通过 open_id 让机器人通知指定的用户 226 | 227 | Args: 228 | content: 通知的内容 229 | receive_mobiles: 用户接受者的飞书手机号列表 230 | receive_emails: 用户接受者的飞书邮箱号列表 231 | 232 | Returns: 233 | """ 234 | user_list = self._get_user_open_id(receive_mobiles, receive_emails) 235 | # 手机号和邮箱是同一用户的open_id会相同,故用set推导式去重 236 | open_ids = {user_item.get("user_id") for user_item in user_list if user_item.get("user_id")} 237 | for open_id in open_ids: 238 | # 给每个用户发送单聊通知 239 | try: 240 | self.send_msg(content, receive_id_type=FeishuReceiveType.OPEN_ID, receive_id=open_id) 241 | except Exception as e: 242 | logger.error(str(e)) 243 | continue 244 | 245 | def user_group_task_notify( 246 | self, content: str, 247 | group_name: str, 248 | receive_mobiles: list = None, 249 | receive_emails: list = None 250 | ): 251 | """ 252 | 用户任务群聊通知 253 | 参考API: https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/create 254 | 步骤: 255 | 1、通过 app_id、app_secret 换取 access_token 256 | 2、通过 access_token 和 手机号 查询飞书用户 open_id 257 | 3、通过 群聊名称过滤出 chat_id 258 | 4、最后通过群聊的 chat_id 和 用户的 open_id 通知指定群聊的用户 259 | 260 | Args: 261 | content: 通知的内容 262 | group_name: 群聊名称 263 | receive_mobiles: 用户接受者的飞书手机号列表 264 | receive_emails: 用户接受者的邮箱列表 265 | 266 | Raises: 267 | SendError 268 | 269 | Returns: 270 | """ 271 | 272 | # 根据群聊名称获取群聊机器人所在群组的 chat_id 273 | group_items = self._get_user_or_bot_groups() 274 | group_dict = {group_info.get("name"): group_info.get("chat_id") for group_info in group_items} 275 | chat_id = group_dict.get(group_name) 276 | if not chat_id: 277 | raise SendMsgException(f"未找到 {group_name} 的群聊") 278 | 279 | at_user_str = "" 280 | if receive_mobiles or receive_emails: 281 | # 需要 at 群内用户则通过手机号或邮箱获取用户open_id 282 | user_list = self._get_user_open_id(receive_mobiles, receive_emails) 283 | open_ids = {user_item.get("user_id") for user_item in user_list if user_item.get("user_id")} 284 | at_user_str = "".join([f'' for open_id in open_ids]) 285 | 286 | # 发送通知请求 287 | msg_content = f"{content}\n{at_user_str}" 288 | self.send_msg(msg_content, receive_id_type=FeishuReceiveType.CHAT_ID, receive_id=chat_id) 289 | -------------------------------------------------------------------------------- /py_tools/chatbot/chatbot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { webhook机器人模块 } 5 | # @Date: 2023/02/19 19:48 6 | import base64 7 | import hashlib 8 | import hmac 9 | import time 10 | from urllib.parse import quote_plus 11 | 12 | import requests 13 | 14 | from py_tools.exceptions.base import SendMsgException 15 | 16 | 17 | class BaseChatBot(object): 18 | """群聊机器人基类""" 19 | 20 | def __init__(self, webhook_url: str, secret: str = None): 21 | """ 22 | 初始化机器人 23 | Args: 24 | webhook_url: 机器人webhook地址 25 | secret: 安全密钥 26 | """ 27 | self.webhook_url = webhook_url 28 | self.secret = secret 29 | 30 | def _get_sign(self, timestamp: str, secret: str): 31 | """ 32 | 获取签名(NotImplemented) 33 | Args: 34 | timestamp: 签名时使用的时间戳 35 | secret: 签名时使用的密钥 36 | 37 | Returns: 38 | """ 39 | raise NotImplementedError 40 | 41 | def send_msg(self, content: str, timeout=10): 42 | """ 43 | 发送消息(NotImplemented) 44 | Args: 45 | content: 消息内容 46 | timeout: 发送消息请求超时时间 默认10秒 47 | 48 | Returns: 49 | """ 50 | raise NotImplementedError 51 | 52 | 53 | class FeiShuChatBot(BaseChatBot): 54 | """飞书机器人""" 55 | 56 | def _get_sign(self, timestamp: str, secret: str) -> str: 57 | """ 58 | 获取签名 59 | 把 timestamp + "\n" + 密钥 当做签名字符串,使用 HmacSHA256 算法计算签名,再进行 Base64 编码 60 | Args: 61 | timestamp: 签名时使用的时间戳 62 | secret: 签名时使用的密钥 63 | 64 | Returns: sign 65 | """ 66 | string_to_sign = "{}\n{}".format(timestamp, secret) 67 | hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() 68 | 69 | # 对结果进行base64处理 70 | sign = base64.b64encode(hmac_code).decode("utf-8") 71 | return sign 72 | 73 | def send_msg(self, content: str, timeout=10): 74 | """ 75 | 发送消息 76 | Args: 77 | content: 消息内容 78 | timeout: 发送消息请求超时时间 默认10秒 79 | 80 | Raises: 81 | SendMsgException 82 | 83 | Returns: 84 | """ 85 | msg_data = {"msg_type": "text", "content": {"text": f"{content}"}} 86 | if self.secret: 87 | timestamp = str(round(time.time())) 88 | sign = self._get_sign(timestamp=timestamp, secret=self.secret) 89 | msg_data["timestamp"] = timestamp 90 | msg_data["sign"] = sign 91 | 92 | try: 93 | resp = requests.post(url=self.webhook_url, json=msg_data, timeout=timeout) 94 | resp_info = resp.json() 95 | if resp_info.get("code") != 0: 96 | raise SendMsgException(f"FeiShuRobot send msg error, {resp_info}") 97 | return resp_info 98 | except Exception as e: 99 | raise SendMsgException(f"FeiShuRobot send msg error {e}") from e 100 | 101 | 102 | class DingTalkChatBot(BaseChatBot): 103 | """钉钉机器人""" 104 | 105 | def _get_sign(self, timestamp: str, secret: str): 106 | """ 107 | 获取签名 108 | 把 timestamp + "\n" + 密钥当做签名字符串,使用 HmacSHA256 算法计算签名, 109 | 然后进行 Base64 encode,最后再把签名参数再进行 urlEncode,得到最终的签名(需要使用UTF-8字符集) 110 | Args: 111 | timestamp: 签名时使用的时间戳 112 | secret: 签名时使用的密钥 113 | 114 | Returns: sign 115 | """ 116 | secret_enc = secret.encode("utf-8") 117 | string_to_sign = "{}\n{}".format(timestamp, secret) 118 | string_to_sign_enc = string_to_sign.encode("utf-8") 119 | hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest() 120 | sign = quote_plus(base64.b64encode(hmac_code)) 121 | 122 | return sign 123 | 124 | def send_msg(self, content: str, timeout=10): 125 | """ 126 | 发送消息 127 | Args: 128 | content: 消息内容 129 | timeout: 发送消息请求超时时间 默认10秒 130 | 131 | Raises: 132 | SendMsgException 133 | 134 | Returns: 135 | """ 136 | timestamp = str(round(time.time() * 1000)) 137 | sign = self._get_sign(timestamp=timestamp, secret=self.secret) 138 | 139 | params = {"timestamp": timestamp, "sign": sign} 140 | msg_data = {"msgtype": "text", "text": {"content": content}} 141 | try: 142 | resp = requests.post(url=self.webhook_url, json=msg_data, params=params, timeout=timeout) 143 | resp_info = resp.json() 144 | if resp_info.get("errcode") != 0: 145 | raise SendMsgException(f"DingTalkRobot send msg error, {resp_info}") 146 | return resp_info 147 | except Exception as e: 148 | raise SendMsgException(f"DingTalkRobot send msg error {e}") from e 149 | 150 | 151 | class WeComChatbot(BaseChatBot): 152 | """企业微信机器人""" 153 | 154 | def _get_sign(self, timestamp: str, secret: str): 155 | pass 156 | 157 | def send_msg(self, content: str, timeout=10): 158 | msg_data = {"msgtype": "text", "text": {"content": content}} 159 | resp = requests.post(self.webhook_url, json=msg_data) 160 | if resp.status_code != 200: 161 | raise SendMsgException("Failed to send message") 162 | return resp 163 | -------------------------------------------------------------------------------- /py_tools/chatbot/factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 机器人工厂模块 } 5 | # @Date: 2023/02/19 20:03 6 | from typing import Dict, Type, Union 7 | 8 | from py_tools.chatbot import BaseChatBot, DingTalkChatBot, FeiShuChatBot, WeComChatbot 9 | from py_tools.enums import StrEnum 10 | 11 | 12 | class ChatBotType(StrEnum): 13 | """群聊机器人类型枚举""" 14 | 15 | FEISHU_CHATBOT = "feishu" 16 | DINGTALK_CHATBOT = "dingtalk" 17 | WECOM_CHATBOT = "wecom" 18 | 19 | 20 | class ChatBotFactory(object): 21 | """ 22 | 消息机器人工厂 23 | 支持 飞书、钉钉、自定义机器人消息发送 24 | """ 25 | 26 | # 群聊机器人处理类映射 27 | CHATBOT_HANDLER_CLS_MAPPING: Dict[ChatBotType, Type[BaseChatBot]] = { 28 | ChatBotType.FEISHU_CHATBOT: FeiShuChatBot, 29 | ChatBotType.DINGTALK_CHATBOT: DingTalkChatBot, 30 | ChatBotType.WECOM_CHATBOT: WeComChatbot, 31 | } 32 | 33 | def __init__(self, chatbot_type: Union[str, ChatBotType]): 34 | if isinstance(chatbot_type, str): 35 | chatbot_type = ChatBotType(chatbot_type) 36 | 37 | if chatbot_type not in self.CHATBOT_HANDLER_CLS_MAPPING: 38 | raise ValueError(f"不支持 {chatbot_type} 类型的机器人") 39 | self.robot_type = chatbot_type 40 | 41 | def build(self, webhook_url: str, secret: str = None) -> BaseChatBot: 42 | """ 43 | 构造具体的机器人处理类 44 | Args: 45 | webhook_url: 机器人webhook地址 46 | secret: 机器人密钥 47 | 48 | Returns: 根据 robot_type 返回对应的机器人处理类 49 | 50 | """ 51 | chatbot_handle_cls = self.CHATBOT_HANDLER_CLS_MAPPING.get(self.robot_type) 52 | return chatbot_handle_cls(webhook_url=webhook_url, secret=secret) 53 | 54 | 55 | def main(): 56 | feishu_webhook = "xxx" 57 | feishu_webhook_secret = "xxx" 58 | 59 | dingtalk_webhook = "xxx" 60 | dingtalk_webhook_secret = "xxx" 61 | 62 | feishu_chatbot = ChatBotFactory(chatbot_type=ChatBotType.FEISHU_CHATBOT.value).build( 63 | webhook_url=feishu_webhook, secret=feishu_webhook_secret 64 | ) 65 | content = "飞书自定义机器人使用指南:\n https://open.feishu.cn/document/ukTMukTMukTM/ucTM5YjL3ETO24yNxkjN" 66 | feishu_chatbot.send_msg(content) 67 | 68 | dingtalk_chatbot = ChatBotFactory(chatbot_type=ChatBotType.DINGTALK_CHATBOT.value).build( 69 | webhook_url=dingtalk_webhook, secret=dingtalk_webhook_secret 70 | ) 71 | content = "钉钉自定义机器人使用指南:\n https://open.dingtalk.com/document/robots/custom-robot-access" 72 | dingtalk_chatbot.send_msg(content) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /py_tools/connections/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 连接处理 } 5 | # @Date: 2023/05/03 21:10 6 | -------------------------------------------------------------------------------- /py_tools/connections/db/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 数据库连接处理 } 5 | # @Date: 2023/05/03 21:11 6 | 7 | 8 | def main(): 9 | pass 10 | 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /py_tools/connections/db/mysql/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/08/17 23:54 6 | from py_tools.connections.db.mysql.orm_model import BaseOrmTable, BaseOrmTableWithTS 7 | from py_tools.connections.db.mysql.client import SQLAlchemyManager, DBManager 8 | 9 | __all__ = ["SQLAlchemyManager", "DBManager", "BaseOrmTable", "BaseOrmTableWithTS"] 10 | -------------------------------------------------------------------------------- /py_tools/connections/db/mysql/orm_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/08/17 23:55 6 | from datetime import datetime 7 | 8 | from sqlalchemy.ext.asyncio import AsyncAttrs 9 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 10 | 11 | 12 | class BaseOrmTable(AsyncAttrs, DeclarativeBase): 13 | """SQLAlchemy Base ORM Model""" 14 | 15 | __abstract__ = True 16 | 17 | id: Mapped[int] = mapped_column(primary_key=True, sort_order=-1, comment="主键ID") 18 | 19 | def __repr__(self): 20 | return f"<{self.__class__.__name__} {self.to_dict()}>" 21 | 22 | @classmethod 23 | def all_columns(cls): 24 | return [column for column in cls.__table__.columns] 25 | 26 | def to_dict(self, alias_dict: dict = None, exclude_none=False) -> dict: 27 | """ 28 | 数据库模型转成字典 29 | Args: 30 | alias_dict: 字段别名字典 31 | eg: {"id": "user_id"}, 把id名称替换成 user_id 32 | exclude_none: 默认排查None值 33 | Returns: dict 34 | """ 35 | alias_dict = alias_dict or {} 36 | if exclude_none: 37 | return { 38 | alias_dict.get(c.name, c.name): getattr(self, c.name) 39 | for c in self.all_columns() 40 | if getattr(self, c.name) is not None 41 | } 42 | else: 43 | return {alias_dict.get(c.name, c.name): getattr(self, c.name) for c in self.all_columns()} 44 | 45 | 46 | class TimestampColumns(AsyncAttrs, DeclarativeBase): 47 | """时间戳相关列""" 48 | 49 | __abstract__ = True 50 | 51 | created_at: Mapped[datetime] = mapped_column(default=datetime.now, comment="创建时间") 52 | 53 | updated_at: Mapped[datetime] = mapped_column(default=datetime.now, onupdate=datetime.now, comment="更新时间") 54 | 55 | deleted_at: Mapped[datetime] = mapped_column(nullable=True, comment="删除时间") 56 | 57 | 58 | class BaseOrmTableWithTS(BaseOrmTable, TimestampColumns): 59 | __abstract__ = True 60 | -------------------------------------------------------------------------------- /py_tools/connections/db/redis_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { redis连接处理模块 } 5 | # @Date: 2023/05/03 21:13 6 | from datetime import timedelta 7 | from typing import Optional, Union 8 | 9 | from redis import Redis 10 | from redis import asyncio as aioredis 11 | 12 | from py_tools import constants 13 | from py_tools.decorators.cache import AsyncRedisCacheProxy, CacheMeta, RedisCacheProxy, cache_json 14 | 15 | 16 | class BaseRedisManager: 17 | """Redis客户端管理器""" 18 | 19 | client: Union[Redis, aioredis.Redis] = None 20 | cache_key_prefix = constants.CACHE_KEY_PREFIX 21 | 22 | @classmethod 23 | def init_redis_client( 24 | cls, 25 | async_client: bool = False, 26 | host: str = "localhost", 27 | port: int = 6379, 28 | db: int = 0, 29 | password: Optional[str] = None, 30 | max_connections: Optional[int] = None, 31 | **kwargs 32 | ): 33 | """ 34 | 初始化 Redis 客户端。 35 | 36 | Args: 37 | async_client (bool): 是否使用异步客户端,默认为 False(同步客户端) 38 | host (str): Redis 服务器的主机名,默认为 'localhost' 39 | port (int): Redis 服务器的端口,默认为 6379 40 | db (int): 要连接的数据库编号,默认为 0 41 | password (Optional[str]): 密码可选 42 | max_connections (Optional[int]): 最大连接数。默认为 None(不限制连接数) 43 | **kwargs: 传递给 Redis 客户端的其他参数 44 | 45 | Returns: 46 | None 47 | """ 48 | if cls.client is None: 49 | redis_client_cls = Redis 50 | if async_client: 51 | redis_client_cls = aioredis.Redis 52 | 53 | cls.client = redis_client_cls( 54 | host=host, port=port, db=db, password=password, max_connections=max_connections, **kwargs 55 | ) 56 | 57 | return cls.client 58 | 59 | @classmethod 60 | def cache_json( 61 | cls, 62 | ttl: Union[int, timedelta] = 60, 63 | key_prefix: str = None, 64 | ): 65 | """ 66 | 缓存装饰器(仅支持缓存能够json序列化的数据) 67 | 缓存函数整体结果 68 | Args: 69 | ttl: 过期时间 默认60s 70 | key_prefix: 默认的key前缀, 再未指定key时使用 71 | 72 | Returns: 73 | """ 74 | key_prefix = key_prefix or cls.cache_key_prefix 75 | if isinstance(ttl, timedelta): 76 | ttl = int(ttl.total_seconds()) 77 | 78 | cache_proxy = RedisCacheProxy(cls.client) 79 | if isinstance(cls.client, aioredis.Redis): 80 | cache_proxy = AsyncRedisCacheProxy(cls.client) 81 | 82 | return cache_json(cache_proxy=cache_proxy, key_prefix=key_prefix, ttl=ttl) 83 | -------------------------------------------------------------------------------- /py_tools/connections/http/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/08/10 09:32 6 | from py_tools.connections.http.client import HttpClient, AsyncHttpClient, AsyncRequest 7 | 8 | __all__ = ["HttpClient", "AsyncHttpClient", "AsyncRequest"] 9 | -------------------------------------------------------------------------------- /py_tools/connections/http/client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { http客户端 } 5 | # @Date: 2023/08/10 09:33 6 | from datetime import timedelta 7 | from pathlib import Path 8 | from typing import Any, Union 9 | 10 | import aiohttp 11 | import requests 12 | from aiohttp import ClientResponse 13 | 14 | from py_tools.enums.http import HttpMethod 15 | from py_tools.utils.file_util import FileUtil 16 | 17 | 18 | class AsyncRequest: 19 | def __init__(self, client, method: HttpMethod, url, **kwargs): 20 | self.client = client 21 | self.method = method 22 | self.url = url 23 | self.params = kwargs.pop("params", None) 24 | self.data = kwargs.pop("data", None) 25 | self.timeout = kwargs.pop("timeout", None) 26 | self.headers = kwargs.pop("headers", None) 27 | self.kwargs = kwargs 28 | 29 | async def execute(self) -> ClientResponse: 30 | return await self.client._request( 31 | self.method, 32 | self.url, 33 | params=self.params, 34 | data=self.data, 35 | timeout=self.timeout, 36 | headers=self.headers, 37 | **self.kwargs, 38 | ) 39 | 40 | async def json(self): 41 | async with await self.execute() as response: 42 | return await response.json() 43 | 44 | async def text(self): 45 | async with await self.execute() as response: 46 | return await response.text() 47 | 48 | async def bytes(self): 49 | async with await self.execute() as response: 50 | return await response.read() 51 | 52 | async def stream(self, chunk_size=1024): 53 | async with await self.execute() as response: 54 | async for chunk in response.content.iter_chunked(chunk_size): 55 | yield chunk 56 | 57 | 58 | class AsyncHttpClient: 59 | """异步HTTP客户端(支持链式调用) 60 | 61 | 基于aiohttp封装,实现了常见的HTTP方法,支持设置超时时间、请求参数等,简化了异步调用的层级缩进。 62 | 63 | Examples: 64 | >>> url = "https://juejin.cn/" 65 | >>> resp = await AsyncHttpClient().get(url).execute() 66 | >>> text_data = await AsyncHttpClient().get(url, params={"test": "hui"}).text() 67 | >>> json_data = await AsyncHttpClient().post(url, data={"test": "hui"}).json() 68 | >>> byte_data = await AsyncHttpClient().get(url).bytes() 69 | >>> upload_file_ret = await AsyncHttpClient().upload_file(url, file="test.txt").json() 70 | >>> 71 | >>> async for chunk in AsyncHttpClient().get(url).stream(chunk_size=512): 72 | >>> # 流式调用 73 | >>> print(chunk) 74 | 75 | Attributes: 76 | default_timeout: 默认请求超时时间,单位秒 77 | default_headers: 默认请求头字典 78 | new_session: 是否使用的新的客户端,默认共享一个 ClientSession 79 | """ 80 | 81 | # aiohttp 异步客户端(全局共享) 82 | client_session: aiohttp.ClientSession = None 83 | client_session_set = set() 84 | 85 | def __init__(self, timeout=timedelta(seconds=10), headers: dict = None, new_session=False, **kwargs): 86 | """构造异步HTTP客户端""" 87 | self.default_timeout = aiohttp.ClientTimeout(timeout.total_seconds()) 88 | self.default_headers = headers or {} 89 | self.new_session = new_session 90 | self.cur_session: aiohttp.ClientSession = None 91 | self.kwargs = kwargs 92 | 93 | async def __aenter__(self): 94 | self.new_session = True 95 | return self 96 | 97 | async def __aexit__(self, exc_type, exc_val, exc_tb): 98 | await self._close_cur_session() 99 | 100 | async def _close_cur_session(self): 101 | if self.cur_session: 102 | await self.cur_session.close() 103 | if self.cur_session == AsyncHttpClient.client_session: 104 | AsyncHttpClient.client_session = None 105 | 106 | if self.cur_session in self.client_session_set: 107 | self.client_session_set.remove(self.cur_session) 108 | 109 | async def _get_client_session(self): 110 | if self.new_session: 111 | client_session = aiohttp.ClientSession( 112 | headers=self.default_headers, timeout=self.default_timeout, **self.kwargs 113 | ) 114 | self.client_session_set.add(client_session) 115 | return client_session 116 | 117 | if self.client_session is not None and not self.client_session.closed: 118 | return self.client_session 119 | 120 | AsyncHttpClient.client_session = aiohttp.ClientSession( 121 | headers=self.default_headers, timeout=self.default_timeout, **self.kwargs 122 | ) 123 | self.client_session_set.add(AsyncHttpClient.client_session) 124 | return self.client_session 125 | 126 | @classmethod 127 | async def close(cls): 128 | for client_session in cls.client_session_set: 129 | await client_session.close() 130 | 131 | cls.client_session_set.clear() 132 | cls.client_session = None 133 | 134 | async def _request( 135 | self, 136 | method: HttpMethod, 137 | url: str, 138 | params: dict = None, 139 | data: dict = None, 140 | timeout: timedelta = None, 141 | headers: dict = None, 142 | **kwargs, 143 | ): 144 | """内部请求实现方法 145 | 146 | 创建客户端会话,构造并发送HTTP请求,返回响应对象 147 | 148 | Args: 149 | method: HttpMethod 请求方法, 'GET', 'POST' 等 150 | url: 请求URL 151 | params: 请求查询字符串参数字典 152 | data: 请求体数据字典 153 | timeout: 超时时间,单位秒 154 | headers: 请求头 155 | kwargs: 其他关键字参数 156 | 157 | Returns: 158 | httpx.Response: HTTP响应对象 159 | """ 160 | timeout = timeout or self.default_timeout 161 | if isinstance(timeout, timedelta): 162 | timeout = aiohttp.ClientTimeout(timeout.total_seconds()) 163 | 164 | headers = headers or {} 165 | headers.update(self.default_headers) 166 | client_session = await self._get_client_session() 167 | self.cur_session = client_session 168 | return await client_session.request( 169 | method.value, url, params=params, data=data, timeout=timeout, headers=headers, **kwargs 170 | ) 171 | 172 | def get(self, url: str, params: dict = None, timeout: timedelta = None, **kwargs) -> AsyncRequest: 173 | """GET请求 174 | 175 | Args: 176 | url: 请求URL 177 | params: 请求查询字符串参数字典 178 | timeout: 请求超时时间,单位秒 179 | 180 | Returns: AsyncRequest 181 | """ 182 | 183 | return AsyncRequest(self, HttpMethod.GET, url, params=params, timeout=timeout, **kwargs) 184 | 185 | def post(self, url: str, data: Union[dict, Any] = None, timeout: timedelta = None, **kwargs) -> AsyncRequest: 186 | """POST请求 187 | 188 | Args: 189 | url: 请求URL 190 | data: 请求体数据字典 191 | timeout: 请求超时时间,单位秒 192 | 193 | Returns: AsyncRequest 194 | """ 195 | return AsyncRequest(self, HttpMethod.POST, url, data=data, timeout=timeout, **kwargs) 196 | 197 | def put(self, url: str, data: Union[dict, Any], timeout: timedelta = None, **kwargs): 198 | """PUT请求 199 | 200 | Args: 201 | url: 请求URL 202 | data: 请求体数据字典 203 | timeout: 请求超时时间,单位秒 204 | 205 | Returns: AsyncRequest 206 | """ 207 | return AsyncRequest(self, HttpMethod.PUT, url, data=data, timeout=timeout, **kwargs) 208 | 209 | def delete(self, url: str, data: Union[dict, Any], timeout: timedelta = None, **kwargs): 210 | """DELETE请求 211 | 212 | Args: 213 | url: 请求URL 214 | data: 请求体数据字典 215 | timeout: 请求超时时间,单位秒 216 | 217 | Returns: AsyncRequest 218 | """ 219 | return AsyncRequest(self, HttpMethod.DELETE, url, data=data, timeout=timeout, **kwargs) 220 | 221 | def upload_file( 222 | self, 223 | url: str, 224 | file: Union[str, bytes, Path], 225 | file_field: str = "file", 226 | filename: str = None, 227 | method=HttpMethod.POST, 228 | timeout: timedelta = None, 229 | content_type: str = None, 230 | **kwargs, 231 | ) -> AsyncRequest: 232 | """ 233 | 上传文件 234 | Args: 235 | url: 请求URL 236 | file: 文件路径 or 字节数据 237 | file_field: 文件参数字段 默认"file" 238 | filename: 文件名名称 239 | method: 请求方法,默认POST 240 | content_type: 内容类型 241 | timeout: 请求超时时间,单位秒 242 | 243 | Returns: AsyncRequest 244 | """ 245 | form_data = aiohttp.FormData() 246 | _filename, file_bytes, mime_type = FileUtil.get_file_info(file, filename=filename) 247 | filename = filename or _filename 248 | content_type = content_type or mime_type 249 | form_data.add_field(name=file_field, value=file_bytes, filename=filename, content_type=content_type) 250 | return AsyncRequest(self, method, url, data=form_data, timeout=timeout, **kwargs) 251 | 252 | 253 | class HttpClient: 254 | """同步HTTP客户端 255 | 256 | 通过request封装,实现了常见的HTTP方法,支持设置超时时间、请求参数等,链式调用 257 | 258 | Examples: 259 | >>> HttpClient().get("http://www.baidu.com").text 260 | >>> HttpClient().get("http://www.google.com", params={"name": "hui"}).bytes 261 | >>> HttpClient().post("http://www.google.com", data={"name": "hui"}).json 262 | 263 | Attributes: 264 | default_timeout: 默认请求超时时间,单位秒 265 | default_headers: 默认请求头字典 266 | client: request 客户端 267 | response: 每次实例请求的响应 268 | """ 269 | 270 | def __init__(self, timeout=timedelta(seconds=10), headers: dict = None): 271 | """构造异步HTTP客户端""" 272 | self.default_timeout = timeout 273 | self.default_headers = headers or {} 274 | self.client = requests.session() 275 | self.response: requests.Response = None 276 | 277 | def _request( 278 | self, method: HttpMethod, url: str, params: dict = None, data: dict = None, timeout: timedelta = None, **kwargs 279 | ): 280 | """内部请求实现方法 281 | 282 | 创建客户端会话,构造并发送HTTP请求,返回响应对象 283 | 284 | Args: 285 | method: HttpMethod 请求方法, 'GET', 'POST' 等 286 | url: 请求URL 287 | params: 请求查询字符串参数字典 288 | data: 请求体数据字典 289 | timeout: 超时时间,单位秒 290 | kwargs: 其他关键字参数 291 | 292 | Returns: 293 | httpx.Response: HTTP响应对象 294 | """ 295 | timeout = timeout or self.default_timeout 296 | headers = self.default_headers or {} 297 | self.response = self.client.request( 298 | method=method.value, 299 | url=url, 300 | params=params, 301 | data=data, 302 | headers=headers, 303 | timeout=timeout.total_seconds(), 304 | **kwargs, 305 | ) 306 | return self.response 307 | 308 | @property 309 | def json(self): 310 | return self.response.json() 311 | 312 | @property 313 | def bytes(self): 314 | return self.response.content 315 | 316 | @property 317 | def text(self): 318 | return self.response.text 319 | 320 | def get(self, url: str, params: dict = None, timeout: timedelta = None, **kwargs): 321 | """GET请求 322 | 323 | Args: 324 | url: 请求URL 325 | params: 请求查询字符串参数字典 326 | timeout: 请求超时时间,单位秒 327 | 328 | Returns: 329 | self 自身对象实例 330 | """ 331 | 332 | self._request(HttpMethod.GET, url, params=params, timeout=timeout, **kwargs) 333 | return self 334 | 335 | def post(self, url: str, data: dict = None, timeout: timedelta = None, **kwargs): 336 | """POST请求 337 | 338 | Args: 339 | url: 请求URL 340 | data: 请求体数据字典 341 | timeout: 请求超时时间,单位秒 342 | 343 | Returns: 344 | self 自身对象实例 345 | """ 346 | self._request(HttpMethod.POST, url, data=data, timeout=timeout, **kwargs) 347 | return self 348 | 349 | def put(self, url: str, data: dict = None, timeout: timedelta = None, **kwargs): 350 | """PUT请求 351 | 352 | Args: 353 | url: 请求URL 354 | data: 请求体数据字典 355 | timeout: 请求超时时间,单位秒 356 | 357 | Returns: 358 | self 自身对象实例 359 | """ 360 | self._request(HttpMethod.PUT, url, data=data, timeout=timeout, **kwargs) 361 | return self 362 | 363 | def delete(self, url: str, data: dict = None, timeout: timedelta = None, **kwargs): 364 | """DELETE请求 365 | 366 | Args: 367 | url: 请求URL 368 | data: 请求体数据字典 369 | timeout: 请求超时时间,单位秒 370 | 371 | Returns: 372 | self 自身对象实例 373 | """ 374 | self._request(HttpMethod.DELETE, url, data=data, timeout=timeout, **kwargs) 375 | return self 376 | -------------------------------------------------------------------------------- /py_tools/connections/mq/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 消息队列连接处理 } 5 | # @Date: 2023/05/03 21:12 6 | 7 | 8 | def main(): 9 | pass 10 | 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /py_tools/connections/mq/kafka_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { kafka连接处理模块 } 5 | # @Date: 2023/05/03 21:14 6 | 7 | 8 | def main(): 9 | pass 10 | 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /py_tools/connections/mq/rabbitmq_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { rabbitmq连接处理模块 } 5 | # @Date: 2023/05/03 21:13 6 | 7 | 8 | def main(): 9 | pass 10 | 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /py_tools/connections/oss/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/11/07 17:40 6 | 7 | 8 | def main(): 9 | pass 10 | 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /py_tools/connections/oss/minio_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { minio客户端模块 } 5 | # @Date: 2023/11/07 17:41 6 | from datetime import timedelta 7 | from io import BytesIO 8 | 9 | from minio import Minio 10 | 11 | 12 | class MinioClient(Minio): 13 | """ 14 | 自定义的 MinIO 客户端类,继承自 Minio 类。 15 | """ 16 | 17 | def __init__(self, endpoint, access_key, secret_key, secure=False, **kwargs): 18 | """ 19 | 初始化 MinioClient 对象。 20 | Args: 21 | endpoint: MinIO 服务器的终端节点 22 | access_key: MinIO 访问密钥 23 | secret_key: MinIO 秘密密钥 24 | secure: 是否使用安全连接,默认为 False 25 | **kwargs: 其他关键字参数 26 | """ 27 | self.endpoint = endpoint 28 | self.access_key = access_key 29 | self.secret_key = secret_key 30 | self.secure = secure 31 | 32 | super().__init__( 33 | endpoint=endpoint, 34 | access_key=access_key, 35 | secret_key=secret_key, 36 | secure=secure, 37 | **kwargs 38 | ) 39 | 40 | def put_object_get_sign_url( 41 | self, 42 | bucket_name: str, 43 | object_name: str, 44 | data: bytes, 45 | content_type: str = "application/octet-stream", 46 | sign_expires=timedelta(days=1), 47 | **kwargs 48 | ): 49 | """ 50 | 上传对象并获取预签名 URL。 51 | Args: 52 | bucket_name: 存储桶名称 53 | object_name: 对象名称 54 | data: 字节数据 55 | content_type: 对象的内容类型 56 | sign_expires: 对象签名过期时间,默认为 1 天 57 | **kwargs: 其他关键字参数 58 | 59 | Returns: 60 | 预签名 URL 61 | """ 62 | data_size = len(data) 63 | self.put_object(bucket_name, object_name, BytesIO(data), data_size, content_type, **kwargs) 64 | 65 | obj_sign_url = self.presigned_get_object(bucket_name, object_name, expires=sign_expires) 66 | return obj_sign_url 67 | -------------------------------------------------------------------------------- /py_tools/constants/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 常量模块 } 5 | # @Date: 2022/11/26 18:11 6 | from py_tools.constants.const import CACHE_KEY_PREFIX, BASE_DIR, DEMO_DIR, DEMO_DATA, TEST_DIR, PROJECT_DIR 7 | 8 | __all__ = ["CACHE_KEY_PREFIX", "BASE_DIR", "DEMO_DIR", "DEMO_DATA", "TEST_DIR", "PROJECT_DIR"] 9 | -------------------------------------------------------------------------------- /py_tools/constants/const.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: const.py 5 | # @Desc: { 常量模块 } 6 | # @Date: 2024/07/23 13:49 7 | from pathlib import Path 8 | 9 | # 默认的缓存key前缀 10 | CACHE_KEY_PREFIX = "py-tools" 11 | 12 | # 项目基准目录 13 | BASE_DIR = Path(__file__).parent.parent.parent 14 | 15 | # 案例目录 16 | DEMO_DIR = BASE_DIR / "demo" 17 | 18 | # 案例数据目录 19 | DEMO_DATA = DEMO_DIR / "data" 20 | 21 | # 项目源代码目录 22 | PROJECT_DIR = BASE_DIR / "py_tools" 23 | 24 | # 测试目录 25 | TEST_DIR = BASE_DIR / "tests" 26 | -------------------------------------------------------------------------------- /py_tools/data_schemas/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/05/07 13:39 6 | -------------------------------------------------------------------------------- /py_tools/data_schemas/time.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 时间数据模型 } 5 | # @Date: 2023/04/30 23:35 6 | from pydantic import BaseModel 7 | 8 | 9 | class DateDiff(BaseModel): 10 | years: int 11 | months: int 12 | days: int 13 | hours: int 14 | minutes: int 15 | seconds: int 16 | -------------------------------------------------------------------------------- /py_tools/data_schemas/unit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: unit.py 5 | # @Desc: { 单位数据模型 } 6 | # @Date: 2024/04/24 11:05 7 | import re 8 | 9 | 10 | class ByteUnit: 11 | """字节大小单位""" 12 | 13 | B = 1 # 字节 14 | KB = 1024 * B # 千字节 (1 KB = 1024 字节) 15 | MB = 1024 * KB # 兆字节 (1 MB = 1024 KB) 16 | GB = 1024 * MB # 吉字节 (1 GB = 1024 MB) 17 | TB = 1024 * GB # 太字节 (1 TB = 1024 GB) 18 | PB = 1024 * TB # 拍字节 (1 PB = 1024 TB) 19 | EB = 1024 * PB # 艾字节 (1 EB = 1024 PB) 20 | ZB = 1024 * EB # 泽字节 (1 ZB = 1024 EB) 21 | YB = 1024 * ZB # 尧字节 (1 YB = 1024 ZB) 22 | 23 | @classmethod 24 | def convert_size_to_bytes(cls, str_size): 25 | """ 26 | 转换字符串数据大小为字节大小 27 | Args: 28 | str_size: 数据大小, eg. 1B, 1kb, 1MB, 1GB 29 | 30 | Returns: int 31 | """ 32 | str_size = str_size.strip().upper() 33 | match_ret = re.match(r"^(\d+)(\w+)$", str_size) 34 | if match_ret: 35 | num = int(match_ret.group(1)) 36 | unit = match_ret.group(2) 37 | unit_size = getattr(cls, unit) 38 | return num * unit_size 39 | raise ValueError(f"Invalid size format: {str_size}") 40 | 41 | 42 | class LengthUnit: 43 | """长度单位""" 44 | 45 | MM = 1 # 毫米 46 | CM = 10 * MM # 厘米 47 | DM = 10 * CM # 分米 48 | M = 10 * DM # 米 49 | KM = 1000 * M # 千米 50 | INCH = 25.4 * MM # 英寸 (1 英寸 = 25.4 毫米) 51 | FOOT = 12 * INCH # 英尺 (1 英尺 = 12 英寸) 52 | YARD = 3 * FOOT # 码 (1 码 = 3 英尺) 53 | MILE = 1760 * YARD # 英里 (1 英里 = 1760 码) 54 | 55 | 56 | class WeightUnit: 57 | """重量单位""" 58 | 59 | MG = 1 # 毫克 60 | G = 1000 * MG # 克 61 | KG = 1000 * G # 千克 62 | TONNE = 1000 * KG # 公吨 63 | OUNCE = 28.3495 * G # 盎司 (1 盎司 = 28.3495 克) 64 | POUND = 16 * OUNCE # 磅 (1 磅 = 16 盎司) 65 | STONE = 14 * POUND # 英石 (1 英石 = 14 磅) 66 | TON = 2000 * POUND # 短吨 (1 短吨 = 2000 磅) 67 | 68 | 69 | class VolumeUnit: 70 | """容量单位""" 71 | 72 | ML = 1 # 毫升 73 | L = 1000 * ML # 升 74 | CUBIC_METER = 1000 * L # 立方米 75 | CUBIC_INCH = 16.387064 * ML # 立方英寸 76 | CUBIC_FOOT = 28.3168466 * L # 立方英尺 77 | 78 | 79 | class SpeedUnit: 80 | """速度单位""" 81 | 82 | M_S = 1 # 米每秒 83 | KM_H = 3.6 * M_S # 千米每小时 84 | MPH = 0.44704 * M_S # 英里每小时 85 | 86 | 87 | class PowerUnit: 88 | """功率单位""" 89 | 90 | WATT = 1 # 瓦特 91 | KW = 1000 * WATT # 千瓦 92 | HP = 735.49875 * WATT # 马力 (1 马力 ≈ 735.49875 瓦特) 93 | 94 | 95 | class TimeUnit: 96 | """时间单位""" 97 | 98 | SECOND = 1 # 秒 99 | MINUTE = 60 * SECOND # 分钟 100 | HOUR = 60 * MINUTE # 小时 101 | DAY = 24 * HOUR # 天 102 | WEEK = 7 * DAY # 周 103 | MONTH = 30.44 * DAY # 月 (平均天数) 104 | YEAR = 365.24 * DAY # 年 (平均天数) 105 | DECADE = 10 * YEAR # 十年 106 | CENTURY = 100 * YEAR # 世纪 107 | -------------------------------------------------------------------------------- /py_tools/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 装饰器包模块 } 5 | # @Date: 2022/11/26 16:15 6 | from py_tools.decorators.base import retry, timing, set_timeout, singleton, synchronized, run_on_executor 7 | 8 | __all__ = ["singleton", "synchronized", "run_on_executor", "retry", "timing", "set_timeout"] 9 | -------------------------------------------------------------------------------- /py_tools/decorators/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 通用装饰器模块 } 5 | # @Date: 2022/11/26 16:16 6 | import asyncio 7 | import functools 8 | import signal 9 | import threading 10 | import time 11 | import traceback 12 | from asyncio import iscoroutinefunction 13 | from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError 14 | from datetime import datetime 15 | from typing import Callable, Type 16 | 17 | from loguru import logger 18 | 19 | from py_tools.exceptions import MaxRetryException, MaxTimeoutException 20 | 21 | 22 | def synchronized(func): 23 | """同步锁装饰器""" 24 | func.__lock__ = threading.Lock() 25 | 26 | @functools.wraps(func) 27 | def lock_func(*args, **kwargs): 28 | with func.__lock__: 29 | return func(*args, **kwargs) 30 | 31 | return lock_func 32 | 33 | 34 | def singleton(cls_obj): 35 | """单例装饰器""" 36 | _instance_dic = {} 37 | _instance_lock = threading.Lock() 38 | 39 | @functools.wraps(cls_obj) 40 | def wrapper(*args, **kwargs): 41 | if cls_obj in _instance_dic: 42 | return _instance_dic.get(cls_obj) 43 | 44 | with _instance_lock: 45 | if cls_obj not in _instance_dic: 46 | _instance_dic[cls_obj] = cls_obj(*args, **kwargs) 47 | return _instance_dic.get(cls_obj) 48 | 49 | return wrapper 50 | 51 | 52 | def timing(method): 53 | """ 54 | 例子: 55 | 56 | @timing 57 | def my_func(): 58 | pass 59 | """ 60 | 61 | def before_call(): 62 | start_time, cpu_start_time = time.perf_counter(), time.process_time() 63 | logger.info(f"[{method.__name__}] started at: " f"{datetime.now().strftime('%Y-%m-%d %H:%m:%S')}") 64 | return start_time, cpu_start_time 65 | 66 | def after_call(start_time, cpu_start_time): 67 | end_time, cpu_end_time = time.perf_counter(), time.process_time() 68 | logger.info( 69 | f"[{method.__name__}] ended. " 70 | f"Time elapsed: {end_time - start_time:.4} sec, CPU elapsed: {cpu_end_time - cpu_start_time:.4} sec" 71 | ) 72 | 73 | @functools.wraps(method) 74 | def timeit_wrapper(*args, **kwargs): 75 | start_time, cpu_start_time = before_call() 76 | result = method(*args, **kwargs) 77 | after_call(start_time, cpu_start_time) 78 | return result 79 | 80 | @functools.wraps(method) 81 | async def timeit_wrapper_async(*args, **kwargs): 82 | start_time, cpu_start_time = before_call() 83 | result = await method(*args, **kwargs) 84 | after_call(start_time, cpu_start_time) 85 | return result 86 | 87 | return timeit_wrapper_async if iscoroutinefunction(method) else timeit_wrapper 88 | 89 | 90 | def set_timeout(timeout: int, use_signal=False): 91 | """ 92 | 超时处理装饰器 93 | Args: 94 | timeout: 超时时间,单位秒 95 | use_signal: 使用信号量机制只能在 unix内核上使用,默认False 96 | 97 | Raises: 98 | TimeoutException 99 | 100 | """ 101 | 102 | def _timeout(func: Callable): 103 | def _handle_timeout(signum, frame): 104 | raise MaxTimeoutException(f"Function timed out after {timeout} seconds") 105 | 106 | @functools.wraps(func) 107 | def sync_wrapper(*args, **kwargs): 108 | # 同步函数处理超时 109 | if use_signal: 110 | # 使用信号量计算超时 111 | signal.signal(signal.SIGALRM, _handle_timeout) 112 | signal.alarm(timeout) 113 | try: 114 | return func(*args, **kwargs) 115 | finally: 116 | signal.alarm(0) 117 | else: 118 | # 使用线程 119 | with ThreadPoolExecutor() as executor: 120 | future = executor.submit(func, *args, **kwargs) 121 | try: 122 | return future.result(timeout) 123 | except TimeoutError: 124 | raise MaxTimeoutException(f"Function timed out after {timeout} seconds") 125 | 126 | @functools.wraps(func) 127 | async def async_wrapper(*args, **kwargs): 128 | # 异步函数处理超时 129 | try: 130 | ret = await asyncio.wait_for(func(*args, **kwargs), timeout) 131 | return ret 132 | except asyncio.TimeoutError: 133 | raise MaxTimeoutException(f"Function timed out after {timeout} seconds") 134 | 135 | return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper 136 | 137 | return _timeout 138 | 139 | 140 | def retry(max_count: int = 5, interval: int = 2, catch_exc: Type[BaseException] = Exception): 141 | """ 142 | 重试装饰器 143 | Args: 144 | max_count: 最大重试次数 默认 5 次 145 | interval: 每次异常重试间隔 默认 2s 146 | catch_exc: 指定捕获的异常类用于特定的异常重试 默认捕获 Exception 147 | 148 | Raises: 149 | MaxRetryException 150 | """ 151 | 152 | def _retry(task_func): 153 | @functools.wraps(task_func) 154 | def sync_wrapper(*args, **kwargs): 155 | # 函数循环重试 156 | 157 | for retry_count in range(max_count): 158 | logger.info(f"{task_func} execute count {retry_count + 1}") 159 | try: 160 | return task_func(*args, **kwargs) 161 | except catch_exc: 162 | logger.error(f"fail {traceback.print_exc()}") 163 | if retry_count < max_count - 1: 164 | # 最后一次异常不等待 165 | time.sleep(interval) 166 | 167 | # 超过最大重试次数, 抛异常终止 168 | raise MaxRetryException(f"超过最大重试次数失败, max_retry_count {max_count}") 169 | 170 | @functools.wraps(task_func) 171 | async def async_wrapper(*args, **kwargs): 172 | # 异步循环重试 173 | for retry_count in range(max_count): 174 | logger.info(f"{task_func} execute count {retry_count + 1}") 175 | 176 | try: 177 | return await task_func(*args, **kwargs) 178 | except catch_exc as e: 179 | logger.error(f"fail {str(e)}") 180 | if retry_count < max_count - 1: 181 | await asyncio.sleep(interval) 182 | 183 | # 超过最大重试次数, 抛异常终止 184 | raise MaxRetryException(f"超过最大重试次数失败, max_retry_count {max_count}") 185 | 186 | # 异步函数判断 187 | wrapper_func = async_wrapper if asyncio.iscoroutinefunction(task_func) else sync_wrapper 188 | return wrapper_func 189 | 190 | return _retry 191 | 192 | 193 | def run_on_executor(executor: Executor = None, background: bool = False): 194 | """ 195 | 异步装饰器 196 | - 支持同步函数使用 executor 加速 197 | - 异步函数和同步函数都可以使用 `await` 语法等待返回结果 198 | - 异步函数和同步函数都支持后台任务,无需等待 199 | Args: 200 | executor: 函数执行器, 装饰同步函数的时候使用 201 | background: 是否后台执行,默认False 202 | 203 | Returns: 204 | """ 205 | 206 | def _run_on_executor(func): 207 | @functools.wraps(func) 208 | async def async_wrapper(*args, **kwargs): 209 | if background: 210 | return asyncio.create_task(func(*args, **kwargs)) 211 | else: 212 | return await func(*args, **kwargs) 213 | 214 | @functools.wraps(func) 215 | def sync_wrapper(*args, **kwargs): 216 | loop = asyncio.get_event_loop() 217 | task_func = functools.partial(func, *args, **kwargs) # 支持关键字参数 218 | return loop.run_in_executor(executor, task_func) 219 | 220 | # 异步函数判断 221 | wrapper_func = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper 222 | return wrapper_func 223 | 224 | return _run_on_executor 225 | -------------------------------------------------------------------------------- /py_tools/decorators/cache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 缓存装饰器模块 } 5 | # @Date: 2023/05/03 19:23 6 | import asyncio 7 | import functools 8 | import hashlib 9 | import json 10 | from datetime import timedelta 11 | from typing import Union 12 | 13 | import cacheout 14 | import memcache 15 | from pydantic import BaseModel, Field 16 | from redis import Redis 17 | from redis import asyncio as aioredis 18 | 19 | from py_tools import constants 20 | 21 | 22 | class CacheMeta(BaseModel): 23 | """缓存元信息""" 24 | 25 | key: str = Field(description="缓存的key") 26 | ttl: Union[int, timedelta] = Field(description="缓存有效期") 27 | cache_client: str = Field(description="缓存的客户端(Redis、Memcached等)") 28 | data_type: str = Field(description="缓存的数据类型(str、list、hash、set)") 29 | 30 | 31 | class BaseCacheProxy(object): 32 | """缓存代理基类""" 33 | 34 | def __init__(self, cache_client): 35 | self.cache_client = cache_client # 具体的缓存客户端,例如Redis、Memcached等 36 | 37 | def set(self, key: str, value: str, ttl: int): 38 | raise NotImplementedError 39 | 40 | def get(self, key): 41 | cache_data = self.cache_client.get(key) 42 | return cache_data 43 | 44 | 45 | class RedisCacheProxy(BaseCacheProxy): 46 | """同步redis缓存代理""" 47 | 48 | def __init__(self, cache_client: Redis): 49 | super().__init__(cache_client) 50 | 51 | def set(self, key, value, ttl): 52 | self.cache_client.setex(name=key, value=value, time=ttl) 53 | 54 | 55 | class AsyncRedisCacheProxy(BaseCacheProxy): 56 | """异步Redis缓存代理""" 57 | 58 | def __init__(self, cache_client: aioredis.Redis): 59 | super().__init__(cache_client) 60 | 61 | async def set(self, key, value, ttl): 62 | await self.cache_client.setex(name=key, value=value, time=ttl) 63 | 64 | async def get(self, key): 65 | cache_data = await self.cache_client.get(key) 66 | return cache_data 67 | 68 | 69 | class MemoryCacheProxy(BaseCacheProxy): 70 | """系统内存缓存代理""" 71 | 72 | def __init__(self, cache_client: cacheout.Cache): 73 | super().__init__(cache_client) 74 | 75 | def set(self, key, value, ttl): 76 | self.cache_client.set(key=key, value=value, ttl=ttl) 77 | 78 | 79 | MEMORY_PROXY = MemoryCacheProxy(cache_client=cacheout.Cache(maxsize=1024)) 80 | 81 | 82 | class MemcacheCacheProxy(BaseCacheProxy): 83 | 84 | def __init__(self, cache_client: memcache.Client): 85 | super().__init__(cache_client) 86 | 87 | def set(self, key, value, ttl): 88 | self.cache_client.set(key, value, time=ttl) 89 | 90 | 91 | def cache_json( 92 | cache_proxy: BaseCacheProxy = MEMORY_PROXY, 93 | key_prefix: str = constants.CACHE_KEY_PREFIX, 94 | ttl: Union[int, timedelta] = 60, 95 | ): 96 | """ 97 | 缓存装饰器(仅支持缓存能够json序列化的数据) 98 | Args: 99 | cache_proxy: 缓存代理客户端, 默认系统内存 100 | ttl: 过期时间 默认60s 101 | key_prefix: 默认的key前缀 102 | 103 | Returns: 104 | """ 105 | key_prefix = f"{key_prefix}:cache_json" 106 | if isinstance(ttl, timedelta): 107 | ttl = int(ttl.total_seconds()) 108 | 109 | def _cache(func): 110 | def _gen_key(*args, **kwargs): 111 | """生成缓存的key""" 112 | 113 | # 根据函数信息与参数生成 114 | # key => 函数所在模块:函数名:函数位置参数:函数关键字参数 进行hash 115 | param_args_str = ",".join([str(arg) for arg in args]) 116 | param_kwargs_str = ",".join(sorted([f"{k}:{v}" for k, v in kwargs.items()])) 117 | hash_str = f"{func.__module__}:{func.__name__}:{param_args_str}:{param_kwargs_str}" 118 | hash_ret = hashlib.sha256(hash_str.encode()).hexdigest() 119 | 120 | # 根据哈希结果生成key 默认前缀:函数所在模块:函数名:hash 121 | hash_key = f"{key_prefix}:{func.__module__}:{func.__name__}:{hash_ret}" 122 | return hash_key 123 | 124 | @functools.wraps(func) 125 | def sync_wrapper(*args, **kwargs): 126 | """同步处理""" 127 | 128 | # 生成缓存的key 129 | hash_key = _gen_key(*args, **kwargs) 130 | 131 | # 先从缓存获取数据 132 | cache_data = cache_proxy.get(hash_key) 133 | if cache_data: 134 | # 有直接返回 135 | print(f"命中缓存: {hash_key}") 136 | return json.loads(cache_data) 137 | 138 | # 没有,执行函数获取结果 139 | ret = func(*args, **kwargs) 140 | 141 | # 缓存结果 142 | cache_proxy.set(key=hash_key, value=json.dumps(ret), ttl=ttl) 143 | return ret 144 | 145 | @functools.wraps(func) 146 | async def async_wrapper(*args, **kwargs): 147 | """异步处理""" 148 | 149 | # 生成缓存的key 150 | hash_key = _gen_key(*args, **kwargs) 151 | 152 | # 先从缓存获取数据 153 | cache_data = await cache_proxy.get(hash_key) 154 | if cache_data: 155 | # 有直接返回 156 | return json.loads(cache_data) 157 | 158 | # 没有,执行函数获取结果 159 | ret = await func(*args, **kwargs) 160 | 161 | # 缓存结果 162 | await cache_proxy.set(key=hash_key, value=json.dumps(ret), ttl=ttl) 163 | return ret 164 | 165 | return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper 166 | 167 | return _cache 168 | -------------------------------------------------------------------------------- /py_tools/enums/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 常量模块 } 5 | # @Date: 2022/11/26 18:10 6 | from py_tools.enums.base import BaseEnum, StrEnum, IntEnum 7 | 8 | __all__ = ["BaseEnum", "StrEnum", "IntEnum"] 9 | -------------------------------------------------------------------------------- /py_tools/enums/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/04/30 20:54 6 | from enum import Enum 7 | 8 | 9 | class BaseEnum(Enum): 10 | """枚举基类""" 11 | 12 | def __new__(cls, value, desc=None): 13 | """ 14 | 构造枚举成员实例 15 | Args: 16 | value: 枚举成员的值 17 | desc: 枚举成员的描述信息,默认None 18 | """ 19 | if issubclass(cls, int): 20 | obj = int.__new__(cls, value) 21 | elif issubclass(cls, str): 22 | obj = str.__new__(cls, value) 23 | else: 24 | obj = object.__new__(cls) 25 | obj._value_ = value 26 | obj.desc = desc 27 | return obj 28 | 29 | @classmethod 30 | def get_members(cls, exclude_enums: list = None, only_value: bool = False, only_desc: bool = False) -> list: 31 | """ 32 | 获取枚举的所有成员 33 | Args: 34 | exclude_enums: 排除的枚举类列表 35 | only_value: 只需要成员的值,默认False 36 | only_desc: 只需要成员的desc,默认False 37 | 38 | Returns: 枚举成员列表 or 枚举成员值列表 39 | 40 | """ 41 | members = list(cls) 42 | if exclude_enums: 43 | # 排除指定枚举 44 | members = [member for member in members if member not in exclude_enums] 45 | 46 | if only_value: 47 | # 只需要成员的值 48 | members = [member.value for member in members] 49 | return members 50 | 51 | if only_desc: 52 | # 只需要成员的desc 53 | members = [member.desc for member in members] 54 | return members 55 | 56 | return members 57 | 58 | @classmethod 59 | def get_values(cls, exclude_enums: list = None): 60 | return cls.get_members(exclude_enums=exclude_enums, only_value=True) 61 | 62 | @classmethod 63 | def get_names(cls): 64 | return list(cls._member_names_) 65 | 66 | @classmethod 67 | def get_desc(cls, exclude_enums: list = None): 68 | return cls.get_members(exclude_enums=exclude_enums, only_desc=True) 69 | 70 | @classmethod 71 | def get_member_by_desc(cls, enum_desc, only_value: bool = False): 72 | members = cls.get_members() 73 | member_dict = {member.desc: member for member in members} 74 | member = member_dict.get(enum_desc) 75 | return member.value if only_value else member 76 | 77 | 78 | class StrEnum(str, BaseEnum): 79 | """字符串枚举""" 80 | 81 | pass 82 | 83 | 84 | class IntEnum(int, BaseEnum): 85 | """整型枚举""" 86 | 87 | pass 88 | -------------------------------------------------------------------------------- /py_tools/enums/error.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 错误码枚举 } 5 | # @Date: 2023/09/09 14:45 6 | 7 | 8 | class BaseErrCode: 9 | def __init__(self, code, msg): 10 | self.code = code 11 | self.msg = msg 12 | 13 | 14 | class BaseErrCodeEnum: 15 | """ 16 | Notes:由于枚举不能继承成员故改成普通类方式 17 | 错误码前缀 18 | - 000-通用基础错误码前缀 19 | - 100-待定 20 | - 200-通用业务错误码前缀 21 | eg: 22 | - 201-用户模块 23 | - 202-订单模块 24 | - 300-待定 25 | - 400-通用请求错误 26 | - 500-通用系统错误码前缀 27 | """ 28 | 29 | OK = BaseErrCode("000-0000", "SUCCESS") 30 | FAILED = BaseErrCode("000-0001", "FAILED") 31 | FUNC_TIMEOUT_ERR = BaseErrCode("000-0002", "函数最大超时错误") 32 | FUNC_RETRY_ERR = BaseErrCode("000-0003", "函数最大重试错误") 33 | SEND_SMS_ERR = BaseErrCode("000-0004", "发送短信错误") 34 | SEND_EMAIL_ERR = BaseErrCode("000-0005", "发送邮件错误") 35 | 36 | AUTH_ERR = BaseErrCode("400-0401", "权限认证错误") 37 | FORBIDDEN_ERR = BaseErrCode("400-0403", "无权限访问") 38 | NOT_FOUND_ERR = BaseErrCode("400-0404", "未找到资源错误") 39 | PARAM_ERR = BaseErrCode("400-0422", "参数错误") 40 | 41 | SYSTEM_ERR = BaseErrCode("500-0500", "系统异常") 42 | SOCKET_ERR = BaseErrCode("500-0501", "网络异常") 43 | GATEWAY_ERR = BaseErrCode("500-0502", "网关异常") 44 | -------------------------------------------------------------------------------- /py_tools/enums/feishu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 飞书相关枚举 } 5 | # @Date: 2023/05/03 18:52 6 | from py_tools.enums.base import BaseEnum 7 | 8 | 9 | class FeishuReceiveType(BaseEnum): 10 | """消息接收者id类型""" 11 | 12 | OPEN_ID = "open_id" # 标识一个用户在某个应用中的身份 13 | USER_ID = "user_id" # 标识一个用户在某个租户内的身份 14 | UNION_ID = "union_id" # 标识一个用户在某个应用开发商下的身份 15 | EMAIL = "email" # 以用户的真实邮箱来标识用户 16 | CHAT_ID = "chat_id" # 以群ID来标识群聊 17 | -------------------------------------------------------------------------------- /py_tools/enums/http.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { http相关枚举 } 5 | # @Date: 2023/08/10 09:37 6 | 7 | from py_tools.enums.base import BaseEnum 8 | 9 | 10 | class HttpMethod(BaseEnum): 11 | GET = "GET" 12 | POST = "POST" 13 | PATCH = "PATCH" 14 | PUT = "PUT" 15 | DELETE = "DELETE" 16 | HEAD = "HEAD" 17 | OPTIONS = "OPTIONS" 18 | -------------------------------------------------------------------------------- /py_tools/enums/pub_biz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 公用业务枚举 } 5 | # @Date: 2023/09/09 23:54 6 | from py_tools.enums import BaseEnum 7 | 8 | 9 | class SwitchEnum(BaseEnum): 10 | """开关枚举""" 11 | OFF = 0 # 关 12 | ON = 1 # 开 13 | 14 | 15 | class YesNoEnum(BaseEnum): 16 | """开关枚举""" 17 | NO = 0 # 否 18 | YES = 1 # 是 19 | 20 | 21 | class RedisTypeEnum(BaseEnum): 22 | """Redis 数据类型""" 23 | 24 | String = "String" 25 | List = "List" 26 | Hash = "Hash" 27 | Set = "Set" 28 | ZSet = "ZSet" 29 | -------------------------------------------------------------------------------- /py_tools/enums/time.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/02/12 21:38 6 | from py_tools.enums.base import StrEnum 7 | 8 | 9 | class TimeFormatEnum(StrEnum): 10 | """时间格式化枚举""" 11 | DateTime = "%Y-%m-%d %H:%M:%S" 12 | DateOnly = "%Y-%m-%d" 13 | TimeOnly = "%H:%M:%S" 14 | 15 | DateTime_CN = "%Y年%m月%d日 %H时%M分%S秒" 16 | DateOnly_CN = "%Y年%m月%d日" 17 | TimeOnly_CN = "%H时%M分%S秒" 18 | 19 | 20 | class TimeUnitEnum(StrEnum): 21 | """时间单位枚举""" 22 | DAYS = "days" 23 | HOURS = "hours" 24 | MINUTES = "minutes" 25 | SECONDS = "seconds" 26 | -------------------------------------------------------------------------------- /py_tools/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 自定义异常包 } 5 | # @Date: 2023/02/12 22:07 6 | from py_tools.exceptions.base import ( 7 | MaxTimeoutException, 8 | SendMsgException, 9 | MaxRetryException, 10 | BizException, 11 | CommonException, 12 | ) 13 | 14 | __all__ = ["MaxTimeoutException", "SendMsgException", "MaxRetryException", "BizException", "CommonException"] 15 | -------------------------------------------------------------------------------- /py_tools/exceptions/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/02/12 22:09 6 | from py_tools.enums.error import BaseErrCode, BaseErrCodeEnum 7 | 8 | 9 | class CommonException(Exception): 10 | """通用异常""" 11 | 12 | pass 13 | 14 | 15 | class BizException(CommonException): 16 | """业务异常""" 17 | 18 | def __init__(self, msg: str = "", code: str = BaseErrCodeEnum.FAILED.code, err_code: BaseErrCode = None): 19 | self.code = code 20 | self.msg = msg 21 | 22 | if err_code: 23 | self.code = err_code.code 24 | self.msg = self.msg or err_code.msg 25 | 26 | 27 | class MaxRetryException(BizException): 28 | """最大重试次数异常""" 29 | 30 | def __init__(self, msg: str = BaseErrCodeEnum.FUNC_TIMEOUT_ERR.msg): 31 | super().__init__(msg=msg, err_code=BaseErrCodeEnum.FUNC_RETRY_ERR) 32 | 33 | 34 | class MaxTimeoutException(BizException): 35 | """最大超时异常""" 36 | 37 | def __init__(self, msg: str = BaseErrCodeEnum.FUNC_TIMEOUT_ERR.msg): 38 | super().__init__(msg=msg, err_code=BaseErrCodeEnum.FUNC_TIMEOUT_ERR) 39 | 40 | 41 | class SendMsgException(BizException): 42 | """发送消息异常""" 43 | 44 | pass 45 | -------------------------------------------------------------------------------- /py_tools/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from py_tools.logging.base import setup_logging 3 | 4 | __all__ = ["logger", "setup_logging"] 5 | -------------------------------------------------------------------------------- /py_tools/logging/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: base.py 5 | # @Desc: { 日志配置相关函数 } 6 | # @Date: 2024/08/12 11:12 7 | import logging 8 | from pathlib import Path 9 | from typing import Type, Union 10 | 11 | from py_tools.logging import logger 12 | from py_tools.logging.default_logging_conf import ( 13 | default_logging_conf, 14 | server_logging_retention, 15 | server_logging_rotation, 16 | ) 17 | from py_tools.utils.func_util import add_param_if_true 18 | 19 | 20 | def setup_logging( 21 | log_dir: Union[str, Path] = None, 22 | *, 23 | log_conf: dict = None, 24 | sink: Union[str, Path] = None, 25 | log_level: Union[str, int] = None, 26 | console_log_level: Union[str, int] = logging.DEBUG, 27 | log_format: str = None, 28 | log_filter: Type[callable] = None, 29 | log_rotation: str = server_logging_rotation, 30 | log_retention: str = server_logging_retention, 31 | **kwargs, 32 | ): 33 | """ 34 | 配置项目日志信息 35 | Args: 36 | log_dir (Union[str, Path]): 日志存储的目录路径。 37 | log_conf (dict): 项目的详细日志配置字典,可覆盖其他参数的设置。 38 | sink (Union[str, Path]): 日志文件sink 39 | log_level (Union[str, int]): 全局的日志级别,如 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' 或对应的整数级别。 40 | console_log_level (Union[str, int]): 控制台输出的日志级别,默认为 logging.DEBUG。 41 | log_format (str): 日志的格式字符串。 42 | log_filter (object): 用于过滤日志的可调用对象。 43 | log_rotation (str): 日志的轮转策略,例如按时间或大小轮转, 默认每天 0 点新创建一个 log 文件。 44 | log_retention (str): 日志的保留策略,指定保留的时间或数量,默认最长保留 7 天。 45 | **kwargs: 其他未明确指定的额外参数,用于未来的扩展或备用。 46 | 47 | Returns: 48 | None 49 | """ 50 | logger.remove() 51 | logging_conf = {**default_logging_conf} 52 | logging_conf["console_handler"]["level"] = console_log_level 53 | 54 | log_conf = log_conf or {} 55 | log_conf.update(**kwargs) 56 | 57 | conf_mappings = { 58 | "sink": sink, 59 | "level": log_level, 60 | "format": log_format, 61 | "rotation": log_rotation, 62 | "retention": log_retention, 63 | } 64 | for key, val in conf_mappings.items(): 65 | add_param_if_true(log_conf, key, val) 66 | 67 | if log_dir: 68 | log_dir = Path(log_dir) 69 | server_log_file = log_dir / "server.log" 70 | error_log_file = log_dir / "error.log" 71 | log_conf["sink"] = log_conf.get("sink") or server_log_file 72 | logging_conf["error_handler"]["sink"] = error_log_file 73 | else: 74 | if not log_conf.get("sink"): 75 | raise ValueError("log_conf must have `sink` key") 76 | 77 | sink_file = log_conf.get("sink") 78 | sink_file = Path(sink_file) 79 | error_log_file = sink_file.parent / "error.log" 80 | logging_conf["error_handler"]["sink"] = error_log_file 81 | 82 | add_param_if_true(logging_conf, "server_handler", log_conf) 83 | for log_handler, _log_conf in logging_conf.items(): 84 | _log_conf["filter"] = log_filter 85 | logger.add(**_log_conf) 86 | 87 | logger.info("setup logging success") 88 | -------------------------------------------------------------------------------- /py_tools/logging/default_logging_conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: default_logging_conf.py 5 | # @Desc: { 默认日志配置 } 6 | # @Date: 2024/08/12 10:57 7 | import logging 8 | import sys 9 | 10 | from py_tools.constants import BASE_DIR 11 | 12 | # 项目日志目录 13 | logging_dir = BASE_DIR / "logs" 14 | 15 | # 项目运行时所有的日志文件 16 | server_log_file = logging_dir / "server.log" 17 | 18 | # 错误时的日志文件 19 | error_log_file = logging_dir / "error.log" 20 | 21 | # 项目服务综合日志滚动配置(每天 0 点新创建一个 log 文件) 22 | # 错误日志 超过10 MB就自动新建文件扩充 23 | server_logging_rotation = "00:00" 24 | error_logging_rotation = "10 MB" 25 | 26 | # 服务综合日志文件最长保留 7 天,错误日志 30 天 27 | server_logging_retention = "7 days" 28 | error_logging_retention = "30 days" 29 | 30 | # 项目日志配置 31 | console_log_level = logging.DEBUG 32 | trace_msg_log_format = ( 33 | "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level:<8} | {trace_msg} | {name}:{function}:{line} - {message}" 34 | ) 35 | default_log_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level:<8} | {name}:{function}:{line} - {message}" 36 | console_log_format = ( 37 | "{time:YYYY-MM-DD HH:mm:ss.SSS} | " 38 | "{level:<8} | " 39 | "{name}:{function}:{line} - {message}" 40 | ) 41 | 42 | default_logging_conf = { 43 | "console_handler": { 44 | "sink": sys.stdout, 45 | "level": console_log_level, 46 | }, 47 | "server_handler": { 48 | "sink": server_log_file, 49 | "level": "INFO", 50 | "rotation": server_logging_rotation, 51 | "retention": server_logging_retention, 52 | "enqueue": True, 53 | "backtrace": False, 54 | "diagnose": False, 55 | }, 56 | "error_handler": { 57 | "sink": error_log_file, 58 | "level": "ERROR", 59 | "rotation": error_logging_rotation, 60 | "retention": error_logging_retention, 61 | "enqueue": True, 62 | "backtrace": True, 63 | "diagnose": True, 64 | }, 65 | } 66 | -------------------------------------------------------------------------------- /py_tools/meta_cls/__init__.py: -------------------------------------------------------------------------------- 1 | from py_tools.meta_cls.base import SingletonMetaCls 2 | 3 | __all__ = ["SingletonMetaCls"] 4 | -------------------------------------------------------------------------------- /py_tools/meta_cls/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 元类模块 } 5 | # @Date: 2022/11/26 16:43 6 | import threading 7 | 8 | 9 | class SingletonMetaCls(type): 10 | """ 单例元类 """ 11 | _instance_lock = threading.Lock() 12 | 13 | def __init__(cls, *args, **kwargs): 14 | cls._instance = None 15 | super().__init__(*args, **kwargs) 16 | 17 | def _init_instance(cls, *args, **kwargs): 18 | if cls._instance: 19 | # 存在实例对象直接返回,减少锁竞争,提高性能 20 | return cls._instance 21 | 22 | with cls._instance_lock: 23 | if cls._instance is None: 24 | cls._instance = super().__call__(*args, **kwargs) 25 | return cls._instance 26 | 27 | def __call__(cls, *args, **kwargs): 28 | reinit = kwargs.pop("reinit", True) 29 | instance = cls._init_instance(*args, **kwargs) 30 | if reinit: 31 | # 重新初始化单例对象属性 32 | instance.__init__(*args, **kwargs) 33 | return instance 34 | -------------------------------------------------------------------------------- /py_tools/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2022/11/26 16:07 6 | from py_tools.utils.excel_util import ExcelUtil 7 | from py_tools.utils.time_util import TimeUtil 8 | from py_tools.utils.file_util import FileUtil 9 | from py_tools.utils.async_util import AsyncUtil 10 | from py_tools.utils.mask_util import MaskUtil 11 | from py_tools.utils.re_util import RegexUtil 12 | from py_tools.utils.jwt_util import JWTUtil 13 | from py_tools.utils.serializer_util import SerializerUtil 14 | 15 | __all__ = ["ExcelUtil", "TimeUtil", "FileUtil", "AsyncUtil", "MaskUtil", "RegexUtil", "JWTUtil", "SerializerUtil"] 16 | -------------------------------------------------------------------------------- /py_tools/utils/async_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: async.py 5 | # @Desc: { 异步工具模块 } 6 | # @Date: 2024/04/24 15:20 7 | import asyncio 8 | import functools 9 | from concurrent.futures import Executor, ThreadPoolExecutor 10 | from typing import Any, Coroutine, List 11 | 12 | from asgiref.sync import async_to_sync, sync_to_async 13 | 14 | 15 | class AsyncUtil: 16 | # 线程池 17 | BASE_EXECUTOR = ThreadPoolExecutor() 18 | 19 | # 默认并发限制 20 | DEFAULT_NUM_WORKERS = 5 21 | 22 | # 同步异步互转 23 | AsyncToSync = async_to_sync 24 | SyncToAsync = sync_to_async 25 | 26 | @staticmethod 27 | def get_asyncio_module(show_progress: bool = False) -> Any: 28 | if show_progress: 29 | from tqdm.asyncio import tqdm_asyncio 30 | 31 | module = tqdm_asyncio 32 | else: 33 | module = asyncio 34 | 35 | return module 36 | 37 | @classmethod 38 | def run_bg_task(cls, func, *args, executor: Executor = None, **kwargs): 39 | """运行后台任务""" 40 | if asyncio.iscoroutine(func): 41 | # 协程对象处理 42 | return asyncio.create_task(func) 43 | 44 | executor = executor or cls.BASE_EXECUTOR 45 | 46 | return executor.submit(func, *args, **kwargs) 47 | 48 | @classmethod 49 | async def async_run(cls, func, *args, executor: Executor = None, **kwargs): 50 | """同步方法使用线程池异步运行""" 51 | loop = asyncio.get_event_loop() 52 | task_func = functools.partial(func, *args, **kwargs) # 支持关键字参数 53 | executor = executor or cls.BASE_EXECUTOR 54 | return await loop.run_in_executor(executor, task_func) 55 | 56 | @staticmethod 57 | def sync_run(coro_obj: Coroutine): 58 | """同步环境运行异步方法""" 59 | return asyncio.run(coro_obj) 60 | 61 | @classmethod 62 | async def run_jobs( 63 | cls, 64 | jobs: List[Coroutine], 65 | show_progress: bool = False, 66 | workers: int = DEFAULT_NUM_WORKERS, 67 | ) -> List[Any]: 68 | """Run jobs. 69 | 70 | Args: 71 | jobs (List[Coroutine]): 72 | List of jobs to run. 73 | show_progress (bool): 74 | Whether to show progress bar. 75 | workers: 默认并发数 76 | 77 | Returns: 78 | List[Any]: 79 | List of results. 80 | """ 81 | asyncio_mod = cls.get_asyncio_module(show_progress=show_progress) 82 | semaphore = asyncio.Semaphore(workers) 83 | 84 | async def worker(job: Coroutine) -> Any: 85 | async with semaphore: 86 | return await job 87 | 88 | pool_jobs = [worker(job) for job in jobs] 89 | 90 | return await asyncio_mod.gather(*pool_jobs) 91 | 92 | 93 | class NestAsyncio: 94 | """Make asyncio event loop reentrant.""" 95 | 96 | is_applied = False 97 | 98 | @classmethod 99 | def apply_once(cls): 100 | """Ensures `nest_asyncio.apply()` is called only once.""" 101 | if not cls.is_applied: 102 | import nest_asyncio 103 | 104 | nest_asyncio.apply() 105 | cls.is_applied = True 106 | -------------------------------------------------------------------------------- /py_tools/utils/excel_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { Excel文件操作工具模块 } 5 | # @Date: 2022/04/03 19:34 6 | import os 7 | from typing import IO, List, Union 8 | 9 | try: 10 | import pandas 11 | except: 12 | pass 13 | 14 | from pydantic import BaseModel, Field 15 | 16 | 17 | class ColumnMapping(BaseModel): 18 | """列名映射""" 19 | 20 | column_name: str = Field(description="列名") 21 | column_alias: str = Field(description="列名别名") 22 | 23 | 24 | class SheetMapping(BaseModel): 25 | """sheet映射""" 26 | 27 | file_name: str = Field(description="文件名") 28 | sheet_name: str = Field(description="sheet名") 29 | 30 | 31 | class DataCollect(BaseModel): 32 | """多sheet的数据集合""" 33 | 34 | data_list: List[dict] = Field(description="数据列表") 35 | col_mappings: List[ColumnMapping] = Field(default=[], description="列名映射列表") 36 | sheet_name: str = Field(description="sheet名称") 37 | 38 | 39 | class ExcelUtil(object): 40 | """Excel文件操作工具类""" 41 | 42 | DEFAULT_SHEET_NAME = "Sheet1" 43 | 44 | @classmethod 45 | def _to_excel( 46 | cls, 47 | data_list: List[dict], 48 | col_mappings: List[ColumnMapping], 49 | sheet_name: str, 50 | writer: "pandas.ExcelWriter", 51 | **kwargs, 52 | ): 53 | """ 54 | 将列表数据写入excel文件 55 | Args: 56 | path_or_buffer: 文件路径或者字节缓冲流 57 | data_list: 数据集 List[dict] 58 | col_mappings: 表头列字段映射 59 | sheet_name: sheet名称 默认 Sheet1 60 | writer: ExcelWriter 61 | """ 62 | col_dict = {cm.column_name: cm.column_alias for cm in col_mappings} if col_mappings else None 63 | df = pandas.DataFrame(data=data_list) 64 | if col_dict: 65 | df.rename(columns=col_dict, inplace=True) 66 | df.to_excel(writer, sheet_name=sheet_name, index=False, **kwargs) 67 | 68 | @classmethod 69 | def list_to_excel( 70 | cls, 71 | path_or_buffer: Union[str, IO], 72 | data_list: List[dict], 73 | col_mappings: List[ColumnMapping] = None, 74 | sheet_name: str = None, 75 | **kwargs, 76 | ): 77 | """ 78 | 列表转 excel文件 79 | Args: 80 | path_or_buffer: 文件路径或者字节缓冲流 81 | data_list: 数据集 List[dict] 82 | col_mappings: 表头列字段映射 83 | sheet_name: sheet名称 默认 Sheet1 84 | writer: ExcelWriter 85 | 86 | Examples: 87 | data_list = [{"id": 1, "name": "hui", "age": 18}] 88 | user_col_mapping = [ 89 | ColumnMapping('id', '用户id'), 90 | ColumnMapping('name', '用户名'), 91 | ColumnMapping('age', '年龄'), 92 | ] 93 | ExcelUtil.list_to_excel('path_to_file', data_list, user_col_mapping) 94 | 95 | Returns: 96 | """ 97 | sheet_name = sheet_name or cls.DEFAULT_SHEET_NAME 98 | with pandas.ExcelWriter(path_or_buffer) as writer: 99 | cls._to_excel(data_list, col_mappings, sheet_name, writer, **kwargs) 100 | 101 | @classmethod 102 | def multi_list_to_excel(cls, path_or_buffer: Union[str, IO], data_collects: List[DataCollect], **kwargs): 103 | """ 104 | 多列表转带不同 sheet的excel文件 105 | Args: 106 | path_or_buffer: 文件路径或者字节缓冲流 107 | data_collects: 数据集列表 108 | 109 | Returns: 110 | """ 111 | with pandas.ExcelWriter(path_or_buffer) as writer: 112 | for data_collect in data_collects: 113 | cls._to_excel( 114 | data_list=data_collect.data_list, 115 | col_mappings=data_collect.col_mappings, 116 | sheet_name=data_collect.sheet_name, 117 | writer=writer, 118 | **kwargs, 119 | ) 120 | 121 | @classmethod 122 | def read_excel( 123 | cls, 124 | path_or_buffer: Union[str, IO], 125 | sheet_name: str = None, 126 | col_mappings: List[ColumnMapping] = None, 127 | all_col: bool = True, 128 | header: int = 0, 129 | nan_replace=None, 130 | **kwargs, 131 | ) -> List[dict]: 132 | """ 133 | 读取excel表格数据,根据col_mapping替换列名 134 | Args: 135 | path_or_buffer: 文件路径或者缓冲流 136 | sheet_name: 读书excel表的sheet名称 137 | col_mappings: 列字段映射 138 | all_col: True返回所有列信息,False则返回col_mapping对应的字段信息 139 | header: 默认0从第一行开启读取,用于指定从第几行开始读取 140 | nan_replace: nan值替换 141 | 142 | Returns: 143 | """ 144 | sheet_name = sheet_name or cls.DEFAULT_SHEET_NAME 145 | col_dict = {cm.column_name: cm.column_alias for cm in col_mappings} if col_mappings else None 146 | use_cols = None 147 | if not all_col: 148 | # 获取excel表指定列数据 149 | use_cols = list(col_dict) if col_dict else None 150 | 151 | df = pandas.read_excel(path_or_buffer, sheet_name=sheet_name, usecols=use_cols, header=header, **kwargs) 152 | if nan_replace is not None: 153 | df.fillna(nan_replace, inplace=True) 154 | 155 | if col_dict: 156 | df.rename(columns=col_dict, inplace=True) 157 | 158 | return df.to_dict("records") 159 | 160 | @classmethod 161 | def merge_excel_files( 162 | cls, input_files: List[str], output_file: str, sheet_mappings: List[SheetMapping] = None, **kwargs 163 | ): 164 | """ 165 | 合并多个Excel文件到一个文件中(每个文件对应一个工作表) 166 | 如果Excel文件有多个作表,则默认取第一个工作表 167 | Args: 168 | input_files: 待合并的excel文件列表 169 | output_file: 输出文件路径 170 | sheet_mappings: 文件工作表映射,默认为文件名 171 | 172 | Returns: 173 | """ 174 | sheet_mappings = sheet_mappings or [] 175 | sheet_dict = { 176 | os.path.basename(sheet_mapping.file_name): sheet_mapping.sheet_name for sheet_mapping in sheet_mappings 177 | } 178 | with pandas.ExcelWriter(output_file, engine_kwargs=kwargs) as writer: 179 | for file in input_files: 180 | df = pandas.read_excel(file) 181 | file_name = os.path.basename(file) 182 | sheet_name = sheet_dict.get(file_name, file_name) 183 | df.to_excel(writer, sheet_name=sheet_name, index=False) 184 | -------------------------------------------------------------------------------- /py_tools/utils/file_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: file.py 5 | # @Desc: { 文件相关工具模块 } 6 | # @Date: 2024/07/19 15:44 7 | import mimetypes 8 | import os 9 | from pathlib import Path 10 | from typing import AsyncGenerator, Generator, Union 11 | 12 | import aiofiles 13 | 14 | 15 | class FileUtil: 16 | @staticmethod 17 | def get_file_info( 18 | file_input: Union[str, bytes, Path], 19 | filename: str = None, 20 | ) -> Union[bytes, tuple]: 21 | """ 22 | 获取文件字节信息 23 | Args: 24 | file_input: 文件数据 25 | filename: 通过字节数据上传需要指定文件名称,方便获取mime_type 26 | 27 | Raises: 28 | ValueError 29 | 30 | Notes: 31 | 上传文件时指定文件的mime_type 32 | 33 | Returns: 34 | tuple(filename, file_bytes, mime_type) 35 | """ 36 | if isinstance(file_input, (str, Path)): 37 | filename = os.path.basename(str(file_input)) 38 | with open(file_input, "rb") as file: 39 | file_bytes = file.read() 40 | 41 | mime_type = mimetypes.guess_type(file_input)[0] 42 | return filename, file_bytes, mime_type 43 | elif isinstance(file_input, bytes): 44 | if not filename: 45 | raise ValueError("filename must be set when passing bytes") 46 | 47 | mime_type = mimetypes.guess_type(filename)[0] 48 | return filename, file_input, mime_type 49 | else: 50 | raise ValueError("file_input must be a string (file path) or bytes.") 51 | 52 | @staticmethod 53 | def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): 54 | """ 55 | 校验文件后缀 56 | Args: 57 | file_input: 文件路径 or 文件字节数据 58 | allowed_file_extensions: 允许的文件扩展名 59 | bytes_filename: 当字节数据时使用这个参数校验 60 | 61 | Raises: 62 | ValueError 63 | 64 | Returns: 65 | """ 66 | verify_file_path = None 67 | if isinstance(file_input, (str, Path)): 68 | verify_file_path = str(file_input) 69 | elif isinstance(file_input, bytes) and bytes_filename: 70 | verify_file_path = bytes_filename 71 | 72 | if not verify_file_path: 73 | # 仅传字节数据数据时不校验 74 | return 75 | 76 | file_ext = os.path.splitext(verify_file_path)[1].lower() 77 | if file_ext not in allowed_file_extensions: 78 | raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}") 79 | 80 | @staticmethod 81 | async def aread_bytes(file_path: Union[str, Path]) -> bytes: 82 | """ 83 | 异步读取文件的字节数据。 84 | 85 | Args: 86 | file_path: 要读取的文件路径。 87 | 88 | Raises: 89 | ValueError: 如果提供的路径不是字符串或Path对象。 90 | 91 | Returns: 92 | bytes: 文件的全部字节数据。 93 | """ 94 | if not isinstance(file_path, (str, Path)): 95 | raise ValueError("file_path必须是字符串或Path对象") 96 | 97 | async with aiofiles.open(file_path, "rb") as file: 98 | file_bytes = await file.read() 99 | return file_bytes 100 | 101 | @staticmethod 102 | async def awrite(file_path: Union[str, Path], data: bytes) -> str: 103 | """ 104 | 异步写入文件。 105 | 106 | Args: 107 | file_path: 要读取的文件路径。 108 | data: 要写入的字节数据。 109 | 110 | Raises: 111 | ValueError: 如果提供的路径不是字符串或Path对象。 112 | 113 | Returns: 114 | file_path: 文件路径。 115 | """ 116 | if not isinstance(file_path, (str, Path)): 117 | raise ValueError("file_path必须是字符串或Path对象") 118 | 119 | async with aiofiles.open(file_path, "wb") as file: 120 | await file.write(data) 121 | return file_path 122 | 123 | @staticmethod 124 | def read_bytes_chunked(file_path: Union[str, Path], chunk_size: int = 1024) -> Generator: 125 | """ 126 | 同步分块读取文件字节数据。 127 | 128 | Args: 129 | file_path: 要读取的文件路径。 130 | chunk_size: 每块读取的字节数,默认1024字节。 131 | 132 | Raises: 133 | ValueError: 如果提供的路径不是字符串或Path对象。 134 | 135 | Returns: 136 | 生成器: 每次迭代返回一个字节数据块。 137 | """ 138 | if not isinstance(file_path, (str, Path)): 139 | raise ValueError("file_path必须是字符串或Path对象") 140 | 141 | with open(file_path, "rb") as file: 142 | while True: 143 | chunk = file.read(chunk_size) 144 | if not chunk: 145 | break 146 | yield chunk 147 | 148 | @staticmethod 149 | async def aread_bytes_chunked(file_path: Union[str, Path], chunk_size: int = 1024) -> AsyncGenerator[bytes, None]: 150 | """ 151 | 异步分块读取文件字节数据。 152 | 153 | Args: 154 | file_path: 要读取的文件路径。 155 | chunk_size: 每块读取的字节数,默认1024字节。 156 | 157 | Raises: 158 | ValueError: 如果提供的路径不是字符串或Path对象。 159 | 160 | Returns: 161 | 异步生成器: 每次迭代返回一个字节数据块。 162 | """ 163 | if not isinstance(file_path, (str, Path)): 164 | raise ValueError("file_path必须是字符串或Path对象") 165 | 166 | async with aiofiles.open(file_path, "rb") as file: 167 | while True: 168 | chunk = await file.read(chunk_size) 169 | if not chunk: 170 | break 171 | yield chunk 172 | 173 | 174 | async def main(): 175 | file_bytes = await FileUtil.aread_bytes(Path(__file__)) 176 | print(file_bytes) 177 | 178 | for chunk in FileUtil.read_bytes_chunked(Path(__file__), chunk_size=1024): 179 | print(chunk) 180 | 181 | async for chunk in FileUtil.aread_bytes_chunked(Path(__file__), chunk_size=1024): 182 | print(chunk) 183 | 184 | 185 | if __name__ == "__main__": 186 | import asyncio 187 | 188 | asyncio.run(main()) 189 | -------------------------------------------------------------------------------- /py_tools/utils/func_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 实用小函数模块 } 5 | # @Date: 2023/09/10 00:07 6 | 7 | 8 | def chunk_list(data_list: list, chunk_size: int) -> list: 9 | """ 10 | 等份切分列表 11 | Args: 12 | data_list: 数据列表 13 | chunk_size: 每份大小 14 | 15 | Returns: list 16 | """ 17 | return [data_list[i : i + chunk_size] for i in range(0, len(data_list), chunk_size)] 18 | 19 | 20 | def add_param_if_true(params, key, value, is_check_none=True): 21 | """ 22 | 值不为空则添加到参数字典中 23 | Args: 24 | params: 要加入元素的字典 25 | key: 要加入字典的key值 26 | value: 要加入字典的value值 27 | is_check_none: 是否只检查空值None, 默认True 28 | - True: 不允许None, 但允许 0、False、空串、空列表、空字典等是有意义的 29 | - False: 则不允许所有空值 30 | """ 31 | if value or (is_check_none and value is not None): 32 | params[key] = value 33 | -------------------------------------------------------------------------------- /py_tools/utils/jwt_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: jwt_util.py 5 | # @Desc: { jwt 工具模块 } 6 | # @Date: 2024/11/04 15:05 7 | import datetime 8 | from typing import Any, Dict, Optional 9 | 10 | from jose import JWTError, jwt 11 | from loguru import logger 12 | 13 | 14 | class JWTUtil: 15 | """JWT 工具类,用于生成和验证 JWT 令牌。 16 | 17 | Attributes: 18 | secret_key (str): 用于签名 JWT 的密钥。 19 | algorithm (str): 使用的加密算法,默认是 HS256。 20 | expiration_minutes (int): 令牌的默认过期时间,以分钟为单位。 21 | """ 22 | 23 | def __init__(self, secret_key: str, algorithm: str = "HS256", expiration_minutes: int = 60 * 2): 24 | """初始化 JWTUtil 实例。 25 | 26 | Args: 27 | secret_key (str): 用于签名 JWT 的密钥。 28 | algorithm (str): 使用的加密算法,默认为 'HS256'。 29 | expiration_minutes (int): 令牌的默认过期时间(分钟)。 30 | """ 31 | self.secret_key = secret_key 32 | self.algorithm = algorithm 33 | self.expiration_minutes = expiration_minutes 34 | 35 | def generate_token(self, data: Dict[str, Any], expires_delta: Optional[datetime.timedelta] = None) -> str: 36 | """生成 JWT 令牌。 37 | 38 | Args: 39 | data (Dict[str, Any]): 令牌中包含的数据 (payload)。 40 | expires_delta (Optional[datetime.timedelta], optional): 自定义的过期时间。如果没有指定,则使用默认的过期时间。 41 | 42 | Returns: 43 | str: 生成的 JWT 字符串。 44 | """ 45 | to_encode = data.copy() 46 | expire = datetime.datetime.utcnow() + (expires_delta or datetime.timedelta(minutes=self.expiration_minutes)) 47 | to_encode.update({"exp": expire}) 48 | token = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) 49 | return token 50 | 51 | def verify_token(self, token: str) -> Optional[Dict[str, Any]]: 52 | """验证 JWT 令牌并返回其中的数据。 53 | 54 | Args: 55 | token (str): 要验证的 JWT 字符串。 56 | 57 | Returns: 58 | Optional[Dict[str, Any]]: 如果验证成功,返回解码后的数据;如果验证失败,返回 None。 59 | """ 60 | try: 61 | decoded_data = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) 62 | return decoded_data 63 | except JWTError as e: 64 | logger.error(f"Token verification failed, {e}") 65 | return None 66 | 67 | def refresh_token(self, token: str, expires_delta: Optional[datetime.timedelta] = None) -> Optional[str]: 68 | """刷新 JWT 令牌。 69 | 70 | Args: 71 | token (str): 旧的 JWT 字符串。 72 | expires_delta (Optional[datetime.timedelta], optional): 自定义的过期时间。如果没有指定,则使用默认过期时间。 73 | 74 | Returns: 75 | Optional[str]: 新生成的 JWT 字符串;如果旧的令牌无效或已过期,返回 None。 76 | """ 77 | decoded_data = self.verify_token(token) 78 | if not decoded_data: 79 | return None 80 | decoded_data.pop("exp", None) 81 | return self.generate_token(decoded_data, expires_delta) 82 | -------------------------------------------------------------------------------- /py_tools/utils/mask_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 字符串掩码工具类模块 } 5 | # @Date: 2022/11/26 18:09 6 | import re 7 | from typing import Union 8 | 9 | 10 | class MaskUtil(object): 11 | """掩码工具类""" 12 | 13 | # 元素掩码的格式 (匹配规则, 替换后的内容) 14 | # \1, \3 指的是取第几个分组数据相当于 group(1)、group(3) 15 | ADDRESS = (r"(\w)", r"*") # 地址 16 | NAME = (r"(.{1})(.{1})(.*)", r"\1*\3") # 名字 17 | PHONE = (r"(\d{3})(.*)(\d{4})", r"\1****\3") # 电话号码 18 | ID_CARD = (r"(\d{6})(.*)(\d{4})", r"\1****\3") # 身份证 19 | WECHAT_NUM = (r"(.{1})(.*)(.{1})", r"\1****\3") # 微信号 20 | 21 | @classmethod 22 | def mask(cls, origin_text: str, mask_type: Union[tuple, str] = NAME): 23 | """数据掩码""" 24 | if isinstance(mask_type, tuple): 25 | return re.sub(*mask_type, str(origin_text)) 26 | elif isinstance(mask_type, str): 27 | mark_rule_tuple = getattr(cls, mask_type.upper()) 28 | return cls.mask(origin_text, mark_rule_tuple) 29 | return origin_text 30 | 31 | @classmethod 32 | def mask_phone(cls, origin_text: str, mask_type: Union[tuple, str] = PHONE): 33 | return cls.mask(origin_text, mask_type) 34 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/__init__.py: -------------------------------------------------------------------------------- 1 | from py_tools.utils.project_templates.make_pro import make_project 2 | 3 | __all__ = ["make_project"] 4 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/make_pro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: project_template.py 5 | # @Desc: { 项目模板工具模块 } 6 | # @Date: 2024/04/26 17:46 7 | import argparse 8 | import os 9 | import shutil 10 | from pathlib import Path 11 | 12 | from loguru import logger 13 | 14 | # 项目基准目录 15 | BASE_DIR = Path(__file__).parent.parent.parent.parent 16 | 17 | template_dir = os.path.dirname(__file__) 18 | py_template_dir = os.path.join(template_dir, "python_project") 19 | 20 | 21 | def gen_py_project(project_name): 22 | logger.info(f"Generating Python project [{project_name}] structure...") 23 | 24 | # 创建项目目录 25 | os.makedirs(project_name, exist_ok=True) 26 | 27 | # 创建 README.md 文件 28 | with open(os.path.join(project_name, "README.md"), "w") as readme: 29 | readme.write("# Project: " + project_name) 30 | 31 | # 创建 main.py 文件(示例中简单创建一个空文件) 32 | with open(os.path.join(project_name, "main.py"), "w") as main_file: 33 | main_file.write("# 主入口模块") 34 | 35 | # 创建 requirements.txt 文件(示例中简单创建一个空文件) 36 | with open(os.path.join(project_name, "requirements.txt"), "w") as requirements: 37 | requirements.write("hui-tools") 38 | 39 | # 创建 pre-commit-config.yaml、ruff.toml 文件 40 | shutil.copy(src=BASE_DIR / ".pre-commit-config.yaml", dst=os.path.join(project_name, ".pre-commit-config.yaml")) 41 | shutil.copy(src=BASE_DIR / "ruff.toml", dst=os.path.join(project_name, "ruff.toml")) 42 | 43 | # 创建 docs 目录 44 | os.makedirs(os.path.join(project_name, "docs"), exist_ok=True) 45 | 46 | # 创建 src 目录及其子目录 47 | target_dir = os.path.join(project_name, "src") 48 | src_dir = os.path.join(py_template_dir, "src") 49 | shutil.copytree(src_dir, target_dir, dirs_exist_ok=True) 50 | 51 | # 创建 tests 目录 52 | os.makedirs(os.path.join(project_name, "tests"), exist_ok=True) 53 | 54 | logger.info(f"Python project [{project_name}] generated successfully.") 55 | 56 | 57 | def make_project_python(args): 58 | project_name = args.project_name 59 | try: 60 | gen_py_project(project_name) 61 | except Exception: 62 | logger.exception("Failed to generate Python project.") 63 | shutil.rmtree(project_name, ignore_errors=True) 64 | 65 | 66 | def make_project_java(args): 67 | print(f"Generating Java project [{args.project_name}] structure...") 68 | # Add code to generate Java project structure 69 | 70 | 71 | def make_project(): 72 | parser = argparse.ArgumentParser(description="Generate project structure.") 73 | subparsers = parser.add_subparsers(dest="subcommand") 74 | 75 | project_parser = subparsers.add_parser("make_project") 76 | project_parser.add_argument("project_name", help="Name of the project") 77 | project_parser.add_argument("--python", action="store_true", help="Generate Python project structure") 78 | project_parser.add_argument("--java", action="store_true", help="Generate Java project structure") 79 | 80 | args = parser.parse_args() 81 | 82 | if args.subcommand == "make_project": 83 | if args.python: 84 | make_project_python(args) 85 | elif args.java: 86 | make_project_java(args) 87 | else: 88 | make_project_python(args) 89 | else: 90 | parser.print_help() 91 | 92 | 93 | if __name__ == "__main__": 94 | make_project() 95 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/constants/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/constants/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/__init__.py: -------------------------------------------------------------------------------- 1 | from src import settings 2 | from src.dao.redis import RedisManager 3 | 4 | from py_tools.connections.db.mysql import DBManager, SQLAlchemyManager 5 | 6 | 7 | async def init_orm(): 8 | """初始化mysql的ORM""" 9 | db_client = SQLAlchemyManager( 10 | host=settings.mysql_host, 11 | port=settings.mysql_port, 12 | user=settings.mysql_user, 13 | password=settings.mysql_password, 14 | db_name=settings.mysql_dbname, 15 | ) 16 | db_client.init_mysql_engine() 17 | DBManager.init_db_client(db_client) 18 | return db_client 19 | 20 | 21 | async def init_redis(): 22 | RedisManager.init_redis_client( 23 | async_client=True, 24 | host=settings.redis_host, 25 | port=settings.redis_port, 26 | password=settings.redis_password, 27 | db=settings.redis_db, 28 | ) 29 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/orm/manage/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/orm/manage/user.py: -------------------------------------------------------------------------------- 1 | from src.dao.orm.table import UserTable 2 | 3 | from py_tools.connections.db.mysql import DBManager 4 | 5 | 6 | class UserManager(DBManager): 7 | orm_table = UserTable 8 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/orm/table/__init__.py: -------------------------------------------------------------------------------- 1 | from src.dao.orm.table.user import UserTable 2 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/orm/table/user.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import Mapped, mapped_column 2 | 3 | from py_tools.connections.db.mysql import BaseOrmTableWithTS 4 | 5 | 6 | class UserTable(BaseOrmTableWithTS): 7 | """用户表""" 8 | 9 | __tablename__ = "user" 10 | username: Mapped[str] = mapped_column(comment="用户昵称") 11 | password: Mapped[str] = mapped_column(comment="用户密码") 12 | phone: Mapped[str] = mapped_column(comment="手机号") 13 | email: Mapped[str] = mapped_column(comment="邮箱") 14 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/redis/__init__.py: -------------------------------------------------------------------------------- 1 | from src.dao.redis.cache_info import RedisKey 2 | from src.dao.redis.client import RedisManager 3 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/redis/cache_info.py: -------------------------------------------------------------------------------- 1 | class RedisKey(object): 2 | """Redis Key 统一管理""" 3 | 4 | prefix_key = "tf" 5 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/dao/redis/client.py: -------------------------------------------------------------------------------- 1 | from py_tools.connections.db.redis_client import BaseRedisManager 2 | 3 | 4 | class RedisManager(BaseRedisManager): 5 | cache_key_prefix = "" 6 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/data_schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/data_schemas/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/data_schemas/api_schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/data_schemas/api_schemas/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/data_schemas/logic_schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/data_schemas/logic_schemas/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/enums/__init__.py: -------------------------------------------------------------------------------- 1 | from src.enums.base import BizErrCodeEnum 2 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/enums/base.py: -------------------------------------------------------------------------------- 1 | from py_tools.enums import BaseErrCodeEnum 2 | 3 | 4 | class BizErrCodeEnum(BaseErrCodeEnum): 5 | """ 6 | 错误码前缀 7 | - 000-通用基础错误码前缀 8 | - 100-待定 9 | - 200-通用业务错误码前缀 10 | eg: 11 | - 201-用户模块 12 | - 202-订单模块 13 | - 300-待定 14 | - 400-通用请求错误 15 | - 500-通用系统错误码前缀 16 | """ 17 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/handlers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/handlers/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/middlewares/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/middlewares/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/routes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/routes/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/server.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/server.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiDBK/py-tools/c879d0d111dcc11c11d125e0b47b358388bd0d66/py_tools/utils/project_templates/python_project/src/services/__init__.py -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/services/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Desc: { 通用逻辑服务 } 4 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/settings/__init__.py: -------------------------------------------------------------------------------- 1 | from src.settings.base_setting import server_host, server_log_level, server_port 2 | from src.settings.db_setting import ( 3 | mysql_dbname, 4 | mysql_host, 5 | mysql_password, 6 | mysql_port, 7 | mysql_user, 8 | redis_db, 9 | redis_host, 10 | redis_password, 11 | redis_port, 12 | ) 13 | 14 | from src.settings.log_setting import console_log_level, logging_conf 15 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/settings/base_setting.py: -------------------------------------------------------------------------------- 1 | server_host = "127.0.0.1" 2 | server_port = 8000 3 | server_log_level = "warning" 4 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/settings/db_setting.py: -------------------------------------------------------------------------------- 1 | # mysql服务配置 2 | mysql_host = "127.0.0.1" 3 | mysql_port = 3306 4 | mysql_user = "root" 5 | mysql_password = "123456" 6 | mysql_dbname = "task_flow" 7 | 8 | 9 | # redis服务配置 10 | redis_host = "127.0.0.1" 11 | redis_port = 6379 12 | redis_password = "" 13 | redis_db = 0 14 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/settings/log_setting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | # 项目基准路径 6 | base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | # 项目日志目录 9 | logging_dir = os.path.join(base_dir, "logs/") 10 | 11 | # 项目运行时所有的日志文件 12 | server_log_file = os.path.join(logging_dir, "server.log") 13 | 14 | # 错误时的日志文件 15 | error_log_file = os.path.join(logging_dir, "error.log") 16 | 17 | # 项目服务综合日志滚动配置(每天 0 点新创建一个 log 文件) 18 | # 错误日志 超过10 MB就自动新建文件扩充 19 | server_logging_rotation = "00:00" 20 | error_logging_rotation = "10 MB" 21 | 22 | # 服务综合日志文件最长保留 7 天,错误日志 30 天 23 | server_logging_retention = "7 days" 24 | error_logging_retention = "30 days" 25 | 26 | # 项目日志配置 27 | console_log_level = logging.DEBUG 28 | log_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {trace_msg} | {name}:{function}:{line} - {message}" 29 | 30 | logging_conf = { 31 | "console_handler": { 32 | "sink": sys.stdout, 33 | "level": console_log_level, 34 | # "format": log_format, # 开启控制台也会输出 trace_msg 信息但日志没有颜色了 35 | }, 36 | "server_handler": { 37 | "sink": server_log_file, 38 | "level": "INFO", 39 | "rotation": server_logging_rotation, 40 | "retention": server_logging_retention, 41 | "enqueue": True, 42 | "backtrace": False, 43 | "diagnose": False, 44 | "format": log_format, 45 | }, 46 | "error_handler": { 47 | "sink": error_log_file, 48 | "level": "ERROR", 49 | "rotation": error_logging_rotation, 50 | "retention": error_logging_retention, 51 | "enqueue": True, 52 | "backtrace": True, 53 | "diagnose": True, 54 | "format": log_format, 55 | }, 56 | } 57 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.trace_util import TraceUtil 2 | from src.utils.web_util import APIUtil 3 | from src.utils.log_util import LogUtil 4 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/utils/context_util.py: -------------------------------------------------------------------------------- 1 | import contextvars 2 | 3 | # 请求唯一id 4 | REQUEST_ID: contextvars.ContextVar[str] = contextvars.ContextVar("request_id", default="") 5 | 6 | # 任务追踪唯一id 7 | TRACE_ID: contextvars.ContextVar[str] = contextvars.ContextVar("trace_id", default="") 8 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/utils/log_util.py: -------------------------------------------------------------------------------- 1 | from src.utils import context_util 2 | 3 | 4 | class LogUtil: 5 | @staticmethod 6 | def logger_filter(record): 7 | """日志过滤器补充request_id或trace_id""" 8 | req_id = context_util.REQUEST_ID.get() 9 | trace_id = context_util.TRACE_ID.get() 10 | 11 | trace_msg = f"{req_id} | {trace_id}" 12 | record["trace_msg"] = trace_msg 13 | return record 14 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/utils/trace_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 日志链路追踪工具模块 } 5 | # @Date: 2023/10/30 15:51 6 | import uuid 7 | 8 | from src.utils import context_util 9 | 10 | 11 | class TraceUtil(object): 12 | @staticmethod 13 | def set_req_id(req_id: str = None, title="req-id") -> str: 14 | """ 15 | 设置请求唯一ID 16 | Args: 17 | req_id: 请求ID 默认None取uuid 18 | title: 标题 默认req-id 19 | 20 | Returns: 21 | title:req_id 22 | """ 23 | req_id = req_id or uuid.uuid4().hex 24 | req_id = f"{title}:{req_id}" 25 | 26 | context_util.REQUEST_ID.set(req_id) 27 | return req_id 28 | 29 | @staticmethod 30 | def set_trace_id(trace_id: str = None, title="trace-id") -> str: 31 | """ 32 | 设置追踪ID, 可用于一些脚本等场景进行链路追踪 33 | Args: 34 | trace_id: 追踪唯一ID 默认None取uuid 35 | title: 标题 默认 trace-id, 可以用于标识业务 36 | 37 | Returns: 38 | title:trace_id 39 | """ 40 | trace_id = trace_id or uuid.uuid4().hex 41 | trace_id = f"{title}:{trace_id}" 42 | 43 | context_util.TRACE_ID.set(trace_id) 44 | return trace_id 45 | -------------------------------------------------------------------------------- /py_tools/utils/project_templates/python_project/src/utils/web_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: hui 4 | # @Desc: { web相关工具类 } 5 | # @Date: 2023/09/06 16:49 6 | from src.enums import BizErrCodeEnum 7 | 8 | 9 | class APIUtil: 10 | @staticmethod 11 | def success_resp(data=None): 12 | """成功的响应""" 13 | data = data or {} 14 | resp_content = {"code": BizErrCodeEnum.OK.value, "message": BizErrCodeEnum.OK.desc, "data": data or {}} 15 | return resp_content 16 | 17 | @staticmethod 18 | def fail_resp_with_err_enum(err_enum: BizErrCodeEnum, err_msg: str = None, data=None): 19 | """失败的响应携带错误码""" 20 | resp_content = { 21 | "code": err_enum.code, 22 | "message": err_msg or err_enum.msg, 23 | "data": data or {}, 24 | } 25 | return resp_content 26 | 27 | @staticmethod 28 | def fail_resp(err_msg: str = None, data=None): 29 | """失败的响应 默认Failed错误码""" 30 | resp_content = { 31 | "code": BizErrCodeEnum.FAILED.code, 32 | "message": err_msg or BizErrCodeEnum.FAILED.msg, 33 | "data": data or {}, 34 | } 35 | return resp_content 36 | -------------------------------------------------------------------------------- /py_tools/utils/re_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: re_util.py 5 | # @Desc: { 正则工具模块 } 6 | # @Date: 2024/08/24 11:31 7 | import re 8 | from typing import List 9 | 10 | 11 | class RegexUtil: 12 | """正则工具类""" 13 | 14 | # 匹配中文字符 15 | CHINESE_CHARACTER_PATTERN = re.compile(r"[\u4e00-\u9fa5]") 16 | 17 | # 匹配双字节字符(包括汉字以及其他全角字符) 18 | DOUBLE_BYTE_CHARACTER_PATTERN = re.compile(r"[^\x00-\xff]") 19 | 20 | # 匹配Email地址 21 | EMAIL_PATTERN = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") 22 | 23 | # 匹配http网址URL 24 | HTTP_LINK_PATTERN = re.compile(r"(https?:\/\/\S+)") 25 | 26 | # 匹配中国大陆手机号码 27 | CHINESE_PHONE_PATTERN = re.compile(r"1[3-9]\d{9}") 28 | 29 | # 匹配电话号码(包括座机) 30 | TELEPHONE_PATTERN = re.compile(r"[0-9-()()]{7,18}") 31 | 32 | # 匹配负浮点数 33 | NEGATIVE_FLOAT_PATTERN = re.compile(r"-?\d+\.\d+") 34 | 35 | # 匹配整数(包括正负) 36 | INTEGER_PATTERN = re.compile(r"-?[1-9]\d*") 37 | 38 | # 匹配正浮点数 39 | POSITIVE_FLOAT_PATTERN = re.compile(r"[1-9]\d*\.\d*|0\.\d*[1-9]\d*") 40 | 41 | # 匹配腾讯QQ号 42 | QQ_PATTERN = re.compile(r"\d{5,11}") 43 | 44 | # 匹配中国邮政编码 45 | POSTAL_CODE_PATTERN = re.compile(r"\d{6}") 46 | 47 | # 匹配中国身份证号码 48 | ID_CARD_PATTERN = re.compile(r"\d{17}[\d|x]|\d{15}") 49 | 50 | # 匹配日期格式(如YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD) 51 | DATE_PATTERN = re.compile(r"\d{4}[-/.]\d{2}[-/.]\d{2}") 52 | 53 | # 匹配正整数 54 | POSITIVE_INTEGER_PATTERN = re.compile(r"[1-9]\d*") 55 | 56 | # 匹配负整数 57 | NEGATIVE_INTEGER_PATTERN = re.compile(r"-[1-9]\d*") 58 | 59 | # 匹配用户名(支持中英文、数字、下划线、减号) 60 | USERNAME_PATTERN = re.compile(r"[A-Za-z0-9_\-\u4e00-\u9fa5]+") 61 | 62 | @classmethod 63 | def find_http_links(cls, text: str) -> List[str]: 64 | """查找文本中的所有HTTP/HTTPS链接""" 65 | return cls.HTTP_LINK_PATTERN.findall(text) 66 | 67 | @classmethod 68 | def find_chinese_characters(cls, text: str) -> List[str]: 69 | """查找文本中的所有中文字符""" 70 | return cls.CHINESE_CHARACTER_PATTERN.findall(text) 71 | 72 | @classmethod 73 | def find_double_byte_characters(cls, text: str) -> List[str]: 74 | """查找文本中的所有双字节字符""" 75 | return cls.DOUBLE_BYTE_CHARACTER_PATTERN.findall(text) 76 | 77 | @classmethod 78 | def find_emails(cls, text: str) -> List[str]: 79 | """查找文本中的所有Email地址""" 80 | return cls.EMAIL_PATTERN.findall(text) 81 | 82 | @classmethod 83 | def find_chinese_phone_numbers(cls, text: str) -> List[str]: 84 | """查找文本中的所有中国大陆手机号码""" 85 | return cls.CHINESE_PHONE_PATTERN.findall(text) 86 | 87 | # 查找所有匹配的电话号码(包括座机) 88 | @classmethod 89 | def find_telephone_numbers(cls, text: str) -> List[str]: 90 | """查找文本中的所有电话号码(包括座机)""" 91 | return cls.TELEPHONE_PATTERN.findall(text) 92 | 93 | @classmethod 94 | def find_negative_floats(cls, text: str) -> List[str]: 95 | """查找文本中的所有负浮点数""" 96 | return cls.NEGATIVE_FLOAT_PATTERN.findall(text) 97 | 98 | @classmethod 99 | def find_integers(cls, text: str) -> List[str]: 100 | """查找文本中的所有整数(包括正负)""" 101 | return cls.INTEGER_PATTERN.findall(text) 102 | 103 | @classmethod 104 | def find_positive_floats(cls, text: str) -> List[str]: 105 | """查找文本中的所有正浮点数""" 106 | return cls.POSITIVE_FLOAT_PATTERN.findall(text) 107 | 108 | @classmethod 109 | def find_qq_numbers(cls, text: str) -> List[str]: 110 | """查找文本中的所有腾讯QQ号""" 111 | return cls.QQ_PATTERN.findall(text) 112 | 113 | @classmethod 114 | def find_postal_codes(cls, text: str) -> List[str]: 115 | """查找文本中的所有邮政编码""" 116 | return cls.POSTAL_CODE_PATTERN.findall(text) 117 | 118 | @classmethod 119 | def find_id_cards(cls, text: str) -> List[str]: 120 | """查找文本中的所有身份证号码""" 121 | return cls.ID_CARD_PATTERN.findall(text) 122 | 123 | @classmethod 124 | def find_dates(cls, text: str) -> List[str]: 125 | """查找文本中的所有日期格式(YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD)""" 126 | return cls.DATE_PATTERN.findall(text) 127 | 128 | @classmethod 129 | def find_positive_integers(cls, text: str) -> List[str]: 130 | """查找文本中的所有正整数""" 131 | return cls.POSITIVE_INTEGER_PATTERN.findall(text) 132 | 133 | @classmethod 134 | def find_negative_integers(cls, text: str) -> List[str]: 135 | """查找文本中的所有负整数""" 136 | return cls.NEGATIVE_INTEGER_PATTERN.findall(text) 137 | 138 | @classmethod 139 | def find_usernames(cls, text: str) -> List[str]: 140 | """查找文本中的所有用户名(支持中英文、数字、下划线、减号)""" 141 | return cls.USERNAME_PATTERN.findall(text) 142 | -------------------------------------------------------------------------------- /py_tools/utils/serializer_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 序列化器模块 } 5 | # @Date: 2023/09/10 00:15 6 | import dataclasses 7 | from dataclasses import asdict, dataclass 8 | from typing import List, Type, Union 9 | 10 | from pydantic import BaseModel 11 | from sqlalchemy import RowMapping 12 | 13 | from py_tools.connections.db.mysql import BaseOrmTable 14 | 15 | 16 | class SerializerUtil: 17 | @classmethod 18 | def data_to_model( 19 | cls, 20 | data_obj: Union[ 21 | dict, 22 | BaseOrmTable, 23 | BaseModel, 24 | dataclass, 25 | List[dict], 26 | List[BaseOrmTable], 27 | List[BaseModel], 28 | ], 29 | to_model: Type[Union[BaseModel, BaseOrmTable, dataclass]], 30 | ) -> Union[BaseModel, List[BaseModel], List[BaseOrmTable], List[dataclass], None]: 31 | """ 32 | 将数据对象转换成 pydantic 或 sqlalchemy 模型对象, 如果是数据库库表模型对象则调用to_dict()后递归 33 | Args: 34 | data_obj: 支持 字典对象, pydantic、sqlalchemy模型对象, 列表对象 35 | to_model: 转换后数据模型 36 | 37 | Notes: 38 | - 对于实现了 to_dict() 方法的模型对象,将调用该方法返回字典。 39 | 40 | returns: 41 | 转换后的对象 42 | """ 43 | 44 | if isinstance(data_obj, dict): 45 | # 字典处理 46 | return to_model(**data_obj) 47 | 48 | elif isinstance(data_obj, BaseOrmTable): 49 | # 数据库表模型对象处理, to_dict()后递归调用 50 | return cls.data_to_model(data_obj.to_dict(), to_model=to_model) 51 | 52 | elif isinstance(data_obj, BaseModel): 53 | # pydantic v2 模型对象处理, model_dump 后递归调用 54 | return cls.data_to_model(data_obj.model_dump(), to_model=to_model) 55 | 56 | elif dataclasses.is_dataclass(data_obj): 57 | # dataclass 模型对象处理, asdict() 后递归调用 58 | return cls.data_to_model(asdict(data_obj), to_model=to_model) 59 | 60 | elif hasattr(data_obj, "to_dict"): 61 | # 如果模型对象有 to_dict 方法,调用该方法返回字典 62 | return cls.data_to_model(data_obj.to_dict(), to_model=to_model) 63 | 64 | elif isinstance(data_obj, list): 65 | # 列表处理 66 | return [cls.data_to_model(item, to_model=to_model) for item in data_obj] 67 | 68 | else: 69 | raise ValueError(f"不支持此{data_obj}类型的序列化转换") 70 | 71 | @classmethod 72 | def model_to_data( 73 | cls, 74 | model_obj: Union[ 75 | BaseModel, 76 | BaseOrmTable, 77 | dataclass, 78 | List[BaseModel], 79 | List[BaseOrmTable], 80 | List[dataclass], 81 | ], 82 | ) -> Union[dict, List[dict], None]: 83 | """ 84 | 将 Pydantic 模型或 SQLAlchemy 模型对象转换回原始字典或列表对象。 85 | 86 | Args: 87 | model_obj: 支持 Pydantic 模型对象、SQLAlchemy 模型、dataclass 对象,或者它们的列表 88 | 89 | Notes: 90 | - 对于实现了 to_dict() 方法的模型对象,将调用该方法返回字典。 91 | 92 | Returns: 93 | 转换后的字典或列表 94 | """ 95 | 96 | if isinstance(model_obj, dict): 97 | return model_obj 98 | 99 | if isinstance(model_obj, RowMapping): 100 | return dict(model_obj) 101 | 102 | elif isinstance(model_obj, BaseModel): 103 | # Pydantic 模型对象处理,model_dump() 返回字典 104 | return model_obj.model_dump() 105 | 106 | elif isinstance(model_obj, BaseOrmTable): 107 | # SQLAlchemy 模型对象处理,to_dict() 返回字典 108 | return model_obj.to_dict() 109 | 110 | elif dataclasses.is_dataclass(model_obj): 111 | # dataclass 模型对象处理, asdict() 返回字典 112 | return asdict(model_obj) 113 | 114 | elif hasattr(model_obj, "to_dict"): 115 | # 如果模型对象有 to_dict 方法,调用该方法返回字典 116 | return model_obj.to_dict() 117 | 118 | elif isinstance(model_obj, list): 119 | # 列表处理,递归转换每个元素 120 | return [cls.model_to_data(item) for item in model_obj] 121 | 122 | else: 123 | raise ValueError(f"不支持此{model_obj}类型的反序列化转换") 124 | -------------------------------------------------------------------------------- /py_tools/utils/time_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 时间工具类模块 } 5 | # @Date: 2022/11/26 16:08 6 | 7 | import time 8 | from datetime import datetime 9 | 10 | from dateutil.relativedelta import relativedelta 11 | 12 | from py_tools.enums.time import TimeFormatEnum 13 | from py_tools.meta_cls import SingletonMetaCls 14 | 15 | 16 | class TimeUtil(metaclass=SingletonMetaCls): 17 | """ 18 | 时间工具类 19 | """ 20 | 21 | @classmethod 22 | def instance(cls, reinit=True, *args, **kwargs): 23 | instance = cls(*args, reinit=reinit, **kwargs) 24 | return instance 25 | 26 | def __init__(self, datetime_obj: datetime = None, format_str: str = TimeFormatEnum.DateTime.value): 27 | """ 28 | 时间工具类初始化 29 | Args: 30 | datetime_obj: 待处理的datetime对象,不传时默认取当前时间 31 | format_str: 时间格式化字符串 32 | """ 33 | self.datetime_obj = datetime_obj or datetime.now() 34 | self.format_str = format_str 35 | 36 | @property 37 | def yesterday(self) -> datetime: 38 | """获取昨天的日期""" 39 | return self.sub_time(days=1) 40 | 41 | @property 42 | def tomorrow(self) -> datetime: 43 | """获取明天的日期""" 44 | return self.add_time(days=1) 45 | 46 | @property 47 | def week_later(self) -> datetime: 48 | """获取一周后的日期""" 49 | return self.add_time(days=7) 50 | 51 | @property 52 | def month_later(self) -> datetime: 53 | """获取一个月后的日期""" 54 | return self.add_time(months=1) 55 | 56 | def add_time(self, years=0, months=0, days=0, hours=0, minutes=0, seconds=0, **kwargs) -> datetime: 57 | """增加指定时间""" 58 | return self.datetime_obj + relativedelta( 59 | years=years, months=months, days=days, hours=hours, minutes=minutes, seconds=seconds, **kwargs 60 | ) 61 | 62 | def sub_time(self, years=0, months=0, days=0, hours=0, minutes=0, seconds=0, **kwargs) -> datetime: 63 | """减去指定时间""" 64 | return self.datetime_obj - relativedelta( 65 | years=years, months=months, days=days, hours=hours, minutes=minutes, seconds=seconds, **kwargs 66 | ) 67 | 68 | def str_to_datetime(self, date_str: str, format_str: str = None) -> datetime: 69 | """将时间字符串转换为 datetime 对象""" 70 | format_str = format_str or self.format_str 71 | return datetime.strptime(date_str, format_str) 72 | 73 | def datetime_to_str(self, format_str: str = None) -> str: 74 | """将 datetime 对象转换为时间字符串""" 75 | format_str = format_str or self.format_str 76 | return self.datetime_obj.strftime(format_str) 77 | 78 | def timestamp_to_str(self, timestamp: float, format_str: str = None) -> str: 79 | """将时间戳转换为时间字符串""" 80 | format_str = format_str or self.format_str 81 | return datetime.fromtimestamp(timestamp).strftime(format_str) 82 | 83 | def str_to_timestamp(self, time_str: str, format_str: str = None) -> float: 84 | """将时间字符串转换为时间戳""" 85 | format_str = format_str or self.format_str 86 | return time.mktime(time.strptime(time_str, format_str)) 87 | 88 | @staticmethod 89 | def timestamp_to_datetime(timestamp: float) -> datetime: 90 | """将时间戳转换为 datetime 对象""" 91 | return datetime.fromtimestamp(timestamp) 92 | 93 | @property 94 | def timestamp(self) -> float: 95 | """获取 datetime 对象的时间戳""" 96 | return self.datetime_obj.timestamp() 97 | 98 | def date_diff(self, datetime_obj: datetime): 99 | """ 100 | 计算两个日期之间的差值详情 101 | Args: 102 | datetime_obj: 时间对象 103 | 104 | Returns: DateDiff 105 | """ 106 | delta = relativedelta(self.datetime_obj, datetime_obj) 107 | return delta 108 | 109 | def start_of_week(self) -> datetime: 110 | """获取本周的开始日期(周一)""" 111 | return self.datetime_obj - relativedelta(days=self.datetime_obj.weekday()) 112 | 113 | def end_of_week(self) -> datetime: 114 | """获取本周的结束日期(周日)""" 115 | return self.start_of_week() + relativedelta(days=6) 116 | 117 | def start_of_month(self) -> datetime: 118 | """获取本月的第一天""" 119 | return self.datetime_obj.replace(day=1) 120 | 121 | def end_of_month(self) -> datetime: 122 | """获取本月的最后一天""" 123 | next_month = self.add_time(months=1) 124 | return next_month.replace(day=1) - relativedelta(days=1) 125 | 126 | def start_of_quarter(self) -> datetime: 127 | """获取本季度的第一天""" 128 | quarter_month_start = (self.datetime_obj.month - 1) // 3 * 3 + 1 129 | return self.datetime_obj.replace(month=quarter_month_start, day=1) 130 | 131 | def end_of_quarter(self) -> datetime: 132 | """获取本季度的最后一天""" 133 | next_quarter_start = self.start_of_quarter().replace(month=self.datetime_obj.month + 3) 134 | return next_quarter_start - relativedelta(days=1) 135 | 136 | def start_of_year(self) -> datetime: 137 | """获取本年度的第一天""" 138 | return self.datetime_obj.replace(month=1, day=1) 139 | 140 | def end_of_year(self) -> datetime: 141 | """获取本年度的最后一天""" 142 | return self.datetime_obj.replace(month=12, day=31) 143 | 144 | def is_weekday(self) -> bool: 145 | """判断当前日期是否是工作日(星期一到星期五)""" 146 | return self.datetime_obj.weekday() < 5 147 | 148 | def count_weekdays_between(self, datetime_obj: datetime, include_end_date: bool = True) -> int: 149 | """计算两个日期之间的工作日数量 150 | 151 | Args: 152 | datetime_obj: datetime 对象 153 | include_end_date: 是否包含结束日期(默认为 True) 154 | 155 | Returns: 156 | 两个日期之间的工作日数量 157 | """ 158 | # 确保 start_date 是较小的日期,end_date 是较大的日期 159 | start_date = min(self.datetime_obj, datetime_obj) 160 | end_date = max(self.datetime_obj, datetime_obj) 161 | 162 | # 如果不包含结束日期,将 end_date 减去一天 163 | if not include_end_date: 164 | end_date = end_date - relativedelta(days=1) 165 | 166 | # 计算两个日期之间的天数 167 | days_between = abs((end_date - start_date).days) 168 | 169 | # 计算完整周数,每周有5个工作日 170 | weeks_between = days_between // 7 171 | weekdays = weeks_between * 5 172 | 173 | # 计算剩余的天数 174 | remaining_days = days_between % 7 175 | # 遍历剩余的天数,检查每天是否为工作日,如果是,则累加工作日数量 176 | for day_offset in range(remaining_days + 1): 177 | if (start_date + relativedelta(days=day_offset)).weekday() < 5: 178 | weekdays += 1 179 | 180 | return weekdays 181 | -------------------------------------------------------------------------------- /py_tools/utils/tree_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: tree.py 5 | # @Desc: { 树形结构相关工具函数 } 6 | # @Date: 2024/04/24 11:54 7 | from typing import List 8 | 9 | 10 | def list_to_tree_dfs( 11 | data_list: List[dict], 12 | root_pid: int = 0, 13 | pid_field: str = "pid", 14 | sub_field: str = "children", 15 | relate_field: str = "id", 16 | level: int = 0, 17 | need_level: bool = False, 18 | ): 19 | """ 20 | 递归构造树形列表(深度优先) 21 | 22 | Args: 23 | data_list: 待转换为树形结构的字典列表 24 | root_pid: 根节点的父节点标识符,默认为 0 25 | pid_field: 字典中表示父节点标识符的字段名,默认为 "pid" 26 | sub_field: 子节点列表的字段名,默认为 "children" 27 | relate_field: 父子级关联字段,默认为 "id",例如 pid 与 id 关联 28 | level: 当前节点的层级,默认为 0 29 | need_level: 是否需要记录节点的层级,默认为 False 30 | 31 | Returns: 树形列表 32 | """ 33 | children = [] 34 | level = level + 1 # 记录层级 35 | for node in data_list: 36 | if node[pid_field] == root_pid: 37 | # 递归调用 38 | node[sub_field] = list_to_tree_dfs( 39 | data_list, node[relate_field], pid_field, sub_field, relate_field, level, need_level 40 | ) 41 | if need_level: 42 | node["level"] = level 43 | children.append(node) 44 | return children 45 | 46 | 47 | def list_to_tree_bfs( 48 | data_list: List[dict], 49 | root_pid: int = 0, 50 | pid_field: str = "pid", 51 | sub_field: str = "children", 52 | relate_field: str = "id", 53 | level: int = 0, 54 | need_level: bool = False, 55 | ): 56 | """ 57 | 构造树形列表(广度优先) 58 | 59 | Args: 60 | data_list: 待转换为树形结构的字典列表 61 | root_pid: 根节点的父节点标识符,默认为 0 62 | pid_field: 字典中表示父节点标识符的字段名,默认为 "pid" 63 | sub_field: 子节点列表的字段名,默认为 "children" 64 | relate_field: 父子级关联字段,默认为 "id",例如 pid 与 id 关联 65 | level: 当前节点的层级,默认为 0 66 | need_level: 是否需要记录节点的层级,默认为 False 67 | 68 | Returns: 树形列表 69 | """ 70 | queue = [(node, level + 1) for node in data_list if node[pid_field] == root_pid] 71 | 72 | tree_list = [] 73 | while queue: 74 | node, node_level = queue.pop(0) 75 | 76 | # 所有的子节点加入队列 77 | children = [] 78 | for child in data_list: 79 | if child[pid_field] == node[relate_field]: 80 | queue.append((child, node_level + 1)) 81 | children.append(child) 82 | 83 | node[sub_field] = children 84 | if need_level: 85 | node["level"] = node_level 86 | 87 | if node[pid_field] == root_pid: 88 | # 只有顶级节点才添加 89 | tree_list.append(node) 90 | 91 | return tree_list 92 | 93 | 94 | def tree_to_list_dfs(tree_list, sub_field="children", result_list=None, level=0, need_level=False): 95 | """ 96 | 将树形结构列表扁平化成一层列表(深度优先) 97 | 98 | Args: 99 | tree_list: 树形结构列表 100 | sub_field: 子节点列表的字段名,默认为 "children" 101 | result_list: 保存结果的列表 102 | level: 当前节点的层级,默认为 0 103 | need_level: 是否需要记录节点的层级,默认为 False 104 | 105 | Returns: 扁平化后的一层列表 106 | """ 107 | result_list = result_list or [] 108 | level = level + 1 109 | 110 | for node in tree_list: 111 | if need_level: 112 | node["level"] = level 113 | result_list.append(node) 114 | sub_list = node.pop(sub_field, None) 115 | if sub_list: 116 | tree_to_list_dfs(sub_list, sub_field, result_list, level, need_level) 117 | 118 | return result_list 119 | 120 | 121 | def tree_to_list_bfs(tree_list, sub_field="children", level=0, need_level=False): 122 | """ 123 | 将树形结构列表扁平化成一层列表(广度优先) 124 | 125 | Args: 126 | tree_list: 树形结构列表 127 | sub_field: 子节点列表的字段名,默认为 "children" 128 | level: 当前节点的层级,默认为 0 129 | need_level: 是否需要记录节点的层级,默认为 False 130 | 131 | Returns: 扁平化后的一层列表 132 | """ 133 | result_list = [] 134 | queue = [(node, level + 1) for node in tree_list] 135 | 136 | while queue: 137 | node, cur_level = queue.pop(0) # 取出队首节点 138 | children = node.pop(sub_field, []) 139 | if need_level: 140 | node["level"] = cur_level 141 | result_list.append(node) 142 | queue.extend((child, cur_level + 1) for child in children) 143 | 144 | return result_list 145 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openpyxl==3.0.10 2 | pandas==2.0.3 3 | requests==2.31.0 4 | aiohttp==3.9.5 5 | python-dateutil==2.8.2 6 | loguru==0.7.2 7 | cacheout==0.14.1 8 | redis==5.0.1 9 | python-memcached==1.62 10 | pytest==7.3.1 11 | pydantic==2.1.1 12 | sqlalchemy[asyncio]==2.0.20 13 | aiomysql==0.2.0 14 | minio==7.1.17 15 | asgiref==3.8.1 16 | nest_asyncio==1.6.0 17 | tqdm==4.66.4 18 | aiofiles==24.1.0 19 | python-jose==3.3.0 -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | select = ["E", "F"] 2 | ignore = ["E501", "E712", "E722", "F821", "E731"] 3 | 4 | ignore-init-module-imports = true 5 | 6 | # Allow autofix for all enabled rules (when `--fix`) is provided. 7 | fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] 8 | unfixable = [] 9 | 10 | # Exclude a variety of commonly ignored directories. 11 | exclude = [ 12 | ".bzr", 13 | ".direnv", 14 | ".eggs", 15 | ".git", 16 | ".git-rewrite", 17 | ".hg", 18 | ".mypy_cache", 19 | ".nox", 20 | ".pants.d", 21 | ".pytype", 22 | ".ruff_cache", 23 | ".svn", 24 | ".tox", 25 | ".venv", 26 | "__pypackages__", 27 | "_build", 28 | "buck-out", 29 | "build", 30 | "dist", 31 | "node_modules", 32 | "venv", 33 | ] 34 | 35 | # Same as Black. 36 | line-length = 120 37 | 38 | # Allow unused variables when underscore-prefixed. 39 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 40 | 41 | # Assume Python 3.8 42 | target-version = "py38" 43 | 44 | [per-file-ignores] 45 | "__init__.py" = ["F401"] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { pypi打包模块 } 5 | # @Date: 2023/9/04 19:59 6 | import operator 7 | from functools import reduce 8 | 9 | from setuptools import find_packages, setup 10 | 11 | 12 | class PKGManager: 13 | name = "huidevkit" 14 | version = "0.6.0" 15 | author = "hui" 16 | author_email = "huidbk@163.com" 17 | 18 | @classmethod 19 | def get_pkg_desc(cls): 20 | """获取包描述""" 21 | with open("README.md", "r") as f: 22 | long_description = f.read() 23 | return long_description 24 | 25 | @classmethod 26 | def get_install_requires(cls): 27 | """获取必须安装依赖""" 28 | requires = [ 29 | "loguru>=0.7.0,<0.8", 30 | "pydantic>=2.1.1,<3", 31 | "asgiref==3.8.1", 32 | "nest_asyncio==1.6.0", 33 | "tqdm==4.66.4", 34 | "python-dateutil==2.8.2", 35 | "requests==2.31.0", 36 | "aiohttp==3.9.5", 37 | "cacheout==0.14.1", 38 | "aiofiles==24.1.0", 39 | "python-jose==3.3.0", 40 | ] 41 | return requires 42 | 43 | @classmethod 44 | def get_extras_require(cls): 45 | """ 46 | 可选的依赖 47 | """ 48 | extras_require = { 49 | "db-orm": ["sqlalchemy[asyncio]==2.0.20", "aiomysql==0.2.0"], 50 | "db-redis": ["redis>=4.5.4"], 51 | "cache-proxy": ["redis>=4.5.4", "python-memcached==1.62", "cacheout==0.14.1"], 52 | "minio": ["minio==7.1.17"], 53 | "excel-tools": ["pandas==2.0.3", "openpyxl==3.0.10"], 54 | } 55 | 56 | extras_require["all"] = list(set(reduce(operator.add, [cls.get_install_requires(), *extras_require.values()]))) 57 | extras_require["test"] = ["pytest==7.3.1", "pytest-mock==3.14.0", "pytest-asyncio==0.23.8"] 58 | 59 | return extras_require 60 | 61 | 62 | def main(): 63 | setup( 64 | name=PKGManager.name, 65 | author=PKGManager.author, 66 | author_email=PKGManager.author_email, 67 | version=PKGManager.version, 68 | packages=find_packages(), 69 | url="https://github.com/HuiDBK/py-tools", 70 | license="Apache", 71 | description="Practical Python development tools", 72 | long_description=PKGManager.get_pkg_desc(), 73 | long_description_content_type="text/markdown", 74 | install_requires=PKGManager.get_install_requires(), 75 | classifiers=[ 76 | "Programming Language :: Python :: 3", 77 | "License :: OSI Approved :: Apache Software License", 78 | "Operating System :: OS Independent", 79 | ], 80 | extras_require=PKGManager.get_extras_require(), 81 | python_requires=">=3.8", 82 | entry_points={"console_scripts": ["py_tools = py_tools.utils.project_templates:make_project"]}, 83 | include_package_data=True, 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | # python3 setup.py sdist bdist_wheel 89 | # twine upload --repository testpypi dist/* 90 | # twine upload dist/* 91 | main() 92 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 模块描述 } 5 | # @Date: 2023/05/07 13:07 6 | 7 | 8 | def main(): 9 | pass 10 | 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /tests/chatbot/test_chatbot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: test_chatbot.py 5 | # @Desc: { webhook机器人单测 } 6 | # @Date: 2024/08/08 10:11 7 | from unittest.mock import MagicMock, patch 8 | 9 | import pytest 10 | 11 | from py_tools.chatbot import ( 12 | ChatBotFactory, 13 | ChatBotType, 14 | DingTalkChatBot, 15 | FeiShuAppServer, 16 | FeiShuChatBot, 17 | FeiShuTaskChatBot, 18 | WeComChatbot, 19 | ) 20 | from py_tools.enums.feishu import FeishuReceiveType 21 | from py_tools.exceptions import SendMsgException 22 | 23 | 24 | class TestChatBot: 25 | feishu_bot = FeiShuChatBot(webhook_url="test_url", secret="test_secret") 26 | dingtalk_bot = DingTalkChatBot(webhook_url="test_url", secret="test_secret") 27 | wecom_bot = WeComChatbot(webhook_url="test_url", secret="test_secret") 28 | 29 | @classmethod 30 | def chatbots(cls): 31 | return [cls.feishu_bot, cls.dingtalk_bot, cls.wecom_bot] 32 | 33 | def test_get_sign(self): 34 | timestamp = "1609459200" 35 | secret = "test_secret" 36 | for bot in self.chatbots(): 37 | assert bot._get_sign(timestamp, secret) == bot._get_sign(timestamp, secret) 38 | 39 | @pytest.fixture 40 | def mock_request_post(self, mocker): 41 | return mocker.patch("requests.post") 42 | 43 | def get_bot_mock_post_data(self, mock_request_post, bot): 44 | code_key = "" 45 | if isinstance(bot, FeiShuChatBot): 46 | code_key = "code" 47 | mock_request_post.return_value.json.return_value = {code_key: 0, "message": "ok"} 48 | elif isinstance(bot, DingTalkChatBot): 49 | code_key = "errcode" 50 | mock_request_post.return_value.json.return_value = {code_key: 0, "message": "ok"} 51 | elif isinstance(bot, WeComChatbot): 52 | code_key = MagicMock() 53 | code_key.status_code = 200 54 | mock_request_post.return_value = code_key 55 | 56 | return code_key 57 | 58 | def test_send_msg(self, mock_request_post): 59 | for bot in self.chatbots(): 60 | code_key = self.get_bot_mock_post_data(mock_request_post, bot) 61 | ret = bot.send_msg("test message") 62 | if isinstance(ret, dict): 63 | assert ret.get(code_key) == 0 64 | else: 65 | assert ret.status_code == 200 66 | 67 | def test_send_msg_failure(self, mock_request_post): 68 | mock_request_post.return_value.json.return_value = {"code": 1, "message": "error"} 69 | with pytest.raises(SendMsgException): 70 | for bot in self.chatbots(): 71 | bot.send_msg("test message") 72 | 73 | 74 | class TestChatBotFactory(TestChatBot): 75 | def test_bot_factory_send_msg(self, mock_request_post): 76 | mock_request_post.return_value.json.return_value = {"code": 0, "message": "ok"} 77 | bot = ChatBotFactory(chatbot_type=ChatBotType.FEISHU_CHATBOT).build( 78 | webhook_url="test_url", secret="test_secret" 79 | ) 80 | bot.send_msg("test message") 81 | 82 | mock_request_post.return_value.json.return_value = {"code": -1, "message": "ok"} 83 | with pytest.raises(SendMsgException): 84 | bot.send_msg("test message") 85 | 86 | 87 | class TestFeiShuAppServer: 88 | @pytest.fixture 89 | def app_server(self): 90 | return FeiShuAppServer(app_id="test_app_id", app_secret="test_app_secret") 91 | 92 | @pytest.fixture 93 | def mock_tenant_access_token(self, mocker): 94 | mock_post = mocker.patch("requests.post") 95 | mock_response = MagicMock() 96 | mock_response.json.return_value = {"code": 0, "tenant_access_token": "mock_token", "expire": 3600} 97 | mock_post.return_value = mock_response 98 | return mock_post 99 | 100 | @patch("requests.post") 101 | def test_get_tenant_access_token(self, mock_post, app_server): 102 | # Mock the response from requests 103 | mock_response = MagicMock() 104 | mock_response.json.return_value = {"code": 0, "tenant_access_token": "mock_token", "expire": 3600} 105 | mock_post.return_value = mock_response 106 | 107 | token = app_server._get_tenant_access_token() 108 | assert token == "mock_token" 109 | 110 | @patch("requests.post") 111 | def test_get_user_open_id(self, mock_post, app_server): 112 | # Mock the response from requests 113 | mock_response = MagicMock() 114 | mock_response.json.return_value = { 115 | "code": 0, 116 | "data": {"user_list": [{"mobile": "130xxxx1752", "user_id": "ou_xxx"}]}, 117 | } 118 | mock_post.return_value = mock_response 119 | 120 | user_list = app_server._get_user_open_id(mobiles=["130xxxx1752"]) 121 | assert user_list == [{"mobile": "130xxxx1752", "user_id": "ou_xxx"}] 122 | 123 | @patch("requests.get") 124 | def test_get_user_or_bot_groups(self, mock_get, app_server, mock_tenant_access_token): 125 | # Mock the response from requests 126 | mock_response = MagicMock() 127 | mock_response.json.return_value = { 128 | "code": 0, 129 | "data": {"items": [{"name": "test_group", "chat_id": "test_chat_id"}], "has_more": False}, 130 | } 131 | mock_get.return_value = mock_response 132 | 133 | groups = app_server._get_user_or_bot_groups() 134 | assert groups == [{"name": "test_group", "chat_id": "test_chat_id"}] 135 | 136 | @patch("requests.post") 137 | def test_send_msg(self, mock_post, app_server): 138 | # Mock the response from requests 139 | mock_response = MagicMock() 140 | mock_response.json.return_value = {"code": 0} 141 | mock_post.return_value = mock_response 142 | 143 | app_server.send_msg("test message", FeishuReceiveType.OPEN_ID, "test_open_id") 144 | 145 | 146 | class TestFeiShuTaskChatBot: 147 | @pytest.fixture 148 | def chat_bot(self): 149 | return FeiShuTaskChatBot(app_id="test_app_id", app_secret="test_app_secret") 150 | 151 | @patch("requests.post") 152 | def test_user_task_notify(self, mock_post, chat_bot): 153 | # Mock the response from requests 154 | mock_response = MagicMock() 155 | mock_response.json.return_value = { 156 | "code": 0, 157 | "data": {"user_list": [{"mobile": "130xxxx1752", "user_id": "ou_xxx"}]}, 158 | } 159 | mock_post.return_value = mock_response 160 | 161 | with patch.object(chat_bot, "send_msg") as mock_send_msg: 162 | chat_bot.user_task_notify("test content", receive_mobiles=["130xxxx1752"]) 163 | mock_send_msg.assert_called_with( 164 | "test content", receive_id_type=FeishuReceiveType.OPEN_ID, receive_id="ou_xxx" 165 | ) 166 | 167 | @patch("requests.get") 168 | @patch("requests.post") 169 | def test_user_group_task_notify(self, mock_post, mock_get, chat_bot): 170 | # Mock the response from requests 171 | mock_get_response = MagicMock() 172 | mock_get_response.json.return_value = { 173 | "code": 0, 174 | "data": {"items": [{"name": "test_group", "chat_id": "test_chat_id"}], "has_more": False}, 175 | } 176 | mock_get.return_value = mock_get_response 177 | 178 | mock_post_response = MagicMock() 179 | mock_post_response.json.return_value = { 180 | "code": 0, 181 | "data": {"user_list": [{"mobile": "130xxxx1752", "user_id": "ou_xxx"}]}, 182 | } 183 | mock_post.return_value = mock_post_response 184 | 185 | with patch.object(chat_bot, "send_msg") as mock_send_msg: 186 | chat_bot.user_group_task_notify("test content", "test_group", receive_mobiles=["130xxxx1752"]) 187 | mock_send_msg.assert_called() 188 | -------------------------------------------------------------------------------- /tests/connections/test_http_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: test_http_client.py 5 | # @Desc: { http客户端单测 } 6 | # @Date: 2024/08/08 15:20 7 | from unittest.mock import AsyncMock, MagicMock 8 | 9 | import pytest 10 | 11 | from py_tools.connections.http import AsyncHttpClient, HttpClient 12 | 13 | 14 | class TestAsyncHttpClient: 15 | test_url = "http://example.com" 16 | 17 | text_func_ret = "test_response" 18 | bytes_func_ret = b"test_bytes" 19 | json_func_ret = {"key": "value"} 20 | 21 | @pytest.fixture 22 | def mock_request(self, mocker): 23 | mocker_request = mocker.patch.object(AsyncHttpClient, "_request") 24 | 25 | mock_response = AsyncMock() 26 | mock_response.__aenter__.return_value = mock_response 27 | mock_response.__aexit__.return_value = AsyncMock() 28 | 29 | mock_response.text.return_value = self.text_func_ret 30 | mock_response.json.return_value = self.json_func_ret 31 | mock_response.read.return_value = self.bytes_func_ret 32 | 33 | mocker_request.return_value = mock_response 34 | 35 | return mock_response 36 | 37 | @pytest.mark.asyncio 38 | async def test_get_text(self, mock_request): 39 | resp = await AsyncHttpClient().get(url=self.test_url).text() 40 | assert resp == self.text_func_ret 41 | 42 | @pytest.mark.asyncio 43 | async def test_post_put_json(self, mock_request): 44 | resp = await AsyncHttpClient().post(url=self.test_url, data={"method": "post"}).json() 45 | assert resp == self.json_func_ret 46 | 47 | resp = await AsyncHttpClient().put(url=self.test_url, data={"method": "put"}).json() 48 | assert resp == self.json_func_ret 49 | 50 | @pytest.mark.asyncio 51 | async def test_get_bytes(self, mock_request): 52 | resp = await AsyncHttpClient().get(url=self.test_url).bytes() 53 | assert resp == self.bytes_func_ret 54 | 55 | @pytest.mark.asyncio 56 | async def test_upload_file(self, mock_request): 57 | resp = await AsyncHttpClient().upload_file(url=self.test_url, file=__file__).json() 58 | assert resp == self.json_func_ret 59 | 60 | 61 | class TestHttpClient: 62 | test_url = "http://example.com" 63 | 64 | text_func_ret = "test_response" 65 | bytes_func_ret = b"test_bytes" 66 | json_func_ret = {"key": "value"} 67 | 68 | @pytest.fixture 69 | def mock_request(self, mocker): 70 | mocker_request = mocker.patch("requests.Session.request") 71 | 72 | mock_response = MagicMock() 73 | mock_response.text = self.text_func_ret 74 | mock_response.json.return_value = self.json_func_ret 75 | mock_response.content = self.bytes_func_ret 76 | 77 | mocker_request.return_value = mock_response 78 | 79 | return mock_response 80 | 81 | def test_get_text(self, mock_request): 82 | resp = HttpClient().get(url=self.test_url).text 83 | assert resp == self.text_func_ret 84 | 85 | def test_post_put_json(self, mock_request): 86 | resp = HttpClient().post(url=self.test_url, data={"method": "post"}).json 87 | assert resp == self.json_func_ret 88 | 89 | resp = HttpClient().put(url=self.test_url, data={"method": "put"}).json 90 | assert resp == self.json_func_ret 91 | 92 | @pytest.mark.asyncio 93 | async def test_get_bytes(self, mock_request): 94 | resp = HttpClient().get(url=self.test_url).bytes 95 | assert resp == self.bytes_func_ret 96 | -------------------------------------------------------------------------------- /tests/decorators/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | 5 | from py_tools.decorators import retry, set_timeout 6 | from py_tools.exceptions import MaxRetryException, MaxTimeoutException 7 | 8 | 9 | class TestBaseDecorator: 10 | """通用装饰器测试""" 11 | 12 | @retry(max_count=3) 13 | def user_place_order_success(self): 14 | """用户下单成功模拟""" 15 | return {"code": 0, "msg": "ok"} 16 | 17 | @retry(max_count=3, interval=3) 18 | def user_place_order_fail(self): 19 | """用户下单失败重试模拟""" 20 | a = 1 / 0 21 | return {"code": 0, "msg": "ok"} 22 | 23 | def test_retry(self): 24 | """重试装饰器单测""" 25 | ret = self.user_place_order_success() 26 | assert ret["code"] == 0 27 | 28 | # 超过最大重试次数模拟 29 | with pytest.raises(MaxRetryException): 30 | self.user_place_order_fail() 31 | 32 | @set_timeout(3) 33 | def user_place_order(self): 34 | """用户下单超时模拟""" 35 | time.sleep(1) # 模拟业务超时 36 | return {"code": 0, "msg": "ok"} 37 | 38 | @set_timeout(2) 39 | def user_place_order_timeout(self): 40 | """用户下单模拟""" 41 | time.sleep(3) # 模拟业务超时 42 | return {"code": 0, "msg": "ok"} 43 | 44 | def test_timeout(self): 45 | """超时装饰器单测""" 46 | 47 | ret = self.user_place_order() 48 | assert ret.get("code") == 0 49 | 50 | # 超时 51 | with pytest.raises(MaxTimeoutException): 52 | self.user_place_order_timeout() 53 | -------------------------------------------------------------------------------- /tests/meta_cls/singleton.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @Desc: { 单例元类测试 } 5 | # @Date: 2023/08/28 11:01 6 | import threading 7 | 8 | from py_tools.meta_cls import SingletonMetaCls 9 | 10 | 11 | class TestSingletonMetaCls: 12 | """ 单例元类测试 """ 13 | singleton_set = set() 14 | 15 | class Foo(metaclass=SingletonMetaCls): 16 | 17 | def bar(self): 18 | pass 19 | 20 | def create_singleton(self): 21 | self.singleton_set.add(id(self.Foo())) 22 | 23 | def test_singleton_meta_cls(self): 24 | assert self.Foo() is self.Foo() 25 | 26 | # 多线程测试单例 27 | thread_list = list() 28 | for i in range(10): 29 | t = threading.Thread(target=self.create_singleton) 30 | t.start() 31 | thread_list.append(t) 32 | 33 | for thread in thread_list: 34 | thread.join() 35 | 36 | # 判断单例集合中只有一个对象说明地址全部一样 37 | assert len(self.singleton_set) == 1 38 | -------------------------------------------------------------------------------- /tests/utils/test_jwt_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: test_jwt_util.py 5 | # @Desc: { test jwt util } 6 | # @Date: 2024/11/04 15:32 7 | import datetime 8 | import time 9 | 10 | import pytest 11 | 12 | from py_tools.utils import JWTUtil 13 | 14 | # 设置测试用的密钥和算法 15 | SECRET_KEY = "test_secret_key" 16 | ALGORITHM = "HS256" 17 | 18 | 19 | class TestJWTUtil: 20 | """JWTUtil 工具类的测试用例。""" 21 | 22 | @pytest.fixture(autouse=True) 23 | def setup(self): 24 | """初始化 JWTUtil 实例,用于每个测试方法。""" 25 | self.jwt_util = JWTUtil(secret_key=SECRET_KEY, algorithm=ALGORITHM, expiration_minutes=1) 26 | 27 | def test_generate_token(self): 28 | """测试生成 JWT 令牌。""" 29 | data = {"user_id": "12345", "role": "admin"} 30 | token = self.jwt_util.generate_token(data) 31 | assert isinstance(token, str), "生成的令牌应该是字符串" 32 | 33 | def test_verify_token(self): 34 | """测试验证有效的 JWT 令牌。""" 35 | data = {"user_id": "12345", "role": "admin"} 36 | token = self.jwt_util.generate_token(data) 37 | decoded_data = self.jwt_util.verify_token(token) 38 | assert decoded_data is not None, "验证后的数据不应为空" 39 | assert decoded_data["user_id"] == "12345", "解码数据中的 user_id 应该与输入数据匹配" 40 | assert decoded_data["role"] == "admin", "解码数据中的 role 应该与输入数据匹配" 41 | 42 | def test_verify_token_expired(self): 43 | """测试过期的 JWT 令牌验证。""" 44 | data = {"user_id": "12345", "role": "admin"} 45 | token = self.jwt_util.generate_token(data, expires_delta=datetime.timedelta(seconds=1)) 46 | 47 | # 等待令牌过期 48 | time.sleep(2) 49 | decoded_data = self.jwt_util.verify_token(token) 50 | assert decoded_data is None, "过期的令牌应返回 None" 51 | 52 | def test_refresh_token(self): 53 | """测试刷新 JWT 令牌。""" 54 | data = {"user_id": "12345", "role": "admin"} 55 | token = self.jwt_util.generate_token(data, expires_delta=datetime.timedelta(seconds=5)) 56 | 57 | # 在原令牌过期前刷新令牌 58 | refreshed_token = self.jwt_util.refresh_token(token, expires_delta=datetime.timedelta(minutes=1)) 59 | assert refreshed_token is not None, "刷新后的令牌不应为空" 60 | assert refreshed_token != token, "刷新后的令牌应不同于原令牌" 61 | 62 | # 验证刷新后的令牌有效性 63 | decoded_data = self.jwt_util.verify_token(refreshed_token) 64 | assert decoded_data is not None, "刷新后的令牌验证应成功" 65 | assert decoded_data["user_id"] == "12345", "刷新令牌的解码数据应与原数据相同" 66 | assert decoded_data["role"] == "admin", "刷新令牌的解码数据应与原数据相同" 67 | -------------------------------------------------------------------------------- /tests/utils/test_re_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Hui 4 | # @File: test_re_util.py 5 | # @Desc: { RegexUtil unitest } 6 | # @Date: 2024/09/11 15:47 7 | from py_tools.utils import RegexUtil 8 | 9 | 10 | class TestRegexUtil: 11 | """RegexUtil 单元测试类""" 12 | 13 | def test_find_http_links(self): 14 | """测试 HTTP 链接的匹配""" 15 | text = "访问 https://www.juejin.cn 或 http://example.com 了解更多信息。" 16 | expected = ["https://www.juejin.cn", "http://example.com"] 17 | assert RegexUtil.find_http_links(text) == expected 18 | 19 | def test_find_chinese_characters(self): 20 | """测试中文字符的匹配""" 21 | text = "这是一个测试" 22 | expected = ["这", "是", "一", "个", "测", "试"] 23 | assert RegexUtil.find_chinese_characters(text) == expected 24 | 25 | def test_find_double_byte_characters(self): 26 | """测试双字节字符的匹配""" 27 | text = "test这是测试" 28 | expected = ["t", "e", "s", "t", "这", "是", "测", "试"] 29 | assert RegexUtil.find_double_byte_characters(text) == expected 30 | 31 | def test_find_emails(self): 32 | """测试 Email 地址的匹配""" 33 | text = "联系我: huidbk@example.com, hui@domain.cn" 34 | expected = ["huidbk@example.com", "hui@domain.cn"] 35 | assert RegexUtil.find_emails(text) == expected 36 | 37 | def test_find_chinese_phone_numbers(self): 38 | """测试中国大陆手机号码的匹配""" 39 | text = "我的手机号是13800138000,朋友的手机号是14712345678" 40 | expected = ["13800138000", "14712345678"] 41 | assert RegexUtil.find_chinese_phone_numbers(text) == expected 42 | 43 | def test_find_qq_numbers(self): 44 | """测试腾讯QQ号的匹配""" 45 | text = "我的QQ号是123456789,朋友的QQ号是987654321" 46 | expected = ["123456789", "987654321"] 47 | assert RegexUtil.find_qq_numbers(text) == expected 48 | 49 | def test_find_postal_codes(self): 50 | """测试邮政编码的匹配""" 51 | text = "我的邮政编码是123456" 52 | expected = ["123456"] 53 | assert RegexUtil.find_postal_codes(text) == expected 54 | 55 | def test_find_dates(self): 56 | """测试日期格式的匹配""" 57 | text = "今天的日期是2024-08-24,昨天是2024/08/23" 58 | expected = ["2024-08-24", "2024/08/23"] 59 | assert RegexUtil.find_dates(text) == expected 60 | 61 | def test_find_integers(self): 62 | """测试整数的匹配""" 63 | text = "正数: 123, 负数: -456" 64 | expected = ["123", "-456"] 65 | assert RegexUtil.find_integers(text) == expected 66 | 67 | def test_find_positive_floats(self): 68 | """测试正浮点数的匹配""" 69 | text = "正浮点数: 12.34, 0.56" 70 | expected = ["12.34", "0.56"] 71 | assert RegexUtil.find_positive_floats(text) == expected 72 | 73 | def test_find_negative_floats(self): 74 | """测试负浮点数的匹配""" 75 | text = "负浮点数: -12.34, -0.56,负整数 -100" 76 | expected = ["-12.34", "-0.56"] 77 | assert RegexUtil.find_negative_floats(text) == expected 78 | 79 | def test_find_usernames(self): 80 | """测试用户名的匹配""" 81 | text = "user123, 张三_001" 82 | expected = ["user123", "张三_001"] 83 | assert RegexUtil.find_usernames(text) == expected 84 | --------------------------------------------------------------------------------