├── .gitignore ├── LICENSE ├── README.md ├── blueprints ├── __pycache__ │ ├── blueprints_image_generation.cpython-38.pyc │ ├── blueprints_midjourney.cpython-38.pyc │ ├── blueprints_mj.cpython-38.pyc │ ├── blueprints_openai.cpython-38.pyc │ ├── blueprints_text_generation.cpython-38.pyc │ └── blueprints_user.cpython-38.pyc └── blueprints_image_generation.py ├── entity ├── __pycache__ │ ├── mj_scheme.cpython-38.pyc │ ├── models.cpython-38.pyc │ ├── openai_scheme.cpython-38.pyc │ └── user_scheme.cpython-38.pyc ├── gen.sh ├── mj_scheme.py ├── models.py ├── openai_scheme.py └── user_scheme.py ├── main.py ├── requirements.txt ├── resources ├── api_params │ ├── blend.json │ ├── describe.json │ ├── imagine.json │ ├── info.json │ ├── message.json │ ├── reroll.json │ ├── shorten.json │ ├── upscale.json │ └── variation.json └── config │ └── prod_mj_config.yaml ├── service ├── __pycache__ │ ├── celery_client.cpython-38.pyc │ ├── celery_service.cpython-38.pyc │ ├── discord_http_service.cpython-38.pyc │ ├── mj_data_service.cpython-38.pyc │ ├── notify_service.cpython-38.pyc │ └── template_controller.cpython-38.pyc ├── discord_http_service.py ├── discord_ws_service.py ├── mail_service.py ├── mj_data_service.py ├── notify_service.py └── template_controller.py ├── support ├── Injector.py ├── __init__.py ├── __pycache__ │ ├── Injector.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── image_controller.cpython-38.pyc │ ├── load_balancer.cpython-38.pyc │ ├── mj_account.cpython-38.pyc │ ├── mj_config.cpython-38.pyc │ ├── mj_task.cpython-38.pyc │ ├── task_controller.cpython-38.pyc │ └── template_controller.cpython-38.pyc ├── image_controller.py ├── load_balancer.py ├── mj_account.py ├── mj_config.py └── task_controller.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── common_util.cpython-38.pyc │ ├── logger_util.cpython-38.pyc │ └── wx_util.cpython-38.pyc ├── bootstrap.py ├── common_util.py ├── logger_util.py └── wx_util.py └── wss ├── __pycache__ ├── mj_wss_manager.cpython-38.pyc └── mj_wss_proxy.cpython-38.pyc ├── handler ├── __pycache__ │ ├── base_message_handler.cpython-38.pyc │ ├── describe_success_handler.cpython-38.pyc │ ├── imagine_handler.cpython-38.pyc │ ├── imagine_hanlder.cpython-38.pyc │ ├── message_create_handler.cpython-38.pyc │ ├── upscale_handler.cpython-38.pyc │ └── variation_handler.cpython-38.pyc ├── base_message_handler.py ├── describe_success_handler.py ├── imagine_hanlder.py ├── shorten_success_handler.py ├── upscale_handler.py └── variation_handler.py ├── mj_wss_manager.py └── mj_wss_proxy.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | */**/*.pyc 3 | .cache_* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

py-midjourney-proxy

4 | 5 | English | [中文](./README_CN.md) 6 | 7 | An unofficial python implementation of the Discord proxy for MidJourney. 8 | 9 | Just one command to build your own MidJourney proxy server. 10 |
11 | 12 | 13 | ## Main Functions 14 | 15 | - [x] Supports Imagine instructions and related actions 16 | - [x] Supports adding image base64 as a placeholder when using the Imagine command 17 | - [x] Supports Blend (image blending) and Describe (image to text) commands 18 | - [x] Supports real-time progress tracking of tasks 19 | - [x] Prompt sensitive word pre-detection, supports override adjustment 20 | - [x] User-token connects to WSS (WebSocket Secure), allowing access to error messages and full functionality 21 | - [x] Supports multi-account configuration, with each account able to set up corresponding task queues 22 | 23 | ## Prerequisites for use 24 | 25 | 1. Register and subscribe to MidJourney, create `your own server and channel`, refer 26 | to https://docs.midjourney.com/docs/quick-start 27 | 2. Obtain user Token, server ID, channel ID: [Method of acquisition](./docs/discord-params.md) 28 | 29 | 30 | ## Local development 31 | 32 | - Depends on python and fastapi 33 | - Change configuration items: Edit resources/config/prod_mj_config.yaml 34 | - Project execution: Start the main.py 35 | 36 | ## Configuration items 37 | 38 | - mj.accounts: Refer 39 | to [Account pool configuration](./docs/config.md#%E8%B4%A6%E5%8F%B7%E6%B1%A0%E9%85%8D%E7%BD%AE%E5%8F%82%E8%80%83) 40 | - mj.task-store.type: Task storage method, default is in_memory (in memory, lost after restart), Redis is an alternative 41 | option. 42 | - mj.task-store.timeout: Task storage expiration time, tasks are deleted after expiration, default is 30 days. 43 | - mj.api-secret: API key, if left empty, authentication is not enabled; when calling the API, you need to add the 44 | request header 'mj-api-secret'. 45 | - mj.translate-way: The method for translating Chinese prompts into English, options include null (default), Baidu, or 46 | GPT. 47 | - For more configuration options, see [Configuration items](./docs/config.md) 48 | 49 | ## Related documentation 50 | 51 | 1. [API Interface Description](./docs/api.md) 52 | 2. [Version Update Log](https://github.com/novicezk/midjourney-proxy/wiki/%E6%9B%B4%E6%96%B0%E8%AE%B0%E5%BD%95) 53 | 54 | ## Precautions 55 | 56 | 1. Frequent image generation and similar behaviors may trigger warnings on your Midjourney account. Please use with 57 | caution. 58 | 2. For common issues and solutions, see [Wiki / FAQ](https://github.com/novicezk/midjourney-proxy/wiki/FAQ) 59 | 3. Interested friends are also welcome to join the discussion group. If the group is full from scanning the code, you 60 | can add the administrator’s WeChat to be invited into the group. Please remark: mj join group. 61 | 62 | 微信二维码 63 | 64 | ## Application Project 65 | 66 | If you have a project that depends on this one and is open source, feel free to contact the author to be added here for 67 | display. 68 | 69 | - [wechat-midjourney](https://github.com/novicezk/wechat-midjourney) : A proxy WeChat client that connects to 70 | MidJourney, intended only as an example application scenario, will no longer be updated. 71 | - [chatgpt-web-midjourney-proxy](https://github.com/Dooy/chatgpt-web-midjourney-proxy) : chatgpt web, midjourney, 72 | gpts,tts, whisper A complete UI solution 73 | - [chatnio](https://github.com/Deeptrain-Community/chatnio) : The next-generation AI one-stop solution for B/C end, an aggregated model platform with exquisite UI and powerful functions 74 | - [new-api](https://github.com/Calcium-Ion/new-api) : An API interface management and distribution system compatible with the Midjourney Proxy 75 | - [stable-diffusion-mobileui](https://github.com/yuanyuekeji/stable-diffusion-mobileui) : SDUI, based on this interface 76 | and SD (System Design), can be packaged with one click to generate H5 and mini-programs. 77 | - [MidJourney-Web](https://github.com/ConnectAI-E/MidJourney-Web) : 🍎 Supercharged Experience For MidJourney On Web UI 78 | 79 | ## Open API 80 | 81 | Provides unofficial MJ/SD open API, add administrator WeChat for inquiries, please remark: api 82 | 83 | ## Others 84 | 85 | If you find this project helpful, please consider giving it a star. 86 | 87 | ## Star History 88 | 89 | [![Star History Chart](https://api.star-history.com/svg?repos=Anychnn/py-midjourney-proxy&type=Date)](https://star-history.com/#Anychnn/py-midjourney-proxy&Date) -------------------------------------------------------------------------------- /blueprints/__pycache__/blueprints_image_generation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_image_generation.cpython-38.pyc -------------------------------------------------------------------------------- /blueprints/__pycache__/blueprints_midjourney.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_midjourney.cpython-38.pyc -------------------------------------------------------------------------------- /blueprints/__pycache__/blueprints_mj.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_mj.cpython-38.pyc -------------------------------------------------------------------------------- /blueprints/__pycache__/blueprints_openai.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_openai.cpython-38.pyc -------------------------------------------------------------------------------- /blueprints/__pycache__/blueprints_text_generation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_text_generation.cpython-38.pyc -------------------------------------------------------------------------------- /blueprints/__pycache__/blueprints_user.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_user.cpython-38.pyc -------------------------------------------------------------------------------- /blueprints/blueprints_image_generation.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, UploadFile, Request, File 2 | 3 | from entity.mj_scheme import ( 4 | ImagineRequest, 5 | ImagineResponse, 6 | QueryImagineStatusRequest, 7 | QueryImagineStatusResponse, 8 | UpscaleRequest, 9 | QueryUpscaleStatusRequest, 10 | QueryVariationStatusRequest, 11 | ImageUploadRequest, 12 | DescribeRequest, 13 | BlendRequest, 14 | ShortenRequest 15 | ) 16 | from support.Injector import injector 17 | from support.task_controller import TaskController 18 | from service.mj_data_service import MjTask 19 | from service.mj_data_service import MjDataService 20 | from support.image_controller import ImageController 21 | from support.load_balancer import LoadBalancer 22 | import datetime 23 | from fastapi.responses import JSONResponse 24 | from utils import common_util 25 | import os 26 | import requests 27 | import json 28 | 29 | task_controller = injector.get(TaskController) 30 | mj_data_service = injector.get(MjDataService) 31 | load_balancer = injector.get(LoadBalancer) 32 | # image_controller = injector.get(ImageController) 33 | 34 | image_router = APIRouter() 35 | 36 | 37 | @image_router.post("/midjourney/imagine", response_model=ImagineResponse, summary="提交Imagine任务") 38 | async def submit_imagine(request: Request, imagine_request: ImagineRequest): 39 | prompt = imagine_request.prompt 40 | if not prompt: 41 | return { 42 | "status": 400, 43 | "msg": "prompt不能为空", 44 | "data": None, 45 | } 46 | 47 | imgs = [] 48 | if imagine_request.imgs: 49 | for img in imagine_request.imgs: 50 | image_url = load_balancer.get_discord_http_service().upload_img_if_bs64(img) 51 | imgs.append(image_url) 52 | # 将reference_imgs用空格进行拼接 53 | if len(imgs) > 0: 54 | image_prompt = " ".join(imgs) 55 | else: 56 | image_prompt = "" 57 | 58 | text_prompt = prompt 59 | 60 | if imagine_request.mode == "FAST": 61 | text_prompt += " --fast" 62 | elif imagine_request.mode == "RELAX": 63 | text_prompt += " --relax" 64 | 65 | task = MjTask() 66 | task.prompt = text_prompt 67 | task.image_prompt = image_prompt 68 | task.task_type = "imagine" 69 | task.task_status = "pending" 70 | nonce = mj_data_service.next_nonce() 71 | task.nonce = str(nonce) 72 | task.create_time = datetime.datetime.now() 73 | 74 | task_id = task_controller.submit_imagine(task) 75 | 76 | return { 77 | "status": 200, 78 | "msg": "success", 79 | "data": { 80 | "task_id": task_id, 81 | "task_status": task.task_status, 82 | }, 83 | } 84 | 85 | 86 | @image_router.post("/midjourney/query_imagine_status", summary="查询Action任务") 87 | async def query_imagine_status(request: Request, body: QueryImagineStatusRequest): 88 | task_id = body.task_id 89 | task: MjTask = mj_data_service.get_task_by_task_id(task_id) 90 | if not task: 91 | return { 92 | "status": 10002, 93 | "msg": f"未查询到任务:task_id:{task_id}", 94 | } 95 | else: 96 | # 根据任务状态返回结果 97 | image_url = task.image_url 98 | if image_url: 99 | image_url = image_url.replace( 100 | "https://cdn.discordapp.com", "http://discordcdn.aidomore.com") 101 | 102 | return { 103 | "status": 200, 104 | "msg": "success", 105 | "data": { 106 | "task_id": task.task_id, 107 | "task_status": task.task_status, 108 | "image_url": image_url, 109 | "progress": task.progress, 110 | }, 111 | } 112 | 113 | 114 | @image_router.post("/midjourney/action", summary="提交Action任务") 115 | async def submit_upscale(request: Request, body: UpscaleRequest): 116 | task_id = body.task_id 117 | mode = body.mode 118 | # index = body.index 119 | # action_type = body.action_type 120 | custom_id = body.custom_id 121 | 122 | task: MjTask = mj_data_service.get_task_by_task_id(task_id) 123 | action_task = MjTask() 124 | action_task.task_type = "action" 125 | action_task.create_time = datetime.datetime.now() 126 | action_task.task_status = "pending" 127 | action_task.custom_id = custom_id 128 | action_task.reference_message_id = task.message_id 129 | nonce = mj_data_service.next_nonce() 130 | action_task.nonce = str(nonce) 131 | task_id = task_controller.submit_upscale(action_task) 132 | return { 133 | "status": 200, 134 | "msg": "success", 135 | "data": { 136 | "task_id": task_id, 137 | "task_status": action_task.task_status, 138 | } 139 | } 140 | 141 | 142 | # 查询upscale 143 | @image_router.post("/midjourney/query_action_status", summary="查询Upscale任务") 144 | async def query_upscale_status(request: Request, body: QueryUpscaleStatusRequest): 145 | task_id = body.task_id 146 | task: MjTask = mj_data_service.get_task_by_task_id(task_id) 147 | if not task: 148 | return { 149 | "status": 10002, 150 | "msg": f"未查询到任务:task_id:{task_id}" 151 | } 152 | else: 153 | # 根据任务状态返回结果 154 | image_url = task.image_url 155 | if image_url: 156 | image_url = image_url.replace( 157 | "https://cdn.discordapp.com", "http://discordcdn.aidomore.com") 158 | 159 | return { 160 | "status": 200, 161 | "msg": "success", 162 | "data": { 163 | "task_id": task.task_id, 164 | "task_type": task.task_type, 165 | "task_status": task.task_status, 166 | "image_url": image_url, 167 | "progress": task.progress, 168 | "description": task.description, 169 | "buttons": task.buttons, 170 | }, 171 | } 172 | 173 | 174 | @image_router.post("/midjourney/describe", summary="提交Describe任务") 175 | async def submit_describe_func(request: Request, body: DescribeRequest): 176 | if not body.img: 177 | return { 178 | "status": 10001, 179 | "msg": "需要传入参考图片" 180 | } 181 | 182 | task = MjTask() 183 | 184 | task.task_type = "describe" 185 | task.task_status = "pending" 186 | nonce = mj_data_service.next_nonce() 187 | task.nonce = str(nonce) 188 | task.image_prompt = body.img 189 | task.create_time = datetime.datetime.now() 190 | task_id = task_controller.submit_describe(task) 191 | return { 192 | "status": 200, 193 | "msg": "success", 194 | "data": { 195 | "task_id": task_id, 196 | "task_status": task.task_status, 197 | } 198 | } 199 | 200 | 201 | # blend 202 | @image_router.post("/midjourney/blend", summary="提交Blend任务") 203 | async def submit_blend_func(request: Request, body: BlendRequest): 204 | if not body.imgs: 205 | return { 206 | "status": 10001, 207 | "msg": "需要传入参考图片" 208 | } 209 | dimensions = body.dimensions 210 | task = MjTask() 211 | task.task_type = "blend" 212 | task.blend_imgs = body.imgs 213 | task.task_status = "pending" 214 | nonce = mj_data_service.next_nonce() 215 | task.nonce = str(nonce) 216 | task.dimensions = dimensions 217 | task.create_time = datetime.datetime.now() 218 | task_id = task_controller.submit_blend(task) 219 | return { 220 | "status": 200, 221 | "msg": "success", 222 | "data": { 223 | "task_id": task_id, 224 | "task_status": task.task_status, 225 | } 226 | } 227 | 228 | 229 | # 图片上传base64 230 | @image_router.post("/midjourney/upload", summary="base64图片上传") 231 | async def upload_image(request: Request, body: ImageUploadRequest): 232 | if not body.bs64: 233 | return { 234 | "status": 10001, 235 | "msg": "需要传入图片" 236 | } 237 | image_url = image_controller.upload_img_if_bs64(body.bs64) 238 | return { 239 | "status": 200, 240 | "msg": "success", 241 | "data": { 242 | "image_url": image_url 243 | } 244 | } 245 | 246 | 247 | # 图片上传 248 | @image_router.post("/midjourney/upload_image", summary="图片上传") 249 | async def upload_image(file: UploadFile = File(...)): 250 | file_name = file.filename 251 | 252 | if not os.path.exists(".cache_imgs"): 253 | os.mkdir(".cache_imgs") 254 | with open(f"./.cache_imgs/{file_name}", "wb") as f: 255 | f.write(file.file.read()) 256 | image = common_util.getImage(f"./.cache_imgs/{file_name}") 257 | img_bs64 = common_util.image_to_base64(image) 258 | 259 | image_url = image_controller.upload_img_if_bs64(img_bs64) 260 | return { 261 | "status": 200, 262 | "msg": "success", 263 | "data": { 264 | "image_url": image_url 265 | } 266 | } 267 | 268 | # shorten 269 | @image_router.post("/midjourney/shorten", summary="shorten") 270 | async def shorten(request: Request, body: ShortenRequest): 271 | if not body.prompt: 272 | return { 273 | "status": 10002, 274 | "msg": "需要传入prompt" 275 | } 276 | 277 | task = MjTask() 278 | task.task_type = "shorten" 279 | task.prompt = body.prompt 280 | task.task_status = "pending" 281 | nonce = mj_data_service.next_nonce() 282 | task.nonce = str(nonce) 283 | task.create_time = datetime.datetime.now() 284 | task_id = task_controller.submit_shorten(task) 285 | 286 | return { 287 | "status": 200, 288 | "msg": "success", 289 | "data": { 290 | "task_id": task_id, 291 | "task_status": task.task_status, 292 | } 293 | } -------------------------------------------------------------------------------- /entity/__pycache__/mj_scheme.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/mj_scheme.cpython-38.pyc -------------------------------------------------------------------------------- /entity/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /entity/__pycache__/openai_scheme.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/openai_scheme.cpython-38.pyc -------------------------------------------------------------------------------- /entity/__pycache__/user_scheme.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/user_scheme.cpython-38.pyc -------------------------------------------------------------------------------- /entity/gen.sh: -------------------------------------------------------------------------------- 1 | # python -m pwiz -e mysql -H 101.43.141.18 -u xiaohongdou -P RLKXMDsPtfBWpRHk xiaohongdou > models.py 2 | python -m pwiz -e mysql -H localhost -u xiaohongdouapi -P DtcxSyP7aMW8Edws xiaohongdouapi > models.py -------------------------------------------------------------------------------- /entity/mj_scheme.py: -------------------------------------------------------------------------------- 1 | from pydoc import describe 2 | from typing import Optional 3 | 4 | from typing import List, Tuple, Set 5 | from pydantic import BaseModel, Field 6 | 7 | class ImagineRequest(BaseModel): 8 | # RELAX和FAST两种,默认为RELAX 9 | mode: Optional[str] = "RELAX" 10 | # 任务状态发生变化自动回调的地址 11 | notify_hook: Optional[str] 12 | # 用于参考的图片列表,可以传入base64或者url 13 | imgs: Optional[List[str]] 14 | prompt: str 15 | 16 | class ImagineResponseData(BaseModel): 17 | # 任务ID 18 | task_id: str 19 | # 任务状态 20 | task_status: str 21 | # 图片地址 22 | image_url: Optional[str] 23 | 24 | class ImagineResponse(BaseModel): 25 | status: int = Field(..., title="状态码", description="状态码") 26 | msg: str = Field(..., title="消息", description="消息") 27 | data: ImagineResponseData = Field(..., title="数据", description="数据") 28 | 29 | class QueryImagineStatusRequest(BaseModel): 30 | task_id: str = Field(..., title="任务ID", description="任务ID") 31 | # 返回图片的格式,base64或者url 32 | # img_type: str = "base64" 33 | 34 | class QueryImagineStatusResponseData(BaseModel): 35 | task_id: str = Field(..., title="任务ID", description="任务ID") 36 | task_status: str = Field(..., title="状态码", description="状态码") 37 | image_url: Optional[str] 38 | 39 | 40 | class QueryImagineStatusResponse(BaseModel): 41 | status: int = Field(..., title="状态码", description="状态码") 42 | msg: str = Field(..., title="消息", description="消息") 43 | data: QueryImagineStatusResponseData = Field(..., title="数据", description="数据") 44 | 45 | 46 | class UpscaleRequest(BaseModel): 47 | # 任务ID 48 | task_id: str 49 | # 缩放的索引 50 | # index: int 51 | # action_type: str 52 | mode: Optional[str] 53 | custom_id: str 54 | # 任务状态发生变化自动回调的地址 55 | notify_hook: Optional[str] 56 | 57 | 58 | class QueryUpscaleStatusRequest(BaseModel): 59 | task_id: str = Field(..., title="任务ID", description="任务ID") 60 | 61 | 62 | class QueryVariationStatusRequest(BaseModel): 63 | task_id: str = Field(..., title="任务ID", description="任务ID") 64 | 65 | 66 | class ImageUploadRequest(BaseModel): 67 | bs64: str 68 | 69 | 70 | class DescribeRequest(BaseModel): 71 | img: str 72 | 73 | class BlendRequest(BaseModel): 74 | imgs: List[str] 75 | dimensions: Optional[str] 76 | 77 | class ShortenRequest(BaseModel): 78 | prompt: str -------------------------------------------------------------------------------- /entity/models.py: -------------------------------------------------------------------------------- 1 | from peewee import * 2 | 3 | database = MySQLDatabase('xiaohongdouapi', **{'charset': 'utf8', 'sql_mode': 'PIPES_AS_CONCAT', 'use_unicode': True, 'host': 'localhost', 'user': 'xiaohongdouapi', 'password': 'DtcxSyP7aMW8Edws'}) 4 | 5 | class UnknownField(object): 6 | def __init__(self, *_, **__): pass 7 | 8 | class BaseModel(Model): 9 | class Meta: 10 | database = database 11 | 12 | class User(BaseModel): 13 | create_time = DateTimeField(null=True) 14 | email = CharField(null=True) 15 | phone = CharField(null=True) 16 | update_time = DateTimeField(null=True) 17 | 18 | class Meta: 19 | table_name = 'User' 20 | 21 | -------------------------------------------------------------------------------- /entity/openai_scheme.py: -------------------------------------------------------------------------------- 1 | from pydoc import describe 2 | from typing import Optional 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class CompletionRequest(BaseModel): 7 | model: str 8 | prompt: str 9 | temperature: float 10 | top_p: float 11 | stream: bool 12 | 13 | class CompletionResponseData(BaseModel): 14 | text: str 15 | 16 | 17 | class CompletionResponse(BaseModel): 18 | status: int=Field(..., description="状态") 19 | msg: str=Field(..., description="消息") 20 | data: CompletionResponseData=Field(..., description="数据") -------------------------------------------------------------------------------- /entity/user_scheme.py: -------------------------------------------------------------------------------- 1 | from pydoc import describe 2 | from typing import Optional 3 | 4 | from typing import List, Tuple, Set 5 | from pydantic import BaseModel, Field 6 | 7 | # 用户注册请求 8 | class ShowQrcodeRequest(BaseModel): 9 | pass 10 | 11 | # 二维码请求 12 | class QrcodeResponseData(BaseModel): 13 | ticketUrl: str = Field(..., description="二维码地址") 14 | sceneStr: int = Field(..., description="二维码场景值") 15 | 16 | # 二维码展示响应 17 | class ShowQrcodeResponse(BaseModel): 18 | status: int = Field(..., description="状态码") 19 | msg: str = Field(..., description="消息") 20 | data: QrcodeResponseData = Field(..., description="数据") 21 | 22 | 23 | # 二维码状态请求 24 | class QrcodeStatusRequest(BaseModel): 25 | sceneStr: str = Field(..., description="二维码场景值") 26 | 27 | 28 | class QrcodeStatusData(BaseModel): 29 | status: int = Field(..., description="状态码") 30 | uid: int = Field(..., description="用户ID") 31 | 32 | # 二维码状态响应 33 | class QrcodeStatusResponse(BaseModel): 34 | status: int = Field(..., description="状态码") 35 | msg: str = Field(..., description="消息") 36 | data: QrcodeStatusData = Field(..., description="数据") 37 | 38 | 39 | # 用户秘钥生成 40 | class SecretGenerateRequest(BaseModel): 41 | pass 42 | 43 | 44 | # 用户秘钥生成响应 45 | class SecretGenerateResponse(BaseModel): 46 | status: int = Field(..., description="状态码") 47 | msg: str = Field(..., description="消息") 48 | data: str = Field(..., description="数据") 49 | 50 | # 秘钥查询 51 | class SecretQueryRequest(BaseModel): 52 | uid: int = Field(..., description="用户ID") 53 | 54 | class SecretQueryData(BaseModel): 55 | secret: str = Field(..., description="秘钥") 56 | 57 | # 秘钥查询响应 58 | class SecretQueryResponse(BaseModel): 59 | status: int = Field(..., description="状态码") 60 | msg: str = Field(..., description="消息") 61 | data: SecretQueryData = Field(..., description="数据") 62 | 63 | # 支付二维码生成 64 | class PayQrcodeGenerateRequest(BaseModel): 65 | uid: int = Field(..., description="用户ID") 66 | 67 | class PayQrcodeGenerateData(BaseModel): 68 | qrcode_img_bs64: str = Field(..., description="二维码图片base64") 69 | price: int = Field(..., description="价格,单位分") 70 | order_id: str = Field(..., description="订单ID") 71 | 72 | 73 | # 支付二维码生成响应 74 | class PayQrcodeGenerateResponse(BaseModel): 75 | status: int = Field(..., description="状态码") 76 | msg: str = Field(..., description="消息") 77 | data: str = Field(..., description="数据") 78 | 79 | # 支付二维码状态查询 80 | class PayQrcodeStatusRequest(BaseModel): 81 | order_id: str = Field(..., description="订单ID") 82 | 83 | # 支付二维码状态查询响应 84 | class PayQrcodeStatusData(BaseModel): 85 | status: int = Field(..., description="状态码") 86 | price: int = Field(..., description="价格,单位分") 87 | payed: bool = Field(..., description="是否支付") 88 | 89 | # 支付二维码状态查询响应 90 | class PayQrcodeStatusResponse(BaseModel): 91 | status: int = Field(..., description="状态码") 92 | msg: str = Field(..., description="消息") 93 | data: PayQrcodeStatusData = Field(..., description="数据") 94 | 95 | 96 | 97 | 98 | class RegisterRequest(BaseModel): 99 | account: str = Field(..., description="账号") 100 | password: str = Field(..., description="密码") 101 | 102 | 103 | class LoginRequest(BaseModel): 104 | account: str = Field(..., description="账号") 105 | password: str = Field(..., description="密码") 106 | 107 | 108 | class LoginResponseData(BaseModel): 109 | token: str = Field(..., description="token") 110 | 111 | 112 | class LoginResponse(BaseModel): 113 | status: int = Field(..., description="状态码") 114 | msg: str = Field(..., description="消息") 115 | data: LoginResponseData = Field(..., description="数据") -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- ecoding: utf-8 -*- 2 | # @Author: anyang 3 | # @Time: 2024/4/5 4 | 5 | import asyncio 6 | import uvicorn 7 | from fastapi import FastAPI 8 | from blueprints.blueprints_image_generation import image_router 9 | import yaml 10 | # from app.app import app 11 | from support.mj_config import MjConfig 12 | # from union_api_server.Injector import injector 13 | from support.Injector import injector 14 | import subprocess 15 | from wss.mj_wss_proxy import MjWssSercice 16 | from wss.mj_wss_manager import MjWssManager 17 | 18 | router = FastAPI() 19 | router.include_router(image_router, prefix="/image", tags=["图片生成"]) 20 | 21 | @router.on_event("startup") 22 | async def startup_event(): 23 | pass 24 | 25 | 26 | if __name__ == "__main__": 27 | import argparse 28 | parser = argparse.ArgumentParser( 29 | description='union_api_server server.') 30 | parser.add_argument('--port', type=int, help='server port', default=6013) 31 | 32 | # 从injector获取mj_config 33 | mj_config = injector.get(MjConfig) 34 | 35 | # 启动 Celery worker 36 | # subprocess.Popen(['celery', '-A', 'celery_module.celery_app', 'worker', '--loglevel=INFO','-c',"1"]) 37 | 38 | # 是否开启新的线程启动对discord_bot的wss监听 39 | # if mj_config.mj_config["common"]["launch_discord_bot"]: 40 | mj_wss_manager = injector.get(MjWssManager) 41 | mj_wss_manager.start_all() 42 | 43 | args = parser.parse_args() 44 | uvicorn.run(router, port=args.port, host="0.0.0.0") 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | toml 2 | numpy 3 | uvicorn 4 | fastapi 5 | Pillow 6 | requests 7 | pyyaml 8 | python-multipart 9 | tinydb 10 | ordered-set 11 | dill 12 | peewee 13 | celery 14 | redis 15 | qrcode 16 | BeautifulSoup4 17 | lxml 18 | oss2 19 | alibabacloud_imageseg20191230==3.0.0 20 | cachetools 21 | pymongo 22 | injector 23 | requests_toolbelt -------------------------------------------------------------------------------- /resources/api_params/blend.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 2, 3 | "guild_id": "$guild_id", 4 | "channel_id": "$channel_id", 5 | "application_id": "$app_id$", 6 | "session_id": "$session_id$", 7 | "nonce": "$nonce", 8 | "data": { 9 | "version": "1166847114203123796", 10 | "id": "1062880104792997970", 11 | "name": "blend", 12 | "type": 1, 13 | "options": [], 14 | "attachments": [] 15 | } 16 | } -------------------------------------------------------------------------------- /resources/api_params/describe.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 2, 3 | "application_id": "$app_id$", 4 | "guild_id": "$guild_id$", 5 | "channel_id": "$channel_id$", 6 | "session_id": "$session_id$", 7 | "data": { 8 | "version": "1237876415471554625", 9 | "id": "1092492867185950852", 10 | "name": "describe", 11 | "type": 1, 12 | "options": [ 13 | { 14 | "type": 3, 15 | "name": "link", 16 | "value": "$prompt$" 17 | } 18 | ], 19 | "attachments": [] 20 | }, 21 | "nonce": "$nonce$", 22 | "analytics_location": "slash_ui" 23 | } -------------------------------------------------------------------------------- /resources/api_params/imagine.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 2, 3 | "application_id": "$app_id$", 4 | "guild_id": "$guild_id$", 5 | "channel_id": "$channel_id$", 6 | "session_id": "$session_id$", 7 | "nonce": "$nonce$", 8 | "data": { 9 | "version": "1237876415471554623", 10 | "id": "938956540159881230", 11 | "name": "imagine", 12 | "type": 1, 13 | "options": [ 14 | { 15 | "type": 3, 16 | "name": "prompt", 17 | "value": "$prompt$" 18 | } 19 | ], 20 | "attachments": [] 21 | } 22 | } -------------------------------------------------------------------------------- /resources/api_params/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 2, 3 | "application_id": "936929561302675456", 4 | "guild_id": "1233678696364245002", 5 | "channel_id": "1233678696364245005", 6 | "session_id": "c38d67b8f1eded0b5603f44c31765817", 7 | "data": { 8 | "version": "1237876415735660565", 9 | "id": "972289487818334209", 10 | "name": "info", 11 | "type": 1, 12 | "options": [], 13 | "application_command": { 14 | "id": "972289487818334209", 15 | "type": 1, 16 | "application_id": "936929561302675456", 17 | "version": "1237876415735660565", 18 | "name": "info", 19 | "description": "View information about your profile.", 20 | "dm_permission": true, 21 | "contexts": [ 22 | 0, 23 | 1, 24 | 2 25 | ], 26 | "integration_types": [ 27 | 0, 28 | 1 29 | ], 30 | "global_popularity_rank": 2, 31 | "options": [], 32 | "description_localized": "View information about your profile.", 33 | "name_localized": "info" 34 | }, 35 | "attachments": [] 36 | }, 37 | "nonce": "1239288436951613440", 38 | "analytics_location": "slash_ui" 39 | } -------------------------------------------------------------------------------- /resources/api_params/message.json: -------------------------------------------------------------------------------- 1 | { 2 | "content": "", 3 | "channel_id": "$channel_id$", 4 | "type": 0, 5 | "sticker_ids": [], 6 | "attachments": [ 7 | { 8 | "id": "0", 9 | "filename": "$file_name$", 10 | "uploaded_filename": "$final_file_name$" 11 | } 12 | ] 13 | } -------------------------------------------------------------------------------- /resources/api_params/reroll.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 3, 3 | "guild_id": "$guild_id", 4 | "channel_id": "$channel_id", 5 | "message_id": "$message_id", 6 | "application_id": "936929561302675456", 7 | "session_id": "$session_id", 8 | "nonce": "$nonce", 9 | "message_flags": 0, 10 | "data": { 11 | "component_type": 2, 12 | "custom_id": "MJ::JOB::reroll::0::$message_hash::SOLO" 13 | } 14 | } -------------------------------------------------------------------------------- /resources/api_params/shorten.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 2, 3 | "application_id": "$app_id$", 4 | "guild_id": "$guild_id$", 5 | "channel_id": "$channel_id$", 6 | "session_id": "$session_id$", 7 | "data": { 8 | "version": "1237876415471554626", 9 | "id": "1121575372539039774", 10 | "name": "shorten", 11 | "type": 1, 12 | "options": [ 13 | { 14 | "type": 3, 15 | "name": "prompt", 16 | "value": "$prompt$" 17 | } 18 | ], 19 | "attachments": [] 20 | }, 21 | "nonce": "$nonce$", 22 | "analytics_location": "slash_ui" 23 | } -------------------------------------------------------------------------------- /resources/api_params/upscale.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 3, 3 | "guild_id": "$guild_id$", 4 | "channel_id": "$channel_id$", 5 | "message_id": "$message_id$", 6 | "application_id": "$app_id$", 7 | "session_id": "$session_id$", 8 | "nonce": "$nonce$", 9 | "message_flags": 0, 10 | "data": { 11 | "component_type": 2, 12 | "custom_id": "$custom_id$" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /resources/api_params/variation.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": 3, 3 | "guild_id": "$guild_id$", 4 | "channel_id": "$channel_id$", 5 | "message_id": "$message_id$", 6 | "application_id": "936929561302675456", 7 | "session_id": "$session_id$", 8 | "nonce": "$nonce$", 9 | "message_flags": 0, 10 | "data": { 11 | "component_type": 2, 12 | "custom_id": "$custom_id$" 13 | } 14 | } -------------------------------------------------------------------------------- /resources/config/prod_mj_config.yaml: -------------------------------------------------------------------------------- 1 | # midjourney的账号 2 | accounts: 3 | - user_token: "" 4 | bot_token: "" 5 | guild_id: "" 6 | channel_id: "" 7 | app_id: "" 8 | session_id: "" 9 | 10 | ng: 11 | discord_ws: https://gateway.discord.gg 12 | # discord_server: http://discordapi.aidomore.com/api/v9/interactions 13 | discord_server: https://discord.com 14 | discord_upload_server: https://discord-attachments-uploads-prd.storage.googleapis.com 15 | 16 | common: 17 | # 是否启动discord bot 18 | launch_discord_bot: false 19 | img_upload_url: "" 20 | 21 | redis: 22 | host: 0.0.0.0 23 | port: 6379 24 | password: "123" 25 | db: 0 -------------------------------------------------------------------------------- /service/__pycache__/celery_client.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/celery_client.cpython-38.pyc -------------------------------------------------------------------------------- /service/__pycache__/celery_service.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/celery_service.cpython-38.pyc -------------------------------------------------------------------------------- /service/__pycache__/discord_http_service.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/discord_http_service.cpython-38.pyc -------------------------------------------------------------------------------- /service/__pycache__/mj_data_service.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/mj_data_service.cpython-38.pyc -------------------------------------------------------------------------------- /service/__pycache__/notify_service.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/notify_service.cpython-38.pyc -------------------------------------------------------------------------------- /service/__pycache__/template_controller.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/template_controller.cpython-38.pyc -------------------------------------------------------------------------------- /service/discord_http_service.py: -------------------------------------------------------------------------------- 1 | from support.mj_config import MjConfig 2 | import redis 3 | import requests 4 | import json 5 | from support.task_controller import MjTask 6 | from service.template_controller import TemplateController 7 | import base64 8 | import io 9 | import hashlib 10 | from support.mj_account import MjAccount 11 | 12 | class DiscordHttpService: 13 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, template_controller: TemplateController, mj_account: MjAccount) -> None: 14 | self.redis_client = redis_client 15 | self.mj_config = mj_config 16 | self.mj_account = mj_account 17 | self.template_controller = template_controller 18 | 19 | def get_interaction_url(self): 20 | host = self.mj_config['ng']['discord_server'] 21 | return f"{host}/api/v9/interactions" 22 | 23 | def get_messages_url(self): 24 | host=self.mj_config['ng']['discord_server'] 25 | CHANNEL_ID = self.mj_account.get_channel_id() 26 | return f"{host}/api/v9/channels/{CHANNEL_ID}/messages" 27 | def get_discord_upload_attachment_url(self): 28 | host=self.mj_config['ng']['discord_server'] 29 | CHANNEL_ID = self.mj_account.get_channel_id() 30 | upload_attachment_url = f"{host}/api/v9/channels/{CHANNEL_ID}/attachments" 31 | return upload_attachment_url 32 | 33 | 34 | 35 | def imageine(self, task: MjTask): 36 | url = self.get_interaction_url() 37 | 38 | app_id = self.mj_account.get_app_id() 39 | guild_id = self.mj_account.get_gulid_id() 40 | channel_id = self.mj_account.get_channel_id() 41 | session_id = self.mj_account.get_session_id() 42 | user_token = self.mj_account.get_user_token() 43 | 44 | advanced_prompt = "" 45 | 46 | if task.image_prompt: 47 | advanced_prompt = task.image_prompt+" "+task.prompt 48 | else: 49 | advanced_prompt = task.prompt 50 | 51 | template_map = { 52 | "app_id": app_id, 53 | "guild_id": guild_id, 54 | "channel_id": channel_id, 55 | "session_id": session_id, 56 | "prompt": advanced_prompt, 57 | "nonce": task.nonce 58 | } 59 | 60 | imagine_template = self.template_controller.get_imagine(template_map) 61 | 62 | boundary = "----WebKitFormBoundaryqznUd46iGT62TY0s" 63 | fields = {"payload_json": (None, json.dumps(imagine_template))} 64 | from requests_toolbelt.multipart.encoder import MultipartEncoder 65 | form_data = MultipartEncoder(fields=fields, boundary=boundary) 66 | 67 | HEADERS = { 68 | "Content-Type": f"multipart/form-data; boundary={boundary}", 69 | "Authorization": user_token 70 | } 71 | res = requests.post(url=url, data=form_data, headers=HEADERS) 72 | 73 | if res.status_code not in [200, 204]: 74 | raise Exception(json.loads(res.text)) 75 | 76 | print("res.content", res.content) 77 | 78 | def upscale(self, task: MjTask): 79 | url = self.get_interaction_url() 80 | 81 | 82 | app_id = self.mj_account.get_app_id() 83 | guild_id = self.mj_account.get_gulid_id() 84 | channel_id = self.mj_account.get_channel_id() 85 | session_id = self.mj_account.get_session_id() 86 | user_token = self.mj_account.get_user_token() 87 | 88 | 89 | template_map = { 90 | "app_id": app_id, 91 | "guild_id": guild_id, 92 | "channel_id": channel_id, 93 | "session_id": session_id, 94 | "nonce": task.nonce, 95 | "message_id": task.message_id, 96 | "custom_id": task.custom_upscale_id 97 | } 98 | 99 | upscale_template = self.template_controller.get_upscale(template_map) 100 | HEADERS = { 101 | "Content-Type": "application/json", 102 | "Authorization": user_token 103 | } 104 | res = requests.post(url=url, data=json.dumps( 105 | upscale_template), headers=HEADERS) 106 | print(res) 107 | 108 | def variation(self, task: MjTask): 109 | url = self.get_interaction_url() 110 | 111 | app_id = self.mj_account.get_app_id() 112 | guild_id = self.mj_account.get_gulid_id() 113 | channel_id = self.mj_account.get_channel_id() 114 | session_id = self.mj_account.get_session_id() 115 | user_token = self.mj_account.get_user_token() 116 | 117 | template_map = { 118 | "app_id": app_id, 119 | "guild_id": guild_id, 120 | "channel_id": channel_id, 121 | "session_id": session_id, 122 | "nonce": task.nonce, 123 | "message_id": task.message_id, 124 | "custom_id": task.custom_variation_id 125 | } 126 | 127 | upscale_template = self.template_controller.get_upscale(template_map) 128 | HEADERS = { 129 | "Content-Type": "application/json", 130 | "Authorization": user_token 131 | } 132 | res = requests.post(url=url, data=json.dumps( 133 | upscale_template), headers=HEADERS) 134 | print(res) 135 | 136 | def describe(self, task: MjTask): 137 | url = self.get_interaction_url() 138 | 139 | app_id = self.mj_account.get_app_id() 140 | guild_id = self.mj_account.get_gulid_id() 141 | channel_id = self.mj_account.get_channel_id() 142 | session_id = self.mj_account.get_session_id() 143 | user_token = self.mj_account.get_user_token() 144 | 145 | template_map = { 146 | "app_id": app_id, 147 | "guild_id": guild_id, 148 | "channel_id": channel_id, 149 | "session_id": session_id, 150 | "nonce": task.nonce, 151 | "message_id": task.message_id, 152 | "custom_id": task.custom_variation_id, 153 | "prompt": task.image_prompt 154 | } 155 | 156 | upscale_template = self.template_controller.get_describe(template_map) 157 | HEADERS = { 158 | "Content-Type": "application/json", 159 | "Authorization": user_token 160 | } 161 | res = requests.post(url=url, data=json.dumps( 162 | upscale_template), headers=HEADERS) 163 | print(res) 164 | 165 | def blend(self, task: MjTask): 166 | url = self.get_interaction_url() 167 | 168 | app_id = self.mj_account.get_app_id() 169 | guild_id = self.mj_account.get_gulid_id() 170 | channel_id = self.mj_account.get_channel_id() 171 | session_id = self.mj_account.get_session_id() 172 | user_token = self.mj_account.get_user_token() 173 | 174 | advanced_prompt = "" 175 | 176 | if not task.blend_imgs: 177 | raise Exception("blend_imgs is empty") 178 | 179 | advanced_prompt = " ".join(task.blend_imgs) 180 | 181 | if task.dimensions: 182 | dimentsion_prompt = "" 183 | if task.dimensions == 'Square': 184 | dimentsion_prompt = "--ar 1:1" 185 | elif task.dimensions == 'Landscape': 186 | dimentsion_prompt = "--ar 3:2" 187 | elif task.dimensions == 'Portrait': 188 | dimentsion_prompt = "--ar 2:3" 189 | else: 190 | dimentsion_prompt = "" 191 | 192 | advanced_prompt = advanced_prompt + " " + dimentsion_prompt 193 | 194 | template_map = { 195 | "app_id": app_id, 196 | "guild_id": guild_id, 197 | "channel_id": channel_id, 198 | "session_id": session_id, 199 | "prompt": advanced_prompt, 200 | "nonce": task.nonce 201 | } 202 | 203 | imagine_template = self.template_controller.get_imagine(template_map) 204 | 205 | boundary = "----WebKitFormBoundaryqznUd46iGT62TY0s" 206 | fields = {"payload_json": (None, json.dumps(imagine_template))} 207 | from requests_toolbelt.multipart.encoder import MultipartEncoder 208 | form_data = MultipartEncoder(fields=fields, boundary=boundary) 209 | 210 | HEADERS = { 211 | "Content-Type": f"multipart/form-data; boundary={boundary}", 212 | "Authorization": user_token 213 | } 214 | res = requests.post(url=url, data=form_data, headers=HEADERS) 215 | 216 | if res.status_code not in [200, 204]: 217 | raise Exception(json.loads(res.text)) 218 | 219 | print("res.content", res.content) 220 | 221 | 222 | def action(self, task: MjTask): 223 | url = self.get_interaction_url() 224 | 225 | app_id = self.mj_account.get_app_id() 226 | guild_id = self.mj_account.get_gulid_id() 227 | channel_id = self.mj_account.get_channel_id() 228 | session_id = self.mj_account.get_session_id() 229 | user_token = self.mj_account.get_user_token() 230 | 231 | template_map = { 232 | "app_id": app_id, 233 | "guild_id": guild_id, 234 | "channel_id": channel_id, 235 | "session_id": session_id, 236 | "nonce": task.nonce, 237 | "message_id": task.reference_message_id, 238 | "custom_id": task.custom_id 239 | } 240 | 241 | upscale_template = self.template_controller.get_upscale(template_map) 242 | HEADERS = { 243 | "Content-Type": "application/json", 244 | "Authorization": user_token 245 | } 246 | res = requests.post(url=url, data=json.dumps( 247 | upscale_template), headers=HEADERS) 248 | print(res) 249 | 250 | def shorten(self, task: MjTask): 251 | url = self.get_interaction_url() 252 | 253 | app_id = self.mj_account.get_app_id() 254 | guild_id = self.mj_account.get_gulid_id() 255 | channel_id = self.mj_account.get_channel_id() 256 | session_id = self.mj_account.get_session_id() 257 | user_token = self.mj_account.get_user_token() 258 | 259 | template_map = { 260 | "app_id": app_id, 261 | "guild_id": guild_id, 262 | "channel_id": channel_id, 263 | "session_id": session_id, 264 | "nonce": task.nonce, 265 | "prompt": task.prompt 266 | } 267 | 268 | upscale_template = self.template_controller.get_shorten(template_map) 269 | HEADERS = { 270 | "Content-Type": "application/json", 271 | "Authorization": user_token 272 | } 273 | res = requests.post(url=url, data=json.dumps( 274 | upscale_template), headers=HEADERS) 275 | print(res) 276 | 277 | 278 | def message(self, final_file_name): 279 | url = self.get_messages_url() 280 | 281 | app_id = self.mj_account.get_app_id() 282 | guild_id = self.mj_account.get_gulid_id() 283 | channel_id = self.mj_account.get_channel_id() 284 | session_id = self.mj_account.get_session_id() 285 | user_token = self.mj_account.get_user_token() 286 | 287 | file_name = final_file_name.split("/")[-1] 288 | 289 | template_map = { 290 | "app_id": app_id, 291 | "guild_id": guild_id, 292 | "channel_id": channel_id, 293 | "session_id": session_id, 294 | "final_file_name": final_file_name, 295 | "file_name": file_name 296 | } 297 | 298 | message_template = self.template_controller.get_message(template_map) 299 | HEADERS = { 300 | "Content-Type": "application/json", 301 | "Authorization": user_token 302 | } 303 | res = requests.post(url=url, data=json.dumps( 304 | message_template), headers=HEADERS) 305 | print(res) 306 | if res.status_code==200: 307 | upload_resp = json.loads(res.text) 308 | image_url = upload_resp['attachments'][0]['url'] 309 | return image_url 310 | else: 311 | raise Exception("message upload failed") 312 | 313 | def upload_img_if_bs64(self, img: str): 314 | return self.upload2discord(img) 315 | 316 | def upload2discord(self, img: str): 317 | if not img: 318 | return 319 | if img.startswith("data:image/"): 320 | 321 | discord_upload_attachment_url = self.get_discord_upload_attachment_url() 322 | 323 | base64_data = img.split(",")[-1] 324 | image_data = base64.b64decode(base64_data) 325 | # 将字节数据转换为BytesIO对象 326 | image_stream = io.BytesIO(image_data) 327 | 328 | # 将image_streamg转为bytes 329 | image_bytes = image_stream.read() 330 | file_size = len(image_bytes) 331 | hash_value = hashlib.md5(image_data).hexdigest() 332 | file_name=f"{hash_value}.png" 333 | 334 | HEADERS = {"Content-Type": "application/json", "Authorization": self.USER_TOKEN} 335 | payload = { 336 | "files": [{"filename": file_name, "file_size": file_size, "id": "0"}]} 337 | res = requests.post(discord_upload_attachment_url, headers=HEADERS, data= json.dumps(payload)) 338 | if res.status_code == 200: 339 | res_data=json.loads(res.text) 340 | attach = res_data['attachments'][0] 341 | upload_url = attach['upload_url'] 342 | upload_filename = attach['upload_filename'] 343 | print("upload_url",upload_url) 344 | print("upload_filename",upload_filename) 345 | headers = { 346 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)", 347 | "Content-Type": "application/octet-stream", 348 | "Content-Length": str(file_size), 349 | } 350 | res = requests.put(upload_url, data=image_data, headers = headers) 351 | print(res.status_code) 352 | if res.status_code == 200: 353 | image_url= self.discord_http_service.message(upload_filename) 354 | return image_url 355 | else: 356 | raise Exception("上传图片失败") 357 | 358 | else: 359 | raise Exception("上传图片失败") -------------------------------------------------------------------------------- /service/discord_ws_service.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /service/mail_service.py: -------------------------------------------------------------------------------- 1 | import smtplib 2 | import random 3 | from email.mime.text import MIMEText 4 | from email.mime.multipart import MIMEMultipart 5 | 6 | def send_verification_code(email): 7 | # 生成6位随机验证码 8 | verification_code = ''.join(random.choices('0123456789', k=6)) 9 | 10 | # 设置发件人和收件人邮箱地址 11 | sender_email = "496812133@qq.com" 12 | receiver_email = email 13 | 14 | # 创建邮件内容 15 | msg = MIMEMultipart() 16 | msg['From'] = sender_email 17 | msg['To'] = receiver_email 18 | msg['Subject'] = "验证码" 19 | 20 | body = f"您的验证码是:{verification_code}" 21 | msg.attach(MIMEText(body, 'plain')) 22 | 23 | # 发送邮件 24 | with smtplib.SMTP('smtp.example.com', 587) as smtp: 25 | smtp.starttls() 26 | smtp.login(sender_email, "your_password") 27 | smtp.send_message(msg) 28 | 29 | print("验证码已发送至您的邮箱,请查收。") 30 | 31 | # 在此处调用函数,并传入接收验证码的邮箱地址 32 | 33 | send_verification_code("recipient_email@example.com") -------------------------------------------------------------------------------- /service/mj_data_service.py: -------------------------------------------------------------------------------- 1 | 2 | from queue import Queue 3 | import json 4 | from dataclasses import dataclass,field 5 | from datetime import datetime 6 | from typing import Optional 7 | import redis 8 | from typing import List 9 | 10 | def default_custom_ids(): 11 | return [] 12 | 13 | @dataclass(init=False) 14 | class MjTask: 15 | # 任务类型 16 | task_type: Optional[str] = None 17 | # 任务状态 18 | task_status: Optional[str] = None 19 | # task_id 20 | task_id: Optional[str] = None 21 | # 提示词 22 | prompt: Optional[str] = None 23 | # 提示词-英文 24 | promptEn: Optional[str] = None 25 | # 任务描述 26 | description: Optional[str] = None 27 | # 提交时间 28 | create_time: Optional[datetime] = None 29 | # 开始执行时间 30 | start_time: Optional[datetime] = None 31 | # 结束时间 32 | end_time: Optional[datetime] = None 33 | # 图片url 34 | image_url: Optional[str] = None 35 | # 进度 36 | progress: Optional[str] = None 37 | # 失败原因 38 | fail_reason: Optional[str] = None 39 | # 回调地址 40 | notify_hook: Optional[str] = None 41 | 42 | nonce: Optional[str] = None 43 | 44 | image_prompt: Optional[str] = None 45 | 46 | # mj返回的任务创建的prompt 47 | submit_prompt: Optional[str] = None 48 | 49 | # 之前对应的message_id 50 | message_id: Optional[str] = None 51 | reference_message_id: Optional[str] = None 52 | 53 | 54 | finished_message_id: Optional[str] = None 55 | 56 | buttons: Optional[List[str]] = None 57 | 58 | # action 59 | custom_id: Optional[str] = None 60 | 61 | # upscale 62 | upscale_custom_ids: List[str] = field(default_factory=default_custom_ids) 63 | custom_upscale_id: Optional[str] = None 64 | upscale_index: Optional[int] = None 65 | 66 | # variation 67 | variation_custom_ids: List[str] = field(default_factory=default_custom_ids) 68 | custom_variation_id: Optional[str] = None 69 | variation_index: Optional[int] = None 70 | 71 | 72 | # description 73 | description: Optional[str] = None 74 | 75 | # blend 76 | blend_imgs: Optional[List[str]] = None 77 | dimensions: Optional[str] 78 | 79 | class MjDataService: 80 | def __init__(self, redis_client: redis.Redis) -> None: 81 | self.redis_client = redis_client 82 | self.mj_tasks = Queue() 83 | self.mj_tasks_list = [] 84 | self.task_id_map = {} 85 | self.nonce_map = {} 86 | self.message_id_map = {} 87 | 88 | def next_nonce(self): 89 | nonce = self.redis_client.get("mj.nonce") 90 | if nonce is None: 91 | nonce = 1 92 | self.redis_client.set("mj.nonce", nonce) 93 | # 将redis中的nonce加1 94 | nonce = self.redis_client.incr("mj.nonce") 95 | return nonce 96 | def get_tasks_queue(self) -> Queue: 97 | return self.mj_tasks 98 | 99 | def get_tasks_list(self) -> list: 100 | return self.mj_tasks_list 101 | def add_task(self, task: MjTask) -> None: 102 | self.mj_tasks.put(task) 103 | self.task_id_map[task.task_id] = task 104 | self.nonce_map[task.nonce] = task 105 | self.mj_tasks_list.append(task) 106 | 107 | def get_task_by_nonce(self, nonce: str) -> Optional[MjTask]: 108 | return self.nonce_map.get(nonce) 109 | 110 | def get_task_by_task_id(self, task_id: str) -> Optional[MjTask]: 111 | return self.task_id_map.get(task_id) 112 | 113 | def get_task_by_message_id(self, message_id: str) -> Optional[MjTask]: 114 | return self.message_id_map.get(message_id) 115 | 116 | def update_task_nonce(self, task: MjTask, nonce): 117 | task.nonce = nonce 118 | 119 | def update_task_message_id(self, task: MjTask, message_id): 120 | task.message_id = message_id 121 | 122 | def update_task_progress(self, task: MjTask, progress: str) -> None: 123 | task.progress = progress 124 | 125 | def update_task_status(self, task: MjTask, status: str) -> None: 126 | task.task_status = status 127 | 128 | def update_task_image_url(self, task: MjTask, image_url: str) -> None: 129 | task.image_url = image_url 130 | 131 | def update_message_id_map(self,task:MjTask): 132 | self.message_id_map[task.message_id] = task 133 | 134 | 135 | def update_upsacle_custom_ids(self, task: MjTask, custom_ids: List[str]) -> None: 136 | task.upscale_custom_ids = custom_ids 137 | 138 | def update_variation_custom_ids(self, task: MjTask, custom_ids: List[str]) -> None: 139 | task.variation_custom_ids = custom_ids 140 | 141 | def update_finished_message_id(self, task: MjTask, finished_message_id: str) -> None: 142 | task.finished_message_id = finished_message_id 143 | 144 | def update_task_description(self, task: MjTask, description: str) -> None: 145 | task.description = description 146 | 147 | def update_buttons(self, task: MjTask, buttons: List[str]) -> None: 148 | task.buttons = buttons -------------------------------------------------------------------------------- /service/notify_service.py: -------------------------------------------------------------------------------- 1 | from support.mj_config import MjConfig 2 | import redis 3 | import requests 4 | import json 5 | import datetime 6 | from support.task_controller import MjTask 7 | from service.template_controller import TemplateController 8 | from service.mj_data_service import MjDataService 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class NotifyService: 14 | def __init__(self, redis_client: redis.Redis) -> None: 15 | self.redis_client = redis_client 16 | 17 | def notify_task_change(self, task: MjTask): 18 | if not task: 19 | return 20 | 21 | notify_hook = task.notify_hook 22 | if not notify_hook: 23 | return 24 | 25 | if not notify_hook.startswith("http"): 26 | return 27 | logger.info( 28 | f"notify_task_change, task_id: {task.task_id}, hook: {notify_hook}") 29 | 30 | try: 31 | req = { 32 | "task_id": task.task_id, 33 | "task_type": task.task_type, 34 | "task_status": task.task_status, 35 | "image_url": task.image_url, 36 | "progress": task.progress, 37 | "action_index": task.upscale_index, 38 | "description": task.description, 39 | "buttons": task.buttons, 40 | "fail_reason":task.fail_reason 41 | } 42 | res = requests.post(notify_hook, json=req) 43 | # 如果为2xx 44 | if res.status_code // 100 == 2: 45 | logger.info(f"notify_task_change res 2xx, task_id: {task.task_id}, hook: {notify_hook}, res: {res.text}") 46 | return 47 | else: 48 | logger.error( 49 | f"notify_task_change res not 2xx, task_id: {task.task_id}, hook: {notify_hook}, res: {res.text}") 50 | 51 | except Exception as e: 52 | logger.error(f"notify_task_change, task_id: {task.task_id}, hook: {notify_hook}, error: {e}") 53 | return -------------------------------------------------------------------------------- /service/template_controller.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class TemplateController: 4 | def __init__(self) -> None: 5 | blend_file = "./resources/api_params/blend.json" 6 | self.blend_template = json.load(open(blend_file, "r", encoding="utf-8")) 7 | describe_file = "./resources/api_params/describe.json" 8 | self.describe_template = json.load(open(describe_file, "r", encoding="utf-8")) 9 | variation_file = "./resources/api_params/variation.json" 10 | self.variation_template = json.load(open(variation_file, "r", encoding="utf-8")) 11 | upscale_file = "./resources/api_params/upscale.json" 12 | self.upscale_template = json.load(open(upscale_file, "r", encoding="utf-8")) 13 | self.imagine_template = json.load(open("./resources/api_params/imagine.json", "r", encoding="utf-8")) 14 | self.shorten_template = json.load(open("./resources/api_params/shorten.json", "r", encoding="utf-8")) 15 | self.message_template = json.load(open("./resources/api_params/message.json", "r", encoding="utf-8")) 16 | 17 | 18 | # 将tempalte中的${key}$替换为template_map中的值 19 | def replace_template(self, template, template_map): 20 | for key, value in template_map.items(): 21 | if "$"+key+"$" not in template: 22 | continue 23 | template = template.replace("$"+key+"$", value) 24 | return template 25 | 26 | def get_imagine(self, template_map): 27 | imagine_template_dumps=json.dumps(self.imagine_template) 28 | imagine_template_dumps = self.replace_template(imagine_template_dumps, template_map) 29 | return json.loads(imagine_template_dumps) 30 | 31 | def get_blend(self, template_map): 32 | blend_template_dumps=json.dumps(self.blend_template) 33 | blend_template_dumps = self.replace_template(blend_template_dumps, template_map) 34 | return json.loads(blend_template_dumps) 35 | 36 | def get_describe(self, template_map): 37 | describe_template_dumps=json.dumps(self.describe_template) 38 | describe_template_dumps = self.replace_template(describe_template_dumps, template_map) 39 | return json.loads(describe_template_dumps) 40 | 41 | def get_variation(self, template_map): 42 | variation_template_dumps=json.dumps(self.variation_template) 43 | variation_template_dumps = self.replace_template(variation_template_dumps, template_map) 44 | return json.loads(variation_template_dumps) 45 | 46 | def get_upscale(self, template_map): 47 | upscale_template_dumps=json.dumps(self.upscale_template) 48 | upscale_template_dumps = self.replace_template(upscale_template_dumps, template_map) 49 | return json.loads(upscale_template_dumps) 50 | 51 | def get_shorten(self, template_map): 52 | shorten_template_dumps=json.dumps(self.shorten_template) 53 | shorten_template_dumps = self.replace_template(shorten_template_dumps, template_map) 54 | return json.loads(shorten_template_dumps) 55 | 56 | 57 | def get_message(self, template_map): 58 | message_template_dumps = json.dumps(self.message_template) 59 | message_template_dumps = self.replace_template(message_template_dumps, template_map) 60 | return json.loads(message_template_dumps) -------------------------------------------------------------------------------- /support/Injector.py: -------------------------------------------------------------------------------- 1 | from injector import Injector, Module, provider, singleton 2 | import redis 3 | import yaml 4 | from support.task_controller import TaskController 5 | from support.mj_config import MjConfig 6 | from wss.mj_wss_proxy import MjWssSercice 7 | from service.discord_http_service import DiscordHttpService 8 | from service.mj_data_service import MjDataService 9 | from service.template_controller import TemplateController 10 | from service.notify_service import NotifyService 11 | from support.load_balancer import LoadBalancer 12 | from wss.mj_wss_manager import MjWssManager 13 | 14 | 15 | class SupportModule(Module): 16 | 17 | @provider 18 | @singleton 19 | def provide_mj_config(self) -> MjConfig: 20 | mj_config = MjConfig() 21 | return mj_config 22 | 23 | @provider 24 | @singleton 25 | def provide_redis(self, mj_config: MjConfig) -> redis.Redis: 26 | config = mj_config.mj_config 27 | redis_client = redis.Redis(host=config['redis']['host'], 28 | port=config['redis']['port'], 29 | db=config['redis']['db'], 30 | password=config['redis']['password']) 31 | return redis_client 32 | 33 | @provider 34 | @singleton 35 | def provide_task_controller(self, redis_client: redis.Redis, load_balancer: LoadBalancer, mj_data_service: MjDataService) -> TaskController: 36 | task_controller = TaskController( 37 | redis_client, load_balancer, mj_data_service) 38 | return task_controller 39 | 40 | @provider 41 | @singleton 42 | def provide_mj_data_service(self, redis_client: redis.Redis) -> MjDataService: 43 | mj_data_service = MjDataService(redis_client) 44 | return mj_data_service 45 | 46 | @provider 47 | @singleton 48 | def provide_template_controller(self,) -> TemplateController: 49 | template_controller = TemplateController() 50 | return template_controller 51 | 52 | 53 | @provider 54 | @singleton 55 | def provide_notify_service(self, redis_client: redis.Redis) -> NotifyService: 56 | notify_service = NotifyService(redis_client) 57 | return notify_service 58 | 59 | @provider 60 | @singleton 61 | def provide_load_balancer(self, mj_config: MjConfig, redis_client: redis.Redis, template_controller: TemplateController) -> LoadBalancer: 62 | load_balancer = LoadBalancer(mj_config, redis_client, template_controller) 63 | return load_balancer 64 | 65 | @provider 66 | @singleton 67 | def provide_mj_wss_manager(self, mj_config: MjConfig, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> MjWssManager: 68 | mj_wss_manager = MjWssManager(mj_config, redis_client, mj_data_service, notify_service) 69 | return mj_wss_manager 70 | 71 | # 创建注入器 72 | injector = Injector([SupportModule()]) 73 | -------------------------------------------------------------------------------- /support/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__init__.py -------------------------------------------------------------------------------- /support/__pycache__/Injector.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/Injector.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/image_controller.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/image_controller.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/load_balancer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/load_balancer.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/mj_account.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/mj_account.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/mj_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/mj_config.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/mj_task.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/mj_task.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/task_controller.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/task_controller.cpython-38.pyc -------------------------------------------------------------------------------- /support/__pycache__/template_controller.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/template_controller.cpython-38.pyc -------------------------------------------------------------------------------- /support/image_controller.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import redis 3 | import uuid 4 | from support.mj_config import MjConfig 5 | from service.discord_http_service import DiscordHttpService 6 | import json 7 | import io 8 | from PIL import Image 9 | import hashlib 10 | import base64 11 | 12 | class ImageController: 13 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, discord_http_service: DiscordHttpService) -> None: 14 | self.redis_client = redis_client 15 | self.mj_config = mj_config.mj_config 16 | self.discord_http_service = discord_http_service 17 | 18 | self.discord_upload_server = self.mj_config['ng']['discord_upload_server'] 19 | self.CHANNEL_ID = self.mj_config['account']['channel_id'] 20 | self.USER_TOKEN = self.mj_config['account']['user_token'] 21 | 22 | def upload_img_if_bs64(self, img: str): 23 | # return self.upload2fastapi(img) 24 | return self.upload2discord(img) 25 | 26 | def upload2discord(self, img: str): 27 | if not img: 28 | return 29 | if img.startswith("data:image/"): 30 | 31 | discord_upload_attachment_url = self.discord_http_service.get_discord_upload_attachment_url() 32 | 33 | base64_data = img.split(",")[-1] 34 | image_data = base64.b64decode(base64_data) 35 | # 将字节数据转换为BytesIO对象 36 | image_stream = io.BytesIO(image_data) 37 | 38 | # 将image_streamg转为bytes 39 | image_bytes = image_stream.read() 40 | file_size = len(image_bytes) 41 | hash_value = hashlib.md5(image_data).hexdigest() 42 | file_name=f"{hash_value}.png" 43 | 44 | HEADERS = {"Content-Type": "application/json", "Authorization": self.USER_TOKEN} 45 | payload = { 46 | "files": [{"filename": file_name, "file_size": file_size, "id": "0"}]} 47 | res = requests.post(discord_upload_attachment_url, headers=HEADERS, data= json.dumps(payload)) 48 | if res.status_code == 200: 49 | res_data=json.loads(res.text) 50 | attach = res_data['attachments'][0] 51 | upload_url = attach['upload_url'] 52 | upload_filename = attach['upload_filename'] 53 | print("upload_url",upload_url) 54 | print("upload_filename",upload_filename) 55 | headers = { 56 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)", 57 | "Content-Type": "application/octet-stream", 58 | "Content-Length": str(file_size), 59 | } 60 | res = requests.put(upload_url, data=image_data, headers = headers) 61 | print(res.status_code) 62 | if res.status_code == 200: 63 | image_url= self.discord_http_service.message(upload_filename) 64 | return image_url 65 | else: 66 | raise Exception("上传图片失败") 67 | 68 | else: 69 | raise Exception("上传图片失败") 70 | 71 | 72 | 73 | elif img.startswith("http"): 74 | return img 75 | 76 | 77 | def upload2fastapi(self, img: str): 78 | if not img: 79 | return 80 | if img.startswith("data:image/"): 81 | url = self.mj_config['common']['img_upload_url'] 82 | data = { 83 | "bs64": img 84 | } 85 | res = requests.post(url, data=json.dumps(data)) 86 | res_data = json.loads(res.text) 87 | filename = res_data["data"]["filename"] 88 | image_url = f"http://43.153.103.254/images/{filename}" 89 | return image_url 90 | elif img.startswith("http"): 91 | return img -------------------------------------------------------------------------------- /support/load_balancer.py: -------------------------------------------------------------------------------- 1 | from service.discord_http_service import DiscordHttpService 2 | import redis 3 | from support.mj_config import MjConfig 4 | from service.template_controller import TemplateController 5 | from support.mj_account import MjAccount 6 | import random 7 | 8 | class LoadBalancer: 9 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, template_controller: TemplateController): 10 | 11 | self.mj_config = mj_config.mj_config 12 | self.redis_client = redis_client 13 | self.template_controller = template_controller 14 | 15 | self.discord_http_services = [] 16 | self.init_discord_http_services() 17 | 18 | def init_discord_http_services(self): 19 | discord_http_services = [] 20 | for i in range(len(self.mj_config['accounts'])): 21 | account = self.mj_config['accounts'][i] 22 | mj_account = MjAccount(account) 23 | discord_http_services.append(DiscordHttpService( 24 | self.mj_config, self.redis_client, self.template_controller,mj_account)) 25 | 26 | self.discord_http_services = discord_http_services 27 | 28 | def get_discord_http_service(self) -> DiscordHttpService: 29 | index = random.randint(0, len(self.discord_http_services)-1) 30 | return self.discord_http_services[index] 31 | -------------------------------------------------------------------------------- /support/mj_account.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class MjAccount: 4 | def __init__(self, account_config: dict) -> None: 5 | self.account_config = account_config 6 | 7 | def get_user_token(self): 8 | return self.account_config['user_token'] 9 | 10 | def get_bot_token(self): 11 | return self.account_config['bot_token'] 12 | 13 | def get_channel_id(self): 14 | return self.account_config['channel_id'] 15 | 16 | def get_gulid_id(self): 17 | return self.account_config['guild_id'] 18 | 19 | def get_app_id(self): 20 | return self.account_config['app_id'] 21 | 22 | def get_session_id(self): 23 | return self.account_config['session_id'] 24 | -------------------------------------------------------------------------------- /support/mj_config.py: -------------------------------------------------------------------------------- 1 | 2 | import yaml 3 | class MjConfig: 4 | def __init__(self) -> None: 5 | self.mj_config = yaml.load(open( 6 | "./resources/config/prod_mj_config.yaml", "r", encoding="utf-8"), Loader=yaml.FullLoader) 7 | 8 | 9 | def get_accounts(self): 10 | return self.mj_config['accounts'] -------------------------------------------------------------------------------- /support/task_controller.py: -------------------------------------------------------------------------------- 1 | 2 | from service.mj_data_service import MjTask 3 | from queue import Queue 4 | import time 5 | import threading 6 | import redis 7 | from service.discord_http_service import DiscordHttpService 8 | from service.mj_data_service import MjDataService 9 | from support.load_balancer import LoadBalancer 10 | import uuid 11 | 12 | 13 | def consume_tasks(mj_data_service: MjDataService, load_balancer: LoadBalancer): 14 | while True: 15 | task_queue = mj_data_service.get_tasks_queue() 16 | task: MjTask = task_queue.get() 17 | if task is None: 18 | continue 19 | 20 | # comsume任务 21 | if task.task_type is None: 22 | continue 23 | 24 | discord_http_service = load_balancer.get_discord_http_service() 25 | 26 | if task.task_type == "imagine": 27 | discord_http_service.imageine(task) 28 | elif task.task_type == "upscale": 29 | discord_http_service.upscale(task) 30 | elif task.task_type == "variation": 31 | discord_http_service.variation(task) 32 | elif task.task_type == "describe": 33 | discord_http_service.describe(task) 34 | elif task.task_type == "blend": 35 | discord_http_service.blend(task) 36 | elif task.task_type == "action": 37 | discord_http_service.action(task) 38 | elif task.task_type == "shorten": 39 | discord_http_service.shorten(task) 40 | else: 41 | print("unknown task type: " + task.task_type) 42 | 43 | class TaskController: 44 | def __init__(self, redis_client: redis.Redis,load_balancer: LoadBalancer, mj_data_service: MjDataService) -> None: 45 | self.mj_data_service: MjDataService = mj_data_service 46 | 47 | # discord_http_service: DiscordHttpService 48 | consume_thread = threading.Thread( 49 | target = consume_tasks, args=(self.mj_data_service, load_balancer)) 50 | consume_thread.daemon = True 51 | consume_thread.start() 52 | 53 | def submit_imagine(self, task_data: MjTask): 54 | self.base_submit(task_data) 55 | return task_data.task_id 56 | 57 | def submit_upscale(self, task_data: MjTask): 58 | self.base_submit(task_data) 59 | return task_data.task_id 60 | 61 | def submit_describe(self, task_data: MjTask): 62 | self.base_submit(task_data) 63 | return task_data.task_id 64 | 65 | def submit_blend(self, task_data: MjTask): 66 | self.base_submit(task_data) 67 | return task_data.task_id 68 | 69 | def submit_variation(self, task_data: MjTask): 70 | self.base_submit(task_data) 71 | return task_data.task_id 72 | 73 | def submit_action(self, task_data: MjTask): 74 | self.base_submit(task_data) 75 | return task_data.task_id 76 | 77 | def submit_shorten(self, task_data: MjTask): 78 | self.base_submit(task_data) 79 | return task_data.task_id 80 | 81 | def base_submit(self, task_data: MjTask): 82 | task_data.task_id = str(uuid.uuid4()) 83 | self.mj_data_service.add_task(task_data) 84 | return task_data.task_id 85 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # from bootstrap import Bootstrap -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/common_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/logger_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/wx_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/wx_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/bootstrap.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import asyncio 4 | import random 5 | import time 6 | import uuid 7 | import yaml 8 | from .logger_util import get_logger 9 | import consul 10 | import consul.base 11 | import consul.callback 12 | from prometheus_client import CollectorRegistry, Gauge, push_to_gateway, start_http_server 13 | 14 | class Bootstrap(): 15 | def __init__(self, file_yaml='config.yaml') -> None: 16 | config_yaml: dict = self.read_config(file_yaml) 17 | env_dict = os.environ 18 | # main 19 | self.RUN_NAME = config_yaml.get('RUN_NAME') 20 | self.LOG_PATH = config_yaml.get('LOG_PATH') 21 | self.ENV_IP = config_yaml.get('ENV_IP', env_dict.get('ENV_IP')) 22 | self.ENV_PORT = config_yaml.get('ENV_PORT', env_dict.get('ENV_PORT')) 23 | self.ENV_PORTS = config_yaml.get('ENV_PORTS', env_dict.get('ENV_PORTS')) 24 | self.DWONSTREAMS = config_yaml.get('DWONSTREAMS') 25 | 26 | # consul 27 | self.CONSUL_IP = config_yaml.get('CONSUL_IP') 28 | self.CONSUL_PORT = config_yaml.get('CONSUL_PORT') 29 | self.CONSUL_TOKEN = config_yaml.get('CONSUL_TOKEN') 30 | self.CONSUL_CHECK_IP = config_yaml.get('CONSUL_CHECK_IP') 31 | self.CONSUL_CHECK_TICK = config_yaml.get('CONSUL_CHECK_TICK', 5) 32 | self.CONSUL_CHECK_TIMEOUT = config_yaml.get('CONSUL_CHECK_TIMEOUT', 30) 33 | self.CONSUL_CHECK_DEREG = config_yaml.get('CONSUL_CHECK_DEREG', 30) 34 | 35 | # metrics 36 | self.METRICS_IP = config_yaml.get('METRICS_IP') 37 | self.METRICS_PORT = config_yaml.get('METRICS_PORT') 38 | self.METRICS_TOKEN = config_yaml.get('METRICS_TOKEN') 39 | self.METRICS_CHECK_TICK = config_yaml.get('METRICS_CHECK_TICK', 5) 40 | 41 | def lifespan(self, _): 42 | print("startup") 43 | self.background_start() 44 | yield 45 | print(f"shutdown wait {self.CONSUL_CHECK_DEREG}s for consul dereg") 46 | self.deregister_consul() 47 | # time.sleep(self.CONSUL_CHECK_DEREG) 48 | 49 | def init(self): 50 | self.logger = get_logger(self.RUN_NAME, self.LOG_PATH) 51 | self.consul_cli = consul.Consul(self.CONSUL_IP, self.CONSUL_PORT, self.CONSUL_TOKEN) 52 | self.register_consul() 53 | self.metrics_registry = CollectorRegistry() 54 | self.metrics_reg_base() 55 | self.metrics_reg() 56 | self.service_dict = {k: [] for k in self.DWONSTREAMS} 57 | self.background_service_update() 58 | self.not_shutdown = True 59 | 60 | def background_start(self): 61 | if self.service_dict: 62 | asyncio.create_task(self.background_service_update()) 63 | asyncio.create_task(self.background_metrics_update()) 64 | 65 | def read_config(self, file_yaml): 66 | with open(file=file_yaml, mode='r', encoding='utf-8') as f: 67 | return yaml.load(stream=f.read(), Loader=yaml.FullLoader) 68 | 69 | def register_consul(self): 70 | # 健康检查ip端口,检查时间:5,超时时间:30,注销时间:30s 71 | # curl --request PUT http://127.0.0.1:6011/v1/agent/service/register/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81 72 | # curl --request PUT http://127.0.0.1:6011/v1/agent/service/deregister/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81 73 | # curl --request GET http://127.0.0.1:6011/v1/agent/checks?token=6df8babf-1061-2469-79f0-4488640cba81 74 | # curl --request GET http://127.0.0.1:6011/v1/agent/check/test_service_up?token=6df8babf-1061-2469-79f0-4488640cba81 75 | 76 | self.consul_id = uuid.uuid4().hex 77 | checker = consul.Check().tcp(self.CONSUL_CHECK_IP, self.ENV_PORT, 78 | self.CONSUL_CHECK_TICK * 1000 * 1000, 79 | self.CONSUL_CHECK_TIMEOUT * 1000 * 1000, 80 | self.CONSUL_CHECK_DEREG * 1000 * 1000) 81 | # checker = consul.Check().http(f'http://127.0.0.1:8531/health', 82 | # 1 * 1000 * 1000, 2 * 1000 * 1000, 3 * 1000 * 1000) 83 | # res = self.consul_cli.catalog.register( 84 | # self.RUN_NAME, 85 | # self.ENV_IP, 86 | # { 87 | # "Service": self.RUN_NAME, 88 | # "ID": self.consul_id, 89 | # "Tags": [], 90 | # "Port": self.ENV_PORT, 91 | # }, 92 | # check=checker, 93 | # token=self.CONSUL_TOKEN, 94 | # ) 95 | res = self.consul_cli.agent.service.register( 96 | self.RUN_NAME, 97 | service_id=self.consul_id, 98 | address=self.ENV_IP, 99 | port=self.ENV_PORT, 100 | token=self.CONSUL_TOKEN, 101 | check=checker) 102 | if not res: 103 | raise RuntimeError('register_consul fail!') 104 | 105 | def consul_token_param(self): 106 | params = [] 107 | if self.CONSUL_TOKEN: 108 | params.append(('token', self.CONSUL_TOKEN)) 109 | return params 110 | 111 | def deregister_consul(self): 112 | # curl --request PUT http://127.0.0.1:6011/v1/agent/service/deregister/test_service_down 113 | # curl --request PUT http://101.43.141.18:6011/v1/agent/service/register/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81 114 | # curl --request PUT http://101.43.141.18:6011/v1/agent/service/deregister/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81 115 | # curl --request GET http://101.43.141.18:6011/v1/agent/services?token=6df8babf-1061-2469-79f0-4488640cba81 116 | # curl --request GET http://101.43.141.18:6011/v1/agent/service/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81 117 | 118 | res = self.consul_cli.agent.service.deregister(self.consul_id, token=self.CONSUL_TOKEN) 119 | # res = self.consul_cli.catalog.deregister() 120 | # res = self.consul_cli.agent.agent.http.put(consul.callback.CB.bool(), '/v1/agent/service/deregister/%s' % self.RUN_NAME, params=self.consul_token_param()) 121 | if not res: 122 | self.logger.error('deregister_consul fail!') 123 | 124 | def get_server(self, service_name): 125 | service_list = self.service_dict[service_name] 126 | if not service_list: 127 | return None 128 | return random.choice(service_list) 129 | 130 | async def background_service_update(self): 131 | def cb(response): 132 | consul.callback.CB._status(response) 133 | if response.code == 404: 134 | return None 135 | return json.loads(response.body) 136 | while self.not_shutdown: 137 | for service_name in self.service_dict: 138 | filter = ('filter', f'Service == "{service_name}"') 139 | service_info = self.consul_cli.agent.agent.http.get(cb, '/v1/agent/services', params=(*self.consul_token_param(), filter)) 140 | if isinstance(service_info, dict): 141 | self.service_dict[service_name] = [{'ip': item.get('Address'), 'port': item.get('Port')} for item in service_info.values()] 142 | self.metrics_downstream.labels(src=self.RUN_NAME, dst=service_name).set(len(service_info)) 143 | else: 144 | self.logger.error('service_dict[%s] parse error! parsed to: %s', service_name, self.service_dict) 145 | self.metrics_downstream.labels(src=self.RUN_NAME, dst=service_name).set(-1) 146 | await asyncio.sleep(self.CONSUL_CHECK_TICK) 147 | 148 | async def background_metrics_update(self): 149 | while self.not_shutdown: 150 | push_to_gateway(f'{self.METRICS_IP}:{self.METRICS_PORT}', job=self.RUN_NAME, registry=self.metrics_registry) 151 | self.logger.info('push_to_gateway') 152 | await asyncio.sleep(self.METRICS_CHECK_TICK) 153 | 154 | def metrics_reg_base(self): 155 | self.metrics_downstream = Gauge('service_downstream', 'downstream service num', ['src', 'dst'], registry=self.metrics_registry) 156 | 157 | def metrics_reg(self): 158 | raise NotImplementedError() 159 | 160 | # c = Counter('my_requests_total', 'HTTP requests total', ['method', 'endpoint'], registry=self.metrics_registry) 161 | # c.labels(method='get', endpoint='/').inc() 162 | # c.labels(method='post', endpoint='/submit').inc() 163 | -------------------------------------------------------------------------------- /utils/common_util.py: -------------------------------------------------------------------------------- 1 | # -*- ecoding: utf-8 -*- 2 | import numpy as np 3 | from typing import List 4 | import os 5 | import shutil 6 | # import pickle 7 | import json 8 | import string 9 | from PIL import Image 10 | import math 11 | import requests 12 | import base64 13 | from urllib.parse import urlparse 14 | from io import BytesIO 15 | import time 16 | import dill as pickle 17 | 18 | 19 | def data_generator(data, batch_size): 20 | assert batch_size > 0 21 | current_pos = 0 22 | index = 0 23 | total = len(data) // batch_size + 1 24 | while current_pos < len(data): 25 | yield index, total, data[current_pos: current_pos + batch_size] 26 | current_pos += batch_size 27 | index += 1 28 | 29 | 30 | def write2file(filename, text): 31 | with open(filename, "w", encoding="utf-8") as f: 32 | f.write(text) 33 | 34 | 35 | def write2file_arr(filename, arr): 36 | init = True 37 | dirname = os.path.dirname(filename) 38 | if not os.path.exists(dirname): 39 | os.system(f"mkdir -p {dirname}") 40 | with open(filename, "w", encoding="utf-8") as f: 41 | for a in arr: 42 | if not init: 43 | f.write("\n") 44 | f.write(a) 45 | else: 46 | f.write(a) 47 | init = False 48 | 49 | 50 | def readfile2arr(filename, skip=None): 51 | with open(filename, "r", encoding="utf-8") as f: 52 | if skip == None: 53 | arr = [a.strip() for a in f] 54 | else: 55 | arr = [a.strip() for a in f][skip:] 56 | return arr 57 | 58 | 59 | # 字符串最小编辑距离 60 | def minDistance(word1: str, word2: str) -> int: 61 | m, n = len(word1), len(word2) 62 | dp = [[0 for i in range(n + 1)] for j in range(m + 1)] 63 | 64 | for i in range(m + 1): 65 | dp[i][0] = i 66 | 67 | for j in range(n + 1): 68 | dp[0][j] = j 69 | 70 | for i in range(1, m + 1): 71 | for j in range(1, n + 1): 72 | if word1[i - 1] == word2[j - 1]: 73 | dp[i][j] = dp[i - 1][j - 1] 74 | else: 75 | dp[i][j] = min(dp[i - 1][j] + 1, dp[i] 76 | [j - 1] + 1, dp[i - 1][j - 1] + 1) 77 | 78 | return dp[m][n] 79 | 80 | 81 | def get_words2ids_pad(texts, words2id, max_seq): 82 | lines = [] 83 | for line in texts: 84 | lines.append(line.strip()) 85 | 86 | pad_ids = [] 87 | for line in lines: 88 | ids = [words2id[i] for i in line] 89 | ids = ids[0:max_seq] 90 | if len(ids) < max_seq: 91 | ids = ids + [0] * (max_seq - len(ids)) 92 | pad_ids.append(ids) 93 | return np.array(pad_ids) 94 | 95 | 96 | def remove_dir_and_files(root_dir): 97 | # 删除目录及文件 98 | shutil.rmtree(root_dir, True) 99 | 100 | 101 | def remove_subdirs_and_files(root_dir): 102 | for filename in os.listdir(root_dir): 103 | file_path = os.path.join(root_dir, filename) 104 | if os.path.isfile(file_path) or os.path.islink(file_path): 105 | os.unlink(file_path) 106 | elif os.path.isdir(file_path): 107 | shutil.rmtree(file_path) 108 | 109 | 110 | def copy_files(file_list, ori_dir, des_dir): 111 | for file_name in file_list: 112 | file_dir = os.path.join(ori_dir, file_name) 113 | shutil.copy(file_dir, des_dir) 114 | 115 | 116 | def save_pickle(filename, obj): 117 | with open(filename, "wb") as f: 118 | pickle.dump(obj, f) 119 | 120 | 121 | def load_pickle(filename): 122 | if filename.endswith(".json"): 123 | return read_json(filename) 124 | with open(filename, "rb") as f: 125 | obj = pickle.load(f) 126 | return obj 127 | 128 | 129 | def arr2str_tab(arr): 130 | result = [] 131 | for i in arr: 132 | if isinstance(i, int): 133 | result.append(str(i)) 134 | elif isinstance(i, float): 135 | result.append("%.3f" % i) 136 | elif isinstance(i, str): 137 | result.append(i) 138 | # arr = [str(i) if isinstance(i, int) elif "%.3f" % i for i in arr] 139 | return "\t".join(result) 140 | 141 | 142 | def to_torch(data, type=None, device="cpu"): 143 | import torch 144 | 145 | if type == None: 146 | return torch.from_numpy(np.array(data)).to(device) 147 | else: 148 | return torch.from_numpy(np.array(data, dtype=type)).to(device) 149 | 150 | 151 | def read_json(filename): 152 | with open(filename, "r", encoding="utf-8") as f: 153 | data = json.load(f) 154 | return data 155 | 156 | 157 | def cached_pkl(cache_path: str, load_from_cache: bool, gen_func, *args, **kargs): 158 | if os.path.exists(cache_path) and load_from_cache: 159 | print("gen from cached pkl") 160 | return load_pickle(cache_path) 161 | else: 162 | print("gen from origin file") 163 | out = gen_func(*args, **kargs) 164 | save_pickle(cache_path, out) 165 | return out 166 | 167 | 168 | def alignment(arr: List, max_length: int, padding): 169 | arr_len = len(arr) 170 | mask = [1.0] * len(arr) 171 | if len(arr) >= max_length: 172 | arr = arr[0:max_length] 173 | mask = mask[0:max_length] 174 | else: 175 | arr += [padding] * (max_length - arr_len) 176 | mask += [0.0] * (max_length - arr_len) 177 | return arr, mask 178 | 179 | 180 | def collate_fn_from_map(data, keys=None, device="cpu", ignore_keys=[]): 181 | assert data 182 | assert isinstance(data, list) 183 | result = {} 184 | import torch 185 | 186 | ignore_keys = set(ignore_keys) 187 | if isinstance(data[0], dict): 188 | for i in data: 189 | for k, v in i.items(): 190 | if k not in ignore_keys: 191 | if k not in result: 192 | result[k] = [] 193 | result[k].append(v) 194 | if keys: 195 | for k, v in result.items(): 196 | if k in keys: 197 | result[k] = to_torch(v, type=keys[k], device=device) 198 | 199 | return result 200 | 201 | 202 | def save_json_datas(file: str, datas): 203 | # if not os.path.exists("../evaluate/python_result/"+testset_name+"/"): 204 | # os.mkdir("../evaluate/python_result/"+testset_name+"/") 205 | # paths=file.split("/") 206 | # tmp_path=paths[0]+"/" 207 | beg_index = 0 208 | while True: 209 | try: 210 | index = file.index("/", beg_index) 211 | except Exception as e: 212 | break 213 | 214 | recur_dir_path = file[: index + 1] 215 | if not os.path.exists(recur_dir_path): 216 | os.mkdir(recur_dir_path) 217 | beg_index = index + 1 218 | 219 | with open(file, "w", encoding="utf-8") as f: 220 | f.write(json.dumps(datas, indent=4, ensure_ascii=False)) 221 | 222 | 223 | def is_chinese(uchar): 224 | """判断一个unicode是否是汉字""" 225 | if uchar >= "\u4e00" and uchar <= "\u9fa5": 226 | return True 227 | else: 228 | return False 229 | 230 | 231 | def save_huggface_json_datas(file, datas): 232 | line_datas = [] 233 | for i in datas: 234 | line_datas.append(json.dumps(i, ensure_ascii=False)) 235 | write2file_arr(file, line_datas) 236 | 237 | 238 | def read_huggingface_json_datas(file): 239 | line_datas = [] 240 | datas = readfile2arr(file) 241 | for i in datas: 242 | r = json.loads(i) 243 | line_datas.append(r) 244 | return line_datas 245 | 246 | 247 | def get_punctuations(): 248 | punctuations = set(list(string.punctuation + "。!?”’“‘…·《》【】—-,、,")) 249 | return punctuations 250 | 251 | 252 | def random_select_unrepeat(items, size): 253 | # 不重复 254 | selected = np.random.choice(items, replace=False, size=size) 255 | return selected 256 | 257 | 258 | # 将文档拆分成句子,根据max_len合并 259 | def cut_to_sents(text, max_len=120, merge=False): 260 | import re 261 | 262 | sentences = re.split(r"(?|\?|。|!|\…\…|\r|\n)", text) 263 | 264 | clip_sents = [] 265 | for i in sentences: 266 | left = i 267 | while len(left) > 0: 268 | clip_sents.append(left[:max_len]) 269 | left = left[max_len:] 270 | return clip_sents 271 | 272 | for i in range(len(sentences)): 273 | if i % 2 == 1: 274 | _sentences.append(sentences[i - 1] + sentences[i]) 275 | elif i == len(sentences) - 1: 276 | # 最后一个 277 | _sentences.append(sentences[i]) 278 | 279 | clip_sents = [] 280 | for i in _sentences: 281 | left = i 282 | while len(left) > 0: 283 | clip_sents.append(left[:max_len]) 284 | left = left[max_len:] 285 | # 不拼接 286 | if not merge: 287 | return clip_sents 288 | 289 | l = "" 290 | merged_sents = [] 291 | for index, i in enumerate(clip_sents): 292 | if len(l) + len(i) > max_len: 293 | merged_sents.append(l) 294 | l = "" 295 | l += i 296 | else: 297 | l += i 298 | 299 | if len(l) > 0: 300 | merged_sents.append(l) 301 | return merged_sents 302 | 303 | 304 | def getImage(filename, bs64=False): 305 | if filename == None: 306 | return None 307 | 308 | if isinstance(filename, Image.Image): 309 | im = filename 310 | elif filename.startswith("data:image"): 311 | return filename 312 | elif is_url(filename): 313 | response = requests.get(filename, timeout=15) 314 | if response.ok: 315 | im = Image.open(BytesIO(response.content)) 316 | else: 317 | print(f"error get img from url:{filename}") 318 | return None 319 | elif os.path.isfile(filename): 320 | im = Image.open(filename) 321 | elif isinstance(filename, Image.Image): 322 | im = filename 323 | else: 324 | im = base64_to_image(filename) 325 | 326 | if bs64: 327 | return image_to_base64(im) 328 | return im 329 | 330 | 331 | def image_to_base64(image: Image.Image, fmt="png") -> str: 332 | from io import BytesIO 333 | import base64 334 | 335 | output_buffer = BytesIO() 336 | image.save(output_buffer, format=fmt) 337 | byte_data = output_buffer.getvalue() 338 | base64_str = base64.b64encode(byte_data).decode("utf-8") 339 | return f"data:image/{fmt};base64," + base64_str 340 | 341 | 342 | def base64_to_image(image_base64): 343 | import base64 344 | from PIL import Image 345 | from io import BytesIO 346 | 347 | datas = image_base64.split(",") 348 | if len(datas) == 1: 349 | image_base64 = datas[0] 350 | else: 351 | image_base64 = datas[1] 352 | #  353 | img_bs64 = base64.b64decode(image_base64) 354 | image = Image.open(BytesIO(img_bs64)) 355 | return image 356 | 357 | 358 | def url_to_base64(url): 359 | response = requests.get(url) 360 | image_data = response.content 361 | base64_data = base64.b64encode(image_data) 362 | base64_image = base64_data.decode("utf-8") 363 | return f"data:image/jpeg;base64,{base64_image}" 364 | 365 | 366 | def image_to_np(image): 367 | return np.array(image) 368 | 369 | 370 | def hash_lora_file(filename): 371 | import mmap 372 | import hashlib 373 | 374 | """Hashes a .safetensors file using the new hashing method. 375 | Only hashes the weights of the model.""" 376 | hash_sha256 = hashlib.sha256() 377 | blksize = 1024 * 1024 378 | 379 | with open(filename, mode="r", encoding="utf8") as file_obj: 380 | with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: 381 | header = m.read(8) 382 | n = int.from_bytes(header, "little") 383 | 384 | with open(filename, mode="rb") as file_obj: 385 | offset = n + 8 386 | file_obj.seek(offset) 387 | for chunk in iter(lambda: file_obj.read(blksize), b""): 388 | hash_sha256.update(chunk) 389 | file_hash = hash_sha256.hexdigest() 390 | legacy_hash = file_hash[:12] 391 | return file_hash, legacy_hash 392 | 393 | 394 | def hash_sd_model(filename): 395 | """old hash that only looks at a small part of the file and is prone to collisions""" 396 | 397 | try: 398 | with open(filename, "rb") as file: 399 | import hashlib 400 | 401 | m = hashlib.sha256() 402 | 403 | file.seek(0x100000) 404 | m.update(file.read(0x10000)) 405 | return m.hexdigest() 406 | except FileNotFoundError: 407 | return "NOFILE" 408 | 409 | 410 | def hash_file(filename): 411 | """old hash that only looks at a small part of the file and is prone to collisions""" 412 | 413 | try: 414 | with open(filename, "rb") as file: 415 | import hashlib 416 | 417 | m = hashlib.sha256() 418 | m.update(file.read(0x10000)) 419 | 420 | return m.hexdigest()[:10] 421 | except FileNotFoundError: 422 | return "NOFILE" 423 | 424 | 425 | def file_walk(root_path, excludes=[]): 426 | file_list = [] 427 | for dirPath, dirNames, fileNames in os.walk(root_path): 428 | for fileName in fileNames: 429 | if not fileName.split(".")[-1] in excludes: 430 | # if not fileName.lower().endswith('.png') and not fileName.lower().endswith('.jsonl'): 431 | # if not fileName.lower().endswith('.png') and not fileName.lower().endswith('.jsonl'): 432 | filePath = os.path.join(dirPath, fileName) 433 | file_list.append(filePath) 434 | return file_list 435 | 436 | 437 | def get_latest_file(root_path): 438 | file_list = file_walk(root_path) 439 | latest_file = max(file_list, key=os.path.getctime) 440 | return latest_file 441 | 442 | 443 | def webp2png(file_list): 444 | from PIL import Image, ImageOps 445 | 446 | for srcImagePath in file_list: 447 | image = Image.open(srcImagePath) 448 | image = ImageOps.exif_transpose(image) 449 | dstImagePath = os.path.splitext(srcImagePath)[0] + ".png" 450 | image.save(dstImagePath) 451 | print("%s ---> %s" % (srcImagePath, dstImagePath)) 452 | 453 | 454 | def get_file_prefix(fname): 455 | names = fname.split(".")[:-1] 456 | return ".".join(names) 457 | 458 | 459 | def download_file(src, dest): 460 | # import wget 461 | # import ssl 462 | # # 取消ssl全局验证 463 | # ssl._create_default_https_context = ssl._create_unverified_context 464 | # wget.download(src, dest) 465 | import os 466 | 467 | os.system(f"wget -O {dest} {src}") 468 | 469 | 470 | def img_compress(in_file, out_file, target_size=40): 471 | from PIL import Image, ImageFile 472 | 473 | def get_size(file): 474 | # 获取文件大小:KB 475 | size = os.path.getsize(file) 476 | return int(size / 1024) 477 | 478 | # 防止图片超过178956970 pixels 而报错 479 | ImageFile.LOAD_TRUNCATED_IMAGES = True 480 | Image.MAX_IMAGE_PIXELS = None 481 | # 读取img文件 482 | im = Image.open(in_file) 483 | if im.mode == "RGBA": 484 | # print(in_file) 485 | im = im.convert("RGB") 486 | 487 | o_size = get_size(in_file) 488 | if o_size > target_size: 489 | # scale=im.width*im.height 490 | scale = math.sqrt(target_size / o_size) 491 | height = int(im.height * scale) 492 | width = int(im.width * scale) 493 | im = im.resize((width, height)) 494 | im.save(out_file) 495 | 496 | 497 | def img_compress_bs64(bs64_in, target_size=1000): 498 | if target_size == None: 499 | return bs64_in 500 | bytes_count = len(bs64_in) 501 | size = bytes_count / 1024 502 | # print("img_size", size) 503 | image = base64_to_image(bs64_in) 504 | 505 | if size > target_size: 506 | scale = math.sqrt(target_size / size) 507 | height = int(image.height * scale) 508 | width = int(image.width * scale) 509 | image = image.resize((width, height)) 510 | output = image_to_base64(image) 511 | return output 512 | 513 | 514 | def is_url(data): 515 | # 判断字符串是否是IP地址 516 | def is_ip_address(string): 517 | parts = string.split(".") 518 | if len(parts) != 4: 519 | return False 520 | for part in parts: 521 | if not part.isdigit() or int(part) < 0 or int(part) > 255: 522 | return False 523 | return True 524 | 525 | if not data: 526 | return False 527 | 528 | if data.startswith("http://") or data.startswith("https://"): 529 | return True 530 | elif is_ip_address(data): 531 | return True 532 | else: 533 | return False 534 | 535 | # return bool(parsed_url.scheme) 536 | 537 | 538 | def fmt_data(time_obj): 539 | fmt_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time_obj)) 540 | 541 | return fmt_time 542 | 543 | 544 | def get_self_ip(): 545 | import socket 546 | try: 547 | # 创建一个UDP套接字 548 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 549 | 550 | # 连接到公共的 DNS 服务器 551 | sock.connect(('8.8.8.8', 80)) 552 | 553 | # 获取本地套接字的地址信息 554 | ip_address = sock.getsockname()[0] 555 | 556 | return ip_address 557 | except socket.error: 558 | return '无法获取IP地址' 559 | 560 | 561 | def get_remote_ip(): 562 | import requests 563 | res = requests.get("https://ipinfo.io/ip") 564 | return res.text 565 | 566 | 567 | def get_osr_create_event_loop(): 568 | import asyncio 569 | # 尝试获取当前线程的事件循环 570 | try: 571 | loop = asyncio.get_event_loop() 572 | except RuntimeError: # 当前线程没有事件循环时会引发这个异常 573 | loop = asyncio.new_event_loop() 574 | asyncio.set_event_loop(loop) 575 | 576 | return loop 577 | 578 | # 异地组网 579 | 580 | 581 | def get_vnc_ip(): 582 | import socket 583 | import fcntl 584 | import struct 585 | 586 | def get_ip_address(ifname): 587 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 588 | try: 589 | ip_address = socket.inet_ntoa(fcntl.ioctl( 590 | s.fileno(), 591 | 0x8915, # SIOCGIFADDR 592 | struct.pack('256s', ifname[:15].encode('utf-8')) 593 | )[20:24]) 594 | return ip_address 595 | except IOError: 596 | return None 597 | 598 | # 获取所有网卡名称 599 | nic_names = socket.if_nameindex() 600 | 601 | # 遍历每个网卡,并获取对应的IP地址 602 | for nic in nic_names: 603 | name = nic[1] 604 | ip_address = get_ip_address(name) 605 | if ip_address: 606 | if name == "oray_vnc": 607 | return ip_address 608 | return None 609 | 610 | 611 | def encode_base64(s): 612 | # 将字符串转换为字节 613 | byte_representation = s.encode('utf-8') 614 | # 使用base64库进行编码 615 | base64_bytes = base64.b64encode(byte_representation) 616 | # 将字节转换回字符串 617 | base64_string = base64_bytes.decode('utf-8') 618 | return base64_string 619 | 620 | 621 | 622 | 623 | 624 | if __name__ == "__main__": 625 | # articles = readfile2arr("./datas/corpus/raw_articles.csv") 626 | # print(len(articles)) 627 | # print(collate_fn_from_map( 628 | # [{'a': 'texta', 'b': 'textb'}, {'a': 'texta2', 'b': 'textb2'}], keys=[])) 629 | # data = {"a_f": "a", "b_f": "b"} 630 | # datas = [data for i in range(10)] 631 | # # print(json.dumps(data, ensure_ascii=False)) 632 | # # print(json.dumps(data, indent=4, ensure_ascii=False)) 633 | # save_huggface_json_datas("./test.json", datas) 634 | 635 | from PIL import Image, ImageOps 636 | 637 | image = Image.open("/home/ubuntu/projects/imgs/origin/0f3335e556.jpg") 638 | bs64 = image_to_base64(image) 639 | out = img_compress_bs64(bs64, target_size=100) 640 | base64_to_image(out).save("a.png") 641 | # print(len(out)) 642 | -------------------------------------------------------------------------------- /utils/logger_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | def get_logger(logger_name, file_path=None, level=logging.DEBUG, as_default_logger=True) -> logging.Logger: 4 | logger = logging.getLogger(logger_name) 5 | logger.propagate = False 6 | logger.setLevel(level) 7 | if not file_path: 8 | file_path = '../logs' 9 | if not os.path.exists(file_path): 10 | os.makedirs(file_path) 11 | 12 | fh = logging.FileHandler(f'{file_path}/{logger_name}.log') 13 | fh.setLevel(level) 14 | logger.addHandler(fh) 15 | if as_default_logger: 16 | logging.root = logger 17 | return logger 18 | -------------------------------------------------------------------------------- /utils/wx_util.py: -------------------------------------------------------------------------------- 1 | # 定义XML转字典的函数 2 | def trans_xml_to_dict(data_xml): 3 | # soup = BeautifulSoup(data_xml, features='xml') 4 | soup = BeautifulSoup(data_xml, "html.parser") 5 | 6 | xml = soup.find('xml') # 解析XML 7 | if not xml: 8 | return {} 9 | data_dict = dict([(item.name, item.text) for item in xml.find_all()]) 10 | return data_dict 11 | 12 | 13 | # 定义字典转XML的函数 14 | def trans_dict_to_xml(data_dict): 15 | data_xml = [] 16 | for k in sorted(data_dict.keys()): # 遍历字典排序后的key 17 | v = data_dict.get(k) # 取出字典中key对应的value 18 | if k == 'detail' and not v.startswith(''.format(v) 20 | data_xml.append('<{key}>{value}'.format(key=k, value=v)) 21 | # return '{}'.format(''.join(data_xml)) # 返回XML 22 | # 返回XML,并转成utf-8,解决中文的问题 23 | return '{}'.format(''.join(data_xml)).encode('utf-8') 24 | -------------------------------------------------------------------------------- /wss/__pycache__/mj_wss_manager.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/__pycache__/mj_wss_manager.cpython-38.pyc -------------------------------------------------------------------------------- /wss/__pycache__/mj_wss_proxy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/__pycache__/mj_wss_proxy.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/__pycache__/base_message_handler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/base_message_handler.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/__pycache__/describe_success_handler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/describe_success_handler.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/__pycache__/imagine_handler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/imagine_handler.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/__pycache__/imagine_hanlder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/imagine_hanlder.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/__pycache__/message_create_handler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/message_create_handler.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/__pycache__/upscale_handler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/upscale_handler.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/__pycache__/variation_handler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/variation_handler.cpython-38.pyc -------------------------------------------------------------------------------- /wss/handler/base_message_handler.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class MessageHandler: 4 | def __init__(self,) -> None: 5 | pass 6 | 7 | def handle(self, message): 8 | pass -------------------------------------------------------------------------------- /wss/handler/describe_success_handler.py: -------------------------------------------------------------------------------- 1 | from wss.handler.base_message_handler import MessageHandler 2 | import redis 3 | import re 4 | import json 5 | from service.mj_data_service import MjDataService, MjTask 6 | from service.notify_service import NotifyService 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | class DescribeSuccessHandler(MessageHandler): 11 | 12 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None: 13 | self.redis_client = redis_client 14 | self.mj_data_service = mj_data_service 15 | self.notify_service = notify_service 16 | 17 | 18 | def handle(self, message): 19 | op = message['op'] 20 | t = message['t'] 21 | if t == 'MESSAGE_UPDATE': 22 | d = message['d'] 23 | # nonce = d['nonce'] 24 | content = d['content'] 25 | message_id = d['id'] 26 | d_type = d['type'] 27 | if "embeds" not in d or len(d['embeds'])==0: 28 | return 29 | 30 | if d_type == 20: 31 | embed = d['embeds'][0] 32 | description = embed['description'] 33 | image_url = embed['image']['url'] 34 | 35 | target_task = None 36 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list() 37 | 38 | for task in task_list: 39 | if task.task_status == 'pending' and task.image_prompt == image_url and task.task_type == 'describe': 40 | target_task = task 41 | break 42 | 43 | if target_task: 44 | self.mj_data_service.update_task_status( 45 | target_task, "success") 46 | self.mj_data_service.update_task_description( 47 | target_task, description) 48 | self.mj_data_service.update_task_progress( 49 | target_task, "100%") 50 | self.notify_service.notify_task_change(target_task) 51 | d['mj_proxy_handled'] = True -------------------------------------------------------------------------------- /wss/handler/imagine_hanlder.py: -------------------------------------------------------------------------------- 1 | from wss.handler.base_message_handler import MessageHandler 2 | import redis 3 | import re 4 | import json 5 | from service.mj_data_service import MjDataService, MjTask 6 | import logging 7 | from service.notify_service import NotifyService 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class ImagineHandler(MessageHandler): 12 | 13 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None: 14 | self.redis_client = redis_client 15 | self.mj_data_service = mj_data_service 16 | self.notify_service = notify_service 17 | # "**a white haired boy --niji 5** - <@1233674918630264866> (Waiting to start)" 18 | self.prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)") 19 | 20 | def extract_prompt(self,prompt): 21 | prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)") 22 | matched = prompt_pattern.match(prompt) 23 | if matched: 24 | return matched.group(1) 25 | # if matched: 26 | # inner_prompt = matched.group(1) 27 | 28 | # image_pattern = re.compile("<(.*?)> (.*)") 29 | # inner_mathed = image_pattern.match(inner_prompt) 30 | # if inner_mathed: 31 | # image_url = inner_mathed.group(1) 32 | # prompt = inner_mathed.group(2) 33 | # return image_url, prompt 34 | # else: 35 | # return None, inner_prompt 36 | else: 37 | raise Exception("prompt format error") 38 | 39 | def get_buttons(self,d): 40 | buttons=[] 41 | for i in range(len(d['components'])): 42 | buttons.extend(d['components'][i]['components']) 43 | return buttons 44 | def handle(self, message): 45 | op = message['op'] 46 | t = message['t'] 47 | if t == 'MESSAGE_CREATE': 48 | d = message['d'] 49 | message_id = d['id'] 50 | d_type = d['type'] 51 | channel_id = d['channel_id'] 52 | content = d['content'] 53 | attachments = d['attachments'] 54 | matched = self.prompt_pattern.findall(content) 55 | prompt = matched[0][0] if matched else None 56 | progress = matched[0][1] if matched else None 57 | 58 | # ' a white haired boy --niji 5 --relax' 59 | 60 | if d_type == 0: 61 | # Imagine已经完成 62 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list() 63 | # 查找条件: 64 | # 1. 任务状态为pending 65 | # 2. 任务的prompt相等 66 | # 3. task_type为imagine 67 | # print("finished") 68 | target_task = None 69 | for task in task_list: 70 | 71 | if task.task_status == 'pending' and task.submit_prompt == prompt: 72 | target_task = task 73 | break 74 | # if img_promot: 75 | # if task.task_status == 'pending' and task.prompt == prompt and task.task_type == 'imagine' and task.image_prompt!=None: 76 | # target_task = task 77 | # break 78 | # else: 79 | # if task.task_status == 'pending' and task.prompt == prompt and task.task_type == 'imagine': 80 | # target_task = task 81 | # break 82 | 83 | if target_task and len(attachments) > 0: 84 | 85 | buttons = self.get_buttons(d) 86 | self.mj_data_service.update_buttons(target_task, buttons) 87 | 88 | 89 | self.mj_data_service.update_finished_message_id( 90 | target_task, message_id) 91 | self.mj_data_service.update_task_message_id( 92 | target_task, message_id) 93 | 94 | self.mj_data_service.update_task_status( 95 | target_task, "success") 96 | self.mj_data_service.update_task_image_url( 97 | target_task, attachments[0]['url']) 98 | self.mj_data_service.update_task_progress( 99 | target_task, "100%") 100 | self.notify_service.notify_task_change(target_task) 101 | d['mj_proxy_handled'] = True 102 | else: 103 | raise Exception("no task found") 104 | 105 | elif d_type == 20: 106 | # Imagine创建成功 107 | nonce = d['nonce'] 108 | 109 | task = self.mj_data_service.get_task_by_nonce(nonce) 110 | if not task: 111 | logger.error(f"no task found, nonce: {nonce}") 112 | return 113 | 114 | task.submit_prompt = prompt 115 | self.mj_data_service.update_task_progress( 116 | task, "0%") 117 | self.mj_data_service.update_task_nonce(task, nonce) 118 | self.mj_data_service.update_task_message_id(task, message_id) 119 | self.mj_data_service.update_message_id_map(task) 120 | self.notify_service.notify_task_change(task) 121 | 122 | d['mj_proxy_handled'] = True 123 | 124 | elif t == 'MESSAGE_UPDATE': 125 | d = message['d'] 126 | # nonce = d['nonce'] 127 | content = d['content'] 128 | attachments = d['attachments'] 129 | if not content: 130 | return 131 | message_id = d['id'] 132 | d_type = d['type'] 133 | if d_type == 20: 134 | # '**a white haired girl --niji 5** - <@1233674918630264866> (15%) (relaxed)' 135 | # 提取出content中的进度 136 | matched = self.prompt_pattern.findall(content) 137 | prompt = matched[0][0] if matched else None 138 | progress = matched[0][1] if matched else None 139 | channel_id = d['channel_id'] 140 | task = self.mj_data_service.get_task_by_message_id(message_id) 141 | 142 | if not task: 143 | logger.error(f"no task found, message_id: {message_id}") 144 | return 145 | self.mj_data_service.update_task_progress(task, progress) 146 | if len(attachments): 147 | self.mj_data_service.update_task_image_url( 148 | task, attachments[0]['url']) 149 | 150 | self.notify_service.notify_task_change(task) 151 | d['mj_proxy_handled'] = True 152 | print( 153 | f'imagine message update :channel_id:{channel_id},content:{content}message_id:{message_id},progress:{progress}') 154 | -------------------------------------------------------------------------------- /wss/handler/shorten_success_handler.py: -------------------------------------------------------------------------------- 1 | from wss.handler.base_message_handler import MessageHandler 2 | import redis 3 | import re 4 | import json 5 | from service.mj_data_service import MjDataService, MjTask 6 | import logging 7 | from service.notify_service import NotifyService 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class DescribeSuccessHandler(MessageHandler): 12 | 13 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None: 14 | self.redis_client = redis_client 15 | self.mj_data_service = mj_data_service 16 | self.notify_service = notify_service 17 | 18 | 19 | def handle(self, message): 20 | op = message['op'] 21 | t = message['t'] 22 | if t == 'MESSAGE_UPDATE': 23 | d = message['d'] 24 | # nonce = d['nonce'] 25 | content = d['content'] 26 | message_id = d['id'] 27 | d_type = d['type'] 28 | if "embeds" not in d or len(d['embeds'])==0: 29 | return 30 | 31 | if d_type == 19: 32 | embed = d['embeds'][0] 33 | description = embed['description'] 34 | image_url = embed['image']['url'] 35 | 36 | target_task = None 37 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list() 38 | 39 | for task in task_list: 40 | if task.task_status == 'pending' and task.image_prompt == image_url and task.task_type == 'shorten': 41 | target_task = task 42 | break 43 | 44 | if target_task: 45 | self.mj_data_service.update_task_status( 46 | target_task, "success") 47 | self.mj_data_service.update_task_description( 48 | target_task, description) 49 | self.mj_data_service.update_task_progress( 50 | target_task, "100%") 51 | self.notify_service.notify_task_change(target_task) 52 | d['mj_proxy_handled'] = True -------------------------------------------------------------------------------- /wss/handler/upscale_handler.py: -------------------------------------------------------------------------------- 1 | from wss.handler.base_message_handler import MessageHandler 2 | import redis 3 | import re 4 | import json 5 | from service.mj_data_service import MjDataService,MjTask 6 | from service.notify_service import NotifyService 7 | 8 | class UpscaleHandler(MessageHandler): 9 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None: 10 | self.redis_client = redis_client 11 | self.mj_data_service = mj_data_service 12 | self.notify_service = notify_service 13 | self.prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)") 14 | # "**a white haired boy --niji 5** - Image #1 <@1233674918630264866>" 15 | self.upscale_pattern = re.compile("\*\*(.*?)\*\* - Image #(\d) <@\d+>") 16 | # \*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\) 17 | self.variation_pattern = re.compile("\*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\)") 18 | 19 | def get_reference_message_id(self,d): 20 | return d['message_reference']['message_id'] 21 | 22 | def get_buttons(self,d): 23 | buttons=[] 24 | for i in range(len(d['components'])): 25 | buttons.extend(d['components'][i]['components']) 26 | return buttons 27 | 28 | def handle(self, message): 29 | op = message['op'] 30 | t = message['t'] 31 | if t == 'MESSAGE_CREATE': 32 | d = message['d'] 33 | message_id = d['id'] 34 | d_type = d['type'] 35 | channel_id = d['channel_id'] 36 | content = d['content'] 37 | attachments = d['attachments'] 38 | matched = self.prompt_pattern.findall(content) 39 | prompt = matched[0][0] if matched else None 40 | progress = matched[0][1] if matched else None 41 | 42 | if d_type == 19: 43 | # Upscale 图片放大 44 | if "Image" in content: 45 | # Upscale 46 | reference_message_id = self.get_reference_message_id(d) 47 | target_task = None 48 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list() 49 | matched = self.upscale_pattern.findall(content) 50 | # upscale_index= matched[0][1] if matched else None 51 | 52 | for task in task_list: 53 | if task.task_status == 'pending' and task.reference_message_id == reference_message_id and task.task_type == 'action': 54 | target_task = task 55 | break 56 | if target_task and len(attachments) > 0: 57 | buttons = self.get_buttons(d) 58 | self.mj_data_service.update_buttons(target_task, buttons) 59 | self.mj_data_service.update_task_message_id( 60 | target_task, message_id) 61 | 62 | self.mj_data_service.update_task_status( 63 | target_task, "success") 64 | self.mj_data_service.update_task_image_url( 65 | target_task, attachments[0]['url']) 66 | self.mj_data_service.update_task_progress( 67 | target_task, "100%") 68 | self.notify_service.notify_task_change(target_task) 69 | 70 | 71 | d['mj_proxy_handled'] = True 72 | print( 73 | f'channel_id:{channel_id},content:{content},message_id:{message_id}') 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /wss/handler/variation_handler.py: -------------------------------------------------------------------------------- 1 | from wss.handler.base_message_handler import MessageHandler 2 | import redis 3 | import re 4 | import json 5 | from service.mj_data_service import MjDataService,MjTask 6 | from service.notify_service import NotifyService 7 | 8 | class VariationHandler(MessageHandler): 9 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None: 10 | self.redis_client = redis_client 11 | self.mj_data_service = mj_data_service 12 | self.notify_service = notify_service 13 | self.prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)") 14 | # "**a white haired boy --niji 5** - Image #1 <@1233674918630264866>" 15 | self.upscale_pattern = re.compile("\*\*(.*?)\*\* - Image #(\d) <@\d+>") 16 | # \*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\) 17 | self.variation_pattern = re.compile("\*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\)") 18 | def get_reference_message_id(self,d): 19 | return d['message_reference']['message_id'] 20 | 21 | def get_buttons(self,d): 22 | buttons=[] 23 | for i in range(len(d['components'])): 24 | buttons.extend(d['components'][i]['components']) 25 | return buttons 26 | 27 | def handle(self, message): 28 | op = message['op'] 29 | t = message['t'] 30 | if t == 'MESSAGE_CREATE': 31 | d = message['d'] 32 | message_id = d['id'] 33 | d_type = d['type'] 34 | channel_id = d['channel_id'] 35 | content = d['content'] 36 | attachments = d['attachments'] 37 | matched = self.prompt_pattern.findall(content) 38 | prompt = matched[0][0] if matched else None 39 | progress = matched[0][1] if matched else None 40 | 41 | if d_type == 19: 42 | if 'Variations' in content: 43 | reference_message_id = self.get_reference_message_id(d) 44 | 45 | target_task = None 46 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list() 47 | matched = self.variation_pattern.findall(content) 48 | upscale_index= matched[0][1] if matched else None 49 | 50 | for task in task_list: 51 | if task.task_status == 'pending' and task.reference_message_id == reference_message_id and task.task_type == 'action': 52 | target_task = task 53 | break 54 | if target_task and len(attachments) > 0: 55 | 56 | buttons = self.get_buttons(d) 57 | self.mj_data_service.update_buttons(target_task, buttons) 58 | 59 | self.mj_data_service.update_task_message_id( 60 | target_task, message_id) 61 | self.mj_data_service.update_task_status( 62 | target_task, "success") 63 | self.mj_data_service.update_task_image_url( 64 | target_task, attachments[0]['url']) 65 | self.mj_data_service.update_task_progress( 66 | target_task, "100%") 67 | 68 | self.notify_service.notify_task_change(target_task) 69 | d['mj_proxy_handled'] = True 70 | 71 | -------------------------------------------------------------------------------- /wss/mj_wss_manager.py: -------------------------------------------------------------------------------- 1 | 2 | from support.mj_config import MjConfig 3 | import redis 4 | from wss.mj_wss_proxy import MjWssSercice 5 | from service.mj_data_service import MjDataService 6 | from service.notify_service import NotifyService 7 | from support.mj_account import MjAccount 8 | from typing import List 9 | 10 | class MjWssManager: 11 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None: 12 | self.mj_config = mj_config 13 | self.redis_client = redis_client 14 | self.mj_data_service = mj_data_service 15 | self.notify_service = notify_service 16 | self.wss_list:List[MjWssSercice] = [] 17 | self.init_wss_list() 18 | 19 | def init_wss_list(self): 20 | mj_accounts = self.mj_config.get_accounts() 21 | for account in mj_accounts: 22 | mj_account = MjAccount(account) 23 | self.wss_list.append(MjWssSercice( 24 | self.mj_config.mj_config, self.redis_client, self.mj_data_service, self.notify_service, mj_account)) 25 | 26 | def start_all(self): 27 | for wss in self.wss_list: 28 | wss.start() -------------------------------------------------------------------------------- /wss/mj_wss_proxy.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import websocket 3 | import json 4 | import ssl 5 | import time 6 | import asyncio 7 | from utils.logger_util import get_logger 8 | import logging 9 | from typing import List 10 | from wss.handler.base_message_handler import MessageHandler 11 | from wss.handler.upscale_handler import UpscaleHandler 12 | from wss.handler.imagine_hanlder import ImagineHandler 13 | from wss.handler.upscale_handler import UpscaleHandler 14 | from wss.handler.describe_success_handler import DescribeSuccessHandler 15 | from wss.handler.variation_handler import VariationHandler 16 | import redis 17 | from service.mj_data_service import MjDataService 18 | from service.notify_service import NotifyService 19 | from support.mj_account import MjAccount 20 | 21 | logger = get_logger(__name__) 22 | 23 | 24 | class MjWssSercice: 25 | def __init__(self, config, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService, account: MjAccount): 26 | self.config = config 27 | self.account = account 28 | self.sequence = 1 29 | self.redis_client = redis_client 30 | self.mj_data_service = mj_data_service 31 | self.notify_service = notify_service 32 | # Discord WebSocket 地址 33 | self.ws_url = config['ng']['discord_ws'] 34 | self.bot_token = account.get_bot_token() 35 | self.ws = None 36 | 37 | self.message_handlers: List[MessageHandler] = [] 38 | self.init_handlers() 39 | 40 | def init_handlers(self): 41 | self.message_handlers.append(UpscaleHandler( 42 | self.redis_client, self.mj_data_service, self.notify_service)) 43 | self.message_handlers.append(ImagineHandler( 44 | self.redis_client, self.mj_data_service, self.notify_service)) 45 | self.message_handlers.append(UpscaleHandler( 46 | self.redis_client, self.mj_data_service, self.notify_service)) 47 | self.message_handlers.append(VariationHandler( 48 | self.redis_client, self.mj_data_service, self.notify_service)) 49 | self.message_handlers.append(DescribeSuccessHandler( 50 | self.redis_client, self.mj_data_service, self.notify_service)) 51 | 52 | # 登录 53 | def on_open(self, ws): 54 | logger.info("### opened ###") 55 | self.ws.send(json.dumps({ 56 | 'op': 2, 57 | 'd': { 58 | 'token': self.bot_token, 59 | 'intents': 513, 60 | 'properties': { 61 | '$os': 'linux', 62 | '$browser': 'my_library_name', 63 | '$device': 'my_library_name' 64 | } 65 | } 66 | })) 67 | 68 | # 监听消息 69 | def on_message(self, ws, message): 70 | # print(message) 71 | data = json.loads(message) 72 | # print(data) 73 | if data['t'] not in ['MESSAGE_CREATE', 'MESSAGE_UPDATE', 'MESSAGE_DELETE']: 74 | return 75 | if self.ignoreMessage(data): 76 | return 77 | 78 | # ACK 79 | if 'op' in data and data['op'] == 11: 80 | return 81 | else: 82 | if 'op' in data and data['op'] == 10: 83 | # 心跳包 84 | logger.info("receive heart from message") 85 | heartbeat_interval = data['d']['heartbeat_interval'] 86 | self.ws.send(json.dumps({'op': 1, 'd': self.sequence})) 87 | self.sequence += 1 88 | elif 'op' in data and data['op'] == 0: 89 | pass 90 | # event_type = data['t'] 91 | # if event_type == 'MESSAGE_CREATE': 92 | # message_content = data['d']['content'] 93 | # message_author = data['d']['author']['username'] 94 | # logger.info(f'{message_author}: {message_content}') 95 | 96 | if 'op' in data and data['op'] == 0: 97 | for handler in self.message_handlers: 98 | if "mj_proxy_handled" in data['d'] and data['d']['mj_proxy_handled']: 99 | break 100 | handler.handle(data) 101 | 102 | def ignoreMessage(self, data: dict): 103 | channel_id = data.get('d').get('channel_id') 104 | if not channel_id: 105 | return True 106 | if channel_id != self.account.get_channel_id(): 107 | return True 108 | op = data.get('op') 109 | if op in [0, 10]: 110 | self.remove_extra_info(data) 111 | logger.info(json.dumps(data)) 112 | return False 113 | 114 | # 在打印的时候去除掉多余的信息 115 | def remove_extra_info(self, data: dict): 116 | if 'd' in data: 117 | if "author" in data['d']: 118 | del data['d']['author'] 119 | if "mentions" in data['d']: 120 | del data['d']['mentions'] 121 | if "member" in data['d']: 122 | del data['d']['member'] 123 | if "interaction_metadata" in data['d']: 124 | del data['d']['interaction_metadata'] 125 | if "interaction" in data['d']: 126 | del data['d']['interaction'] 127 | if "referenced_message" in data['d']: 128 | del data['d']['referenced_message'] 129 | 130 | def on_error(self, ws, error): 131 | logger.error(error) 132 | 133 | def on_close(self, ws, close_status_code, close_msg): 134 | logger.info("### closed ###") 135 | 136 | def run_heart(self): 137 | while True: 138 | heartbeat = {"op": 1, "d": self.sequence} 139 | # logger.info("send heart") 140 | time.sleep(30) 141 | self.ws.send(json.dumps(heartbeat)) 142 | self.sequence += 1 143 | 144 | def websocket_start_inner(self): 145 | 146 | self.ws = websocket.WebSocketApp( 147 | self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close) 148 | self.ws.on_open = self.on_open 149 | self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) 150 | 151 | # 断线重连 152 | while True: 153 | try: 154 | self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) 155 | except Exception as e: 156 | logger.error(e) 157 | time.sleep(5) 158 | continue 159 | 160 | def start(self): 161 | # websocket.enableTrace(True) 162 | thread_hi = threading.Thread(target=self.run_heart) 163 | thread_hi.start() 164 | 165 | thread_websocket = threading.Thread(target=self.websocket_start_inner) 166 | thread_websocket.start() 167 | 168 | 169 | if __name__ == "__main__": 170 | pass 171 | --------------------------------------------------------------------------------