├── .gitignore
├── LICENSE
├── README.md
├── blueprints
├── __pycache__
│ ├── blueprints_image_generation.cpython-38.pyc
│ ├── blueprints_midjourney.cpython-38.pyc
│ ├── blueprints_mj.cpython-38.pyc
│ ├── blueprints_openai.cpython-38.pyc
│ ├── blueprints_text_generation.cpython-38.pyc
│ └── blueprints_user.cpython-38.pyc
└── blueprints_image_generation.py
├── entity
├── __pycache__
│ ├── mj_scheme.cpython-38.pyc
│ ├── models.cpython-38.pyc
│ ├── openai_scheme.cpython-38.pyc
│ └── user_scheme.cpython-38.pyc
├── gen.sh
├── mj_scheme.py
├── models.py
├── openai_scheme.py
└── user_scheme.py
├── main.py
├── requirements.txt
├── resources
├── api_params
│ ├── blend.json
│ ├── describe.json
│ ├── imagine.json
│ ├── info.json
│ ├── message.json
│ ├── reroll.json
│ ├── shorten.json
│ ├── upscale.json
│ └── variation.json
└── config
│ └── prod_mj_config.yaml
├── service
├── __pycache__
│ ├── celery_client.cpython-38.pyc
│ ├── celery_service.cpython-38.pyc
│ ├── discord_http_service.cpython-38.pyc
│ ├── mj_data_service.cpython-38.pyc
│ ├── notify_service.cpython-38.pyc
│ └── template_controller.cpython-38.pyc
├── discord_http_service.py
├── discord_ws_service.py
├── mail_service.py
├── mj_data_service.py
├── notify_service.py
└── template_controller.py
├── support
├── Injector.py
├── __init__.py
├── __pycache__
│ ├── Injector.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── image_controller.cpython-38.pyc
│ ├── load_balancer.cpython-38.pyc
│ ├── mj_account.cpython-38.pyc
│ ├── mj_config.cpython-38.pyc
│ ├── mj_task.cpython-38.pyc
│ ├── task_controller.cpython-38.pyc
│ └── template_controller.cpython-38.pyc
├── image_controller.py
├── load_balancer.py
├── mj_account.py
├── mj_config.py
└── task_controller.py
├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── common_util.cpython-38.pyc
│ ├── logger_util.cpython-38.pyc
│ └── wx_util.cpython-38.pyc
├── bootstrap.py
├── common_util.py
├── logger_util.py
└── wx_util.py
└── wss
├── __pycache__
├── mj_wss_manager.cpython-38.pyc
└── mj_wss_proxy.cpython-38.pyc
├── handler
├── __pycache__
│ ├── base_message_handler.cpython-38.pyc
│ ├── describe_success_handler.cpython-38.pyc
│ ├── imagine_handler.cpython-38.pyc
│ ├── imagine_hanlder.cpython-38.pyc
│ ├── message_create_handler.cpython-38.pyc
│ ├── upscale_handler.cpython-38.pyc
│ └── variation_handler.cpython-38.pyc
├── base_message_handler.py
├── describe_success_handler.py
├── imagine_hanlder.py
├── shorten_success_handler.py
├── upscale_handler.py
└── variation_handler.py
├── mj_wss_manager.py
└── mj_wss_proxy.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | */**/*.pyc
3 | .cache_*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
py-midjourney-proxy
4 |
5 | English | [中文](./README_CN.md)
6 |
7 | An unofficial python implementation of the Discord proxy for MidJourney.
8 |
9 | Just one command to build your own MidJourney proxy server.
10 |
11 |
12 |
13 | ## Main Functions
14 |
15 | - [x] Supports Imagine instructions and related actions
16 | - [x] Supports adding image base64 as a placeholder when using the Imagine command
17 | - [x] Supports Blend (image blending) and Describe (image to text) commands
18 | - [x] Supports real-time progress tracking of tasks
19 | - [x] Prompt sensitive word pre-detection, supports override adjustment
20 | - [x] User-token connects to WSS (WebSocket Secure), allowing access to error messages and full functionality
21 | - [x] Supports multi-account configuration, with each account able to set up corresponding task queues
22 |
23 | ## Prerequisites for use
24 |
25 | 1. Register and subscribe to MidJourney, create `your own server and channel`, refer
26 | to https://docs.midjourney.com/docs/quick-start
27 | 2. Obtain user Token, server ID, channel ID: [Method of acquisition](./docs/discord-params.md)
28 |
29 |
30 | ## Local development
31 |
32 | - Depends on python and fastapi
33 | - Change configuration items: Edit resources/config/prod_mj_config.yaml
34 | - Project execution: Start the main.py
35 |
36 | ## Configuration items
37 |
38 | - mj.accounts: Refer
39 | to [Account pool configuration](./docs/config.md#%E8%B4%A6%E5%8F%B7%E6%B1%A0%E9%85%8D%E7%BD%AE%E5%8F%82%E8%80%83)
40 | - mj.task-store.type: Task storage method, default is in_memory (in memory, lost after restart), Redis is an alternative
41 | option.
42 | - mj.task-store.timeout: Task storage expiration time, tasks are deleted after expiration, default is 30 days.
43 | - mj.api-secret: API key, if left empty, authentication is not enabled; when calling the API, you need to add the
44 | request header 'mj-api-secret'.
45 | - mj.translate-way: The method for translating Chinese prompts into English, options include null (default), Baidu, or
46 | GPT.
47 | - For more configuration options, see [Configuration items](./docs/config.md)
48 |
49 | ## Related documentation
50 |
51 | 1. [API Interface Description](./docs/api.md)
52 | 2. [Version Update Log](https://github.com/novicezk/midjourney-proxy/wiki/%E6%9B%B4%E6%96%B0%E8%AE%B0%E5%BD%95)
53 |
54 | ## Precautions
55 |
56 | 1. Frequent image generation and similar behaviors may trigger warnings on your Midjourney account. Please use with
57 | caution.
58 | 2. For common issues and solutions, see [Wiki / FAQ](https://github.com/novicezk/midjourney-proxy/wiki/FAQ)
59 | 3. Interested friends are also welcome to join the discussion group. If the group is full from scanning the code, you
60 | can add the administrator’s WeChat to be invited into the group. Please remark: mj join group.
61 |
62 |
63 |
64 | ## Application Project
65 |
66 | If you have a project that depends on this one and is open source, feel free to contact the author to be added here for
67 | display.
68 |
69 | - [wechat-midjourney](https://github.com/novicezk/wechat-midjourney) : A proxy WeChat client that connects to
70 | MidJourney, intended only as an example application scenario, will no longer be updated.
71 | - [chatgpt-web-midjourney-proxy](https://github.com/Dooy/chatgpt-web-midjourney-proxy) : chatgpt web, midjourney,
72 | gpts,tts, whisper A complete UI solution
73 | - [chatnio](https://github.com/Deeptrain-Community/chatnio) : The next-generation AI one-stop solution for B/C end, an aggregated model platform with exquisite UI and powerful functions
74 | - [new-api](https://github.com/Calcium-Ion/new-api) : An API interface management and distribution system compatible with the Midjourney Proxy
75 | - [stable-diffusion-mobileui](https://github.com/yuanyuekeji/stable-diffusion-mobileui) : SDUI, based on this interface
76 | and SD (System Design), can be packaged with one click to generate H5 and mini-programs.
77 | - [MidJourney-Web](https://github.com/ConnectAI-E/MidJourney-Web) : 🍎 Supercharged Experience For MidJourney On Web UI
78 |
79 | ## Open API
80 |
81 | Provides unofficial MJ/SD open API, add administrator WeChat for inquiries, please remark: api
82 |
83 | ## Others
84 |
85 | If you find this project helpful, please consider giving it a star.
86 |
87 | ## Star History
88 |
89 | [](https://star-history.com/#Anychnn/py-midjourney-proxy&Date)
--------------------------------------------------------------------------------
/blueprints/__pycache__/blueprints_image_generation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_image_generation.cpython-38.pyc
--------------------------------------------------------------------------------
/blueprints/__pycache__/blueprints_midjourney.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_midjourney.cpython-38.pyc
--------------------------------------------------------------------------------
/blueprints/__pycache__/blueprints_mj.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_mj.cpython-38.pyc
--------------------------------------------------------------------------------
/blueprints/__pycache__/blueprints_openai.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_openai.cpython-38.pyc
--------------------------------------------------------------------------------
/blueprints/__pycache__/blueprints_text_generation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_text_generation.cpython-38.pyc
--------------------------------------------------------------------------------
/blueprints/__pycache__/blueprints_user.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/blueprints/__pycache__/blueprints_user.cpython-38.pyc
--------------------------------------------------------------------------------
/blueprints/blueprints_image_generation.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, UploadFile, Request, File
2 |
3 | from entity.mj_scheme import (
4 | ImagineRequest,
5 | ImagineResponse,
6 | QueryImagineStatusRequest,
7 | QueryImagineStatusResponse,
8 | UpscaleRequest,
9 | QueryUpscaleStatusRequest,
10 | QueryVariationStatusRequest,
11 | ImageUploadRequest,
12 | DescribeRequest,
13 | BlendRequest,
14 | ShortenRequest
15 | )
16 | from support.Injector import injector
17 | from support.task_controller import TaskController
18 | from service.mj_data_service import MjTask
19 | from service.mj_data_service import MjDataService
20 | from support.image_controller import ImageController
21 | from support.load_balancer import LoadBalancer
22 | import datetime
23 | from fastapi.responses import JSONResponse
24 | from utils import common_util
25 | import os
26 | import requests
27 | import json
28 |
29 | task_controller = injector.get(TaskController)
30 | mj_data_service = injector.get(MjDataService)
31 | load_balancer = injector.get(LoadBalancer)
32 | # image_controller = injector.get(ImageController)
33 |
34 | image_router = APIRouter()
35 |
36 |
37 | @image_router.post("/midjourney/imagine", response_model=ImagineResponse, summary="提交Imagine任务")
38 | async def submit_imagine(request: Request, imagine_request: ImagineRequest):
39 | prompt = imagine_request.prompt
40 | if not prompt:
41 | return {
42 | "status": 400,
43 | "msg": "prompt不能为空",
44 | "data": None,
45 | }
46 |
47 | imgs = []
48 | if imagine_request.imgs:
49 | for img in imagine_request.imgs:
50 | image_url = load_balancer.get_discord_http_service().upload_img_if_bs64(img)
51 | imgs.append(image_url)
52 | # 将reference_imgs用空格进行拼接
53 | if len(imgs) > 0:
54 | image_prompt = " ".join(imgs)
55 | else:
56 | image_prompt = ""
57 |
58 | text_prompt = prompt
59 |
60 | if imagine_request.mode == "FAST":
61 | text_prompt += " --fast"
62 | elif imagine_request.mode == "RELAX":
63 | text_prompt += " --relax"
64 |
65 | task = MjTask()
66 | task.prompt = text_prompt
67 | task.image_prompt = image_prompt
68 | task.task_type = "imagine"
69 | task.task_status = "pending"
70 | nonce = mj_data_service.next_nonce()
71 | task.nonce = str(nonce)
72 | task.create_time = datetime.datetime.now()
73 |
74 | task_id = task_controller.submit_imagine(task)
75 |
76 | return {
77 | "status": 200,
78 | "msg": "success",
79 | "data": {
80 | "task_id": task_id,
81 | "task_status": task.task_status,
82 | },
83 | }
84 |
85 |
86 | @image_router.post("/midjourney/query_imagine_status", summary="查询Action任务")
87 | async def query_imagine_status(request: Request, body: QueryImagineStatusRequest):
88 | task_id = body.task_id
89 | task: MjTask = mj_data_service.get_task_by_task_id(task_id)
90 | if not task:
91 | return {
92 | "status": 10002,
93 | "msg": f"未查询到任务:task_id:{task_id}",
94 | }
95 | else:
96 | # 根据任务状态返回结果
97 | image_url = task.image_url
98 | if image_url:
99 | image_url = image_url.replace(
100 | "https://cdn.discordapp.com", "http://discordcdn.aidomore.com")
101 |
102 | return {
103 | "status": 200,
104 | "msg": "success",
105 | "data": {
106 | "task_id": task.task_id,
107 | "task_status": task.task_status,
108 | "image_url": image_url,
109 | "progress": task.progress,
110 | },
111 | }
112 |
113 |
114 | @image_router.post("/midjourney/action", summary="提交Action任务")
115 | async def submit_upscale(request: Request, body: UpscaleRequest):
116 | task_id = body.task_id
117 | mode = body.mode
118 | # index = body.index
119 | # action_type = body.action_type
120 | custom_id = body.custom_id
121 |
122 | task: MjTask = mj_data_service.get_task_by_task_id(task_id)
123 | action_task = MjTask()
124 | action_task.task_type = "action"
125 | action_task.create_time = datetime.datetime.now()
126 | action_task.task_status = "pending"
127 | action_task.custom_id = custom_id
128 | action_task.reference_message_id = task.message_id
129 | nonce = mj_data_service.next_nonce()
130 | action_task.nonce = str(nonce)
131 | task_id = task_controller.submit_upscale(action_task)
132 | return {
133 | "status": 200,
134 | "msg": "success",
135 | "data": {
136 | "task_id": task_id,
137 | "task_status": action_task.task_status,
138 | }
139 | }
140 |
141 |
142 | # 查询upscale
143 | @image_router.post("/midjourney/query_action_status", summary="查询Upscale任务")
144 | async def query_upscale_status(request: Request, body: QueryUpscaleStatusRequest):
145 | task_id = body.task_id
146 | task: MjTask = mj_data_service.get_task_by_task_id(task_id)
147 | if not task:
148 | return {
149 | "status": 10002,
150 | "msg": f"未查询到任务:task_id:{task_id}"
151 | }
152 | else:
153 | # 根据任务状态返回结果
154 | image_url = task.image_url
155 | if image_url:
156 | image_url = image_url.replace(
157 | "https://cdn.discordapp.com", "http://discordcdn.aidomore.com")
158 |
159 | return {
160 | "status": 200,
161 | "msg": "success",
162 | "data": {
163 | "task_id": task.task_id,
164 | "task_type": task.task_type,
165 | "task_status": task.task_status,
166 | "image_url": image_url,
167 | "progress": task.progress,
168 | "description": task.description,
169 | "buttons": task.buttons,
170 | },
171 | }
172 |
173 |
174 | @image_router.post("/midjourney/describe", summary="提交Describe任务")
175 | async def submit_describe_func(request: Request, body: DescribeRequest):
176 | if not body.img:
177 | return {
178 | "status": 10001,
179 | "msg": "需要传入参考图片"
180 | }
181 |
182 | task = MjTask()
183 |
184 | task.task_type = "describe"
185 | task.task_status = "pending"
186 | nonce = mj_data_service.next_nonce()
187 | task.nonce = str(nonce)
188 | task.image_prompt = body.img
189 | task.create_time = datetime.datetime.now()
190 | task_id = task_controller.submit_describe(task)
191 | return {
192 | "status": 200,
193 | "msg": "success",
194 | "data": {
195 | "task_id": task_id,
196 | "task_status": task.task_status,
197 | }
198 | }
199 |
200 |
201 | # blend
202 | @image_router.post("/midjourney/blend", summary="提交Blend任务")
203 | async def submit_blend_func(request: Request, body: BlendRequest):
204 | if not body.imgs:
205 | return {
206 | "status": 10001,
207 | "msg": "需要传入参考图片"
208 | }
209 | dimensions = body.dimensions
210 | task = MjTask()
211 | task.task_type = "blend"
212 | task.blend_imgs = body.imgs
213 | task.task_status = "pending"
214 | nonce = mj_data_service.next_nonce()
215 | task.nonce = str(nonce)
216 | task.dimensions = dimensions
217 | task.create_time = datetime.datetime.now()
218 | task_id = task_controller.submit_blend(task)
219 | return {
220 | "status": 200,
221 | "msg": "success",
222 | "data": {
223 | "task_id": task_id,
224 | "task_status": task.task_status,
225 | }
226 | }
227 |
228 |
229 | # 图片上传base64
230 | @image_router.post("/midjourney/upload", summary="base64图片上传")
231 | async def upload_image(request: Request, body: ImageUploadRequest):
232 | if not body.bs64:
233 | return {
234 | "status": 10001,
235 | "msg": "需要传入图片"
236 | }
237 | image_url = image_controller.upload_img_if_bs64(body.bs64)
238 | return {
239 | "status": 200,
240 | "msg": "success",
241 | "data": {
242 | "image_url": image_url
243 | }
244 | }
245 |
246 |
247 | # 图片上传
248 | @image_router.post("/midjourney/upload_image", summary="图片上传")
249 | async def upload_image(file: UploadFile = File(...)):
250 | file_name = file.filename
251 |
252 | if not os.path.exists(".cache_imgs"):
253 | os.mkdir(".cache_imgs")
254 | with open(f"./.cache_imgs/{file_name}", "wb") as f:
255 | f.write(file.file.read())
256 | image = common_util.getImage(f"./.cache_imgs/{file_name}")
257 | img_bs64 = common_util.image_to_base64(image)
258 |
259 | image_url = image_controller.upload_img_if_bs64(img_bs64)
260 | return {
261 | "status": 200,
262 | "msg": "success",
263 | "data": {
264 | "image_url": image_url
265 | }
266 | }
267 |
268 | # shorten
269 | @image_router.post("/midjourney/shorten", summary="shorten")
270 | async def shorten(request: Request, body: ShortenRequest):
271 | if not body.prompt:
272 | return {
273 | "status": 10002,
274 | "msg": "需要传入prompt"
275 | }
276 |
277 | task = MjTask()
278 | task.task_type = "shorten"
279 | task.prompt = body.prompt
280 | task.task_status = "pending"
281 | nonce = mj_data_service.next_nonce()
282 | task.nonce = str(nonce)
283 | task.create_time = datetime.datetime.now()
284 | task_id = task_controller.submit_shorten(task)
285 |
286 | return {
287 | "status": 200,
288 | "msg": "success",
289 | "data": {
290 | "task_id": task_id,
291 | "task_status": task.task_status,
292 | }
293 | }
--------------------------------------------------------------------------------
/entity/__pycache__/mj_scheme.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/mj_scheme.cpython-38.pyc
--------------------------------------------------------------------------------
/entity/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/entity/__pycache__/openai_scheme.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/openai_scheme.cpython-38.pyc
--------------------------------------------------------------------------------
/entity/__pycache__/user_scheme.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/entity/__pycache__/user_scheme.cpython-38.pyc
--------------------------------------------------------------------------------
/entity/gen.sh:
--------------------------------------------------------------------------------
1 | # python -m pwiz -e mysql -H 101.43.141.18 -u xiaohongdou -P RLKXMDsPtfBWpRHk xiaohongdou > models.py
2 | python -m pwiz -e mysql -H localhost -u xiaohongdouapi -P DtcxSyP7aMW8Edws xiaohongdouapi > models.py
--------------------------------------------------------------------------------
/entity/mj_scheme.py:
--------------------------------------------------------------------------------
1 | from pydoc import describe
2 | from typing import Optional
3 |
4 | from typing import List, Tuple, Set
5 | from pydantic import BaseModel, Field
6 |
7 | class ImagineRequest(BaseModel):
8 | # RELAX和FAST两种,默认为RELAX
9 | mode: Optional[str] = "RELAX"
10 | # 任务状态发生变化自动回调的地址
11 | notify_hook: Optional[str]
12 | # 用于参考的图片列表,可以传入base64或者url
13 | imgs: Optional[List[str]]
14 | prompt: str
15 |
16 | class ImagineResponseData(BaseModel):
17 | # 任务ID
18 | task_id: str
19 | # 任务状态
20 | task_status: str
21 | # 图片地址
22 | image_url: Optional[str]
23 |
24 | class ImagineResponse(BaseModel):
25 | status: int = Field(..., title="状态码", description="状态码")
26 | msg: str = Field(..., title="消息", description="消息")
27 | data: ImagineResponseData = Field(..., title="数据", description="数据")
28 |
29 | class QueryImagineStatusRequest(BaseModel):
30 | task_id: str = Field(..., title="任务ID", description="任务ID")
31 | # 返回图片的格式,base64或者url
32 | # img_type: str = "base64"
33 |
34 | class QueryImagineStatusResponseData(BaseModel):
35 | task_id: str = Field(..., title="任务ID", description="任务ID")
36 | task_status: str = Field(..., title="状态码", description="状态码")
37 | image_url: Optional[str]
38 |
39 |
40 | class QueryImagineStatusResponse(BaseModel):
41 | status: int = Field(..., title="状态码", description="状态码")
42 | msg: str = Field(..., title="消息", description="消息")
43 | data: QueryImagineStatusResponseData = Field(..., title="数据", description="数据")
44 |
45 |
46 | class UpscaleRequest(BaseModel):
47 | # 任务ID
48 | task_id: str
49 | # 缩放的索引
50 | # index: int
51 | # action_type: str
52 | mode: Optional[str]
53 | custom_id: str
54 | # 任务状态发生变化自动回调的地址
55 | notify_hook: Optional[str]
56 |
57 |
58 | class QueryUpscaleStatusRequest(BaseModel):
59 | task_id: str = Field(..., title="任务ID", description="任务ID")
60 |
61 |
62 | class QueryVariationStatusRequest(BaseModel):
63 | task_id: str = Field(..., title="任务ID", description="任务ID")
64 |
65 |
66 | class ImageUploadRequest(BaseModel):
67 | bs64: str
68 |
69 |
70 | class DescribeRequest(BaseModel):
71 | img: str
72 |
73 | class BlendRequest(BaseModel):
74 | imgs: List[str]
75 | dimensions: Optional[str]
76 |
77 | class ShortenRequest(BaseModel):
78 | prompt: str
--------------------------------------------------------------------------------
/entity/models.py:
--------------------------------------------------------------------------------
1 | from peewee import *
2 |
3 | database = MySQLDatabase('xiaohongdouapi', **{'charset': 'utf8', 'sql_mode': 'PIPES_AS_CONCAT', 'use_unicode': True, 'host': 'localhost', 'user': 'xiaohongdouapi', 'password': 'DtcxSyP7aMW8Edws'})
4 |
5 | class UnknownField(object):
6 | def __init__(self, *_, **__): pass
7 |
8 | class BaseModel(Model):
9 | class Meta:
10 | database = database
11 |
12 | class User(BaseModel):
13 | create_time = DateTimeField(null=True)
14 | email = CharField(null=True)
15 | phone = CharField(null=True)
16 | update_time = DateTimeField(null=True)
17 |
18 | class Meta:
19 | table_name = 'User'
20 |
21 |
--------------------------------------------------------------------------------
/entity/openai_scheme.py:
--------------------------------------------------------------------------------
1 | from pydoc import describe
2 | from typing import Optional
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class CompletionRequest(BaseModel):
7 | model: str
8 | prompt: str
9 | temperature: float
10 | top_p: float
11 | stream: bool
12 |
13 | class CompletionResponseData(BaseModel):
14 | text: str
15 |
16 |
17 | class CompletionResponse(BaseModel):
18 | status: int=Field(..., description="状态")
19 | msg: str=Field(..., description="消息")
20 | data: CompletionResponseData=Field(..., description="数据")
--------------------------------------------------------------------------------
/entity/user_scheme.py:
--------------------------------------------------------------------------------
1 | from pydoc import describe
2 | from typing import Optional
3 |
4 | from typing import List, Tuple, Set
5 | from pydantic import BaseModel, Field
6 |
7 | # 用户注册请求
8 | class ShowQrcodeRequest(BaseModel):
9 | pass
10 |
11 | # 二维码请求
12 | class QrcodeResponseData(BaseModel):
13 | ticketUrl: str = Field(..., description="二维码地址")
14 | sceneStr: int = Field(..., description="二维码场景值")
15 |
16 | # 二维码展示响应
17 | class ShowQrcodeResponse(BaseModel):
18 | status: int = Field(..., description="状态码")
19 | msg: str = Field(..., description="消息")
20 | data: QrcodeResponseData = Field(..., description="数据")
21 |
22 |
23 | # 二维码状态请求
24 | class QrcodeStatusRequest(BaseModel):
25 | sceneStr: str = Field(..., description="二维码场景值")
26 |
27 |
28 | class QrcodeStatusData(BaseModel):
29 | status: int = Field(..., description="状态码")
30 | uid: int = Field(..., description="用户ID")
31 |
32 | # 二维码状态响应
33 | class QrcodeStatusResponse(BaseModel):
34 | status: int = Field(..., description="状态码")
35 | msg: str = Field(..., description="消息")
36 | data: QrcodeStatusData = Field(..., description="数据")
37 |
38 |
39 | # 用户秘钥生成
40 | class SecretGenerateRequest(BaseModel):
41 | pass
42 |
43 |
44 | # 用户秘钥生成响应
45 | class SecretGenerateResponse(BaseModel):
46 | status: int = Field(..., description="状态码")
47 | msg: str = Field(..., description="消息")
48 | data: str = Field(..., description="数据")
49 |
50 | # 秘钥查询
51 | class SecretQueryRequest(BaseModel):
52 | uid: int = Field(..., description="用户ID")
53 |
54 | class SecretQueryData(BaseModel):
55 | secret: str = Field(..., description="秘钥")
56 |
57 | # 秘钥查询响应
58 | class SecretQueryResponse(BaseModel):
59 | status: int = Field(..., description="状态码")
60 | msg: str = Field(..., description="消息")
61 | data: SecretQueryData = Field(..., description="数据")
62 |
63 | # 支付二维码生成
64 | class PayQrcodeGenerateRequest(BaseModel):
65 | uid: int = Field(..., description="用户ID")
66 |
67 | class PayQrcodeGenerateData(BaseModel):
68 | qrcode_img_bs64: str = Field(..., description="二维码图片base64")
69 | price: int = Field(..., description="价格,单位分")
70 | order_id: str = Field(..., description="订单ID")
71 |
72 |
73 | # 支付二维码生成响应
74 | class PayQrcodeGenerateResponse(BaseModel):
75 | status: int = Field(..., description="状态码")
76 | msg: str = Field(..., description="消息")
77 | data: str = Field(..., description="数据")
78 |
79 | # 支付二维码状态查询
80 | class PayQrcodeStatusRequest(BaseModel):
81 | order_id: str = Field(..., description="订单ID")
82 |
83 | # 支付二维码状态查询响应
84 | class PayQrcodeStatusData(BaseModel):
85 | status: int = Field(..., description="状态码")
86 | price: int = Field(..., description="价格,单位分")
87 | payed: bool = Field(..., description="是否支付")
88 |
89 | # 支付二维码状态查询响应
90 | class PayQrcodeStatusResponse(BaseModel):
91 | status: int = Field(..., description="状态码")
92 | msg: str = Field(..., description="消息")
93 | data: PayQrcodeStatusData = Field(..., description="数据")
94 |
95 |
96 |
97 |
98 | class RegisterRequest(BaseModel):
99 | account: str = Field(..., description="账号")
100 | password: str = Field(..., description="密码")
101 |
102 |
103 | class LoginRequest(BaseModel):
104 | account: str = Field(..., description="账号")
105 | password: str = Field(..., description="密码")
106 |
107 |
108 | class LoginResponseData(BaseModel):
109 | token: str = Field(..., description="token")
110 |
111 |
112 | class LoginResponse(BaseModel):
113 | status: int = Field(..., description="状态码")
114 | msg: str = Field(..., description="消息")
115 | data: LoginResponseData = Field(..., description="数据")
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- ecoding: utf-8 -*-
2 | # @Author: anyang
3 | # @Time: 2024/4/5
4 |
5 | import asyncio
6 | import uvicorn
7 | from fastapi import FastAPI
8 | from blueprints.blueprints_image_generation import image_router
9 | import yaml
10 | # from app.app import app
11 | from support.mj_config import MjConfig
12 | # from union_api_server.Injector import injector
13 | from support.Injector import injector
14 | import subprocess
15 | from wss.mj_wss_proxy import MjWssSercice
16 | from wss.mj_wss_manager import MjWssManager
17 |
18 | router = FastAPI()
19 | router.include_router(image_router, prefix="/image", tags=["图片生成"])
20 |
21 | @router.on_event("startup")
22 | async def startup_event():
23 | pass
24 |
25 |
26 | if __name__ == "__main__":
27 | import argparse
28 | parser = argparse.ArgumentParser(
29 | description='union_api_server server.')
30 | parser.add_argument('--port', type=int, help='server port', default=6013)
31 |
32 | # 从injector获取mj_config
33 | mj_config = injector.get(MjConfig)
34 |
35 | # 启动 Celery worker
36 | # subprocess.Popen(['celery', '-A', 'celery_module.celery_app', 'worker', '--loglevel=INFO','-c',"1"])
37 |
38 | # 是否开启新的线程启动对discord_bot的wss监听
39 | # if mj_config.mj_config["common"]["launch_discord_bot"]:
40 | mj_wss_manager = injector.get(MjWssManager)
41 | mj_wss_manager.start_all()
42 |
43 | args = parser.parse_args()
44 | uvicorn.run(router, port=args.port, host="0.0.0.0")
45 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | toml
2 | numpy
3 | uvicorn
4 | fastapi
5 | Pillow
6 | requests
7 | pyyaml
8 | python-multipart
9 | tinydb
10 | ordered-set
11 | dill
12 | peewee
13 | celery
14 | redis
15 | qrcode
16 | BeautifulSoup4
17 | lxml
18 | oss2
19 | alibabacloud_imageseg20191230==3.0.0
20 | cachetools
21 | pymongo
22 | injector
23 | requests_toolbelt
--------------------------------------------------------------------------------
/resources/api_params/blend.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 2,
3 | "guild_id": "$guild_id",
4 | "channel_id": "$channel_id",
5 | "application_id": "$app_id$",
6 | "session_id": "$session_id$",
7 | "nonce": "$nonce",
8 | "data": {
9 | "version": "1166847114203123796",
10 | "id": "1062880104792997970",
11 | "name": "blend",
12 | "type": 1,
13 | "options": [],
14 | "attachments": []
15 | }
16 | }
--------------------------------------------------------------------------------
/resources/api_params/describe.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 2,
3 | "application_id": "$app_id$",
4 | "guild_id": "$guild_id$",
5 | "channel_id": "$channel_id$",
6 | "session_id": "$session_id$",
7 | "data": {
8 | "version": "1237876415471554625",
9 | "id": "1092492867185950852",
10 | "name": "describe",
11 | "type": 1,
12 | "options": [
13 | {
14 | "type": 3,
15 | "name": "link",
16 | "value": "$prompt$"
17 | }
18 | ],
19 | "attachments": []
20 | },
21 | "nonce": "$nonce$",
22 | "analytics_location": "slash_ui"
23 | }
--------------------------------------------------------------------------------
/resources/api_params/imagine.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 2,
3 | "application_id": "$app_id$",
4 | "guild_id": "$guild_id$",
5 | "channel_id": "$channel_id$",
6 | "session_id": "$session_id$",
7 | "nonce": "$nonce$",
8 | "data": {
9 | "version": "1237876415471554623",
10 | "id": "938956540159881230",
11 | "name": "imagine",
12 | "type": 1,
13 | "options": [
14 | {
15 | "type": 3,
16 | "name": "prompt",
17 | "value": "$prompt$"
18 | }
19 | ],
20 | "attachments": []
21 | }
22 | }
--------------------------------------------------------------------------------
/resources/api_params/info.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 2,
3 | "application_id": "936929561302675456",
4 | "guild_id": "1233678696364245002",
5 | "channel_id": "1233678696364245005",
6 | "session_id": "c38d67b8f1eded0b5603f44c31765817",
7 | "data": {
8 | "version": "1237876415735660565",
9 | "id": "972289487818334209",
10 | "name": "info",
11 | "type": 1,
12 | "options": [],
13 | "application_command": {
14 | "id": "972289487818334209",
15 | "type": 1,
16 | "application_id": "936929561302675456",
17 | "version": "1237876415735660565",
18 | "name": "info",
19 | "description": "View information about your profile.",
20 | "dm_permission": true,
21 | "contexts": [
22 | 0,
23 | 1,
24 | 2
25 | ],
26 | "integration_types": [
27 | 0,
28 | 1
29 | ],
30 | "global_popularity_rank": 2,
31 | "options": [],
32 | "description_localized": "View information about your profile.",
33 | "name_localized": "info"
34 | },
35 | "attachments": []
36 | },
37 | "nonce": "1239288436951613440",
38 | "analytics_location": "slash_ui"
39 | }
--------------------------------------------------------------------------------
/resources/api_params/message.json:
--------------------------------------------------------------------------------
1 | {
2 | "content": "",
3 | "channel_id": "$channel_id$",
4 | "type": 0,
5 | "sticker_ids": [],
6 | "attachments": [
7 | {
8 | "id": "0",
9 | "filename": "$file_name$",
10 | "uploaded_filename": "$final_file_name$"
11 | }
12 | ]
13 | }
--------------------------------------------------------------------------------
/resources/api_params/reroll.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 3,
3 | "guild_id": "$guild_id",
4 | "channel_id": "$channel_id",
5 | "message_id": "$message_id",
6 | "application_id": "936929561302675456",
7 | "session_id": "$session_id",
8 | "nonce": "$nonce",
9 | "message_flags": 0,
10 | "data": {
11 | "component_type": 2,
12 | "custom_id": "MJ::JOB::reroll::0::$message_hash::SOLO"
13 | }
14 | }
--------------------------------------------------------------------------------
/resources/api_params/shorten.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 2,
3 | "application_id": "$app_id$",
4 | "guild_id": "$guild_id$",
5 | "channel_id": "$channel_id$",
6 | "session_id": "$session_id$",
7 | "data": {
8 | "version": "1237876415471554626",
9 | "id": "1121575372539039774",
10 | "name": "shorten",
11 | "type": 1,
12 | "options": [
13 | {
14 | "type": 3,
15 | "name": "prompt",
16 | "value": "$prompt$"
17 | }
18 | ],
19 | "attachments": []
20 | },
21 | "nonce": "$nonce$",
22 | "analytics_location": "slash_ui"
23 | }
--------------------------------------------------------------------------------
/resources/api_params/upscale.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 3,
3 | "guild_id": "$guild_id$",
4 | "channel_id": "$channel_id$",
5 | "message_id": "$message_id$",
6 | "application_id": "$app_id$",
7 | "session_id": "$session_id$",
8 | "nonce": "$nonce$",
9 | "message_flags": 0,
10 | "data": {
11 | "component_type": 2,
12 | "custom_id": "$custom_id$"
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/resources/api_params/variation.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": 3,
3 | "guild_id": "$guild_id$",
4 | "channel_id": "$channel_id$",
5 | "message_id": "$message_id$",
6 | "application_id": "936929561302675456",
7 | "session_id": "$session_id$",
8 | "nonce": "$nonce$",
9 | "message_flags": 0,
10 | "data": {
11 | "component_type": 2,
12 | "custom_id": "$custom_id$"
13 | }
14 | }
--------------------------------------------------------------------------------
/resources/config/prod_mj_config.yaml:
--------------------------------------------------------------------------------
1 | # midjourney的账号
2 | accounts:
3 | - user_token: ""
4 | bot_token: ""
5 | guild_id: ""
6 | channel_id: ""
7 | app_id: ""
8 | session_id: ""
9 |
10 | ng:
11 | discord_ws: https://gateway.discord.gg
12 | # discord_server: http://discordapi.aidomore.com/api/v9/interactions
13 | discord_server: https://discord.com
14 | discord_upload_server: https://discord-attachments-uploads-prd.storage.googleapis.com
15 |
16 | common:
17 | # 是否启动discord bot
18 | launch_discord_bot: false
19 | img_upload_url: ""
20 |
21 | redis:
22 | host: 0.0.0.0
23 | port: 6379
24 | password: "123"
25 | db: 0
--------------------------------------------------------------------------------
/service/__pycache__/celery_client.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/celery_client.cpython-38.pyc
--------------------------------------------------------------------------------
/service/__pycache__/celery_service.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/celery_service.cpython-38.pyc
--------------------------------------------------------------------------------
/service/__pycache__/discord_http_service.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/discord_http_service.cpython-38.pyc
--------------------------------------------------------------------------------
/service/__pycache__/mj_data_service.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/mj_data_service.cpython-38.pyc
--------------------------------------------------------------------------------
/service/__pycache__/notify_service.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/notify_service.cpython-38.pyc
--------------------------------------------------------------------------------
/service/__pycache__/template_controller.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/service/__pycache__/template_controller.cpython-38.pyc
--------------------------------------------------------------------------------
/service/discord_http_service.py:
--------------------------------------------------------------------------------
1 | from support.mj_config import MjConfig
2 | import redis
3 | import requests
4 | import json
5 | from support.task_controller import MjTask
6 | from service.template_controller import TemplateController
7 | import base64
8 | import io
9 | import hashlib
10 | from support.mj_account import MjAccount
11 |
12 | class DiscordHttpService:
13 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, template_controller: TemplateController, mj_account: MjAccount) -> None:
14 | self.redis_client = redis_client
15 | self.mj_config = mj_config
16 | self.mj_account = mj_account
17 | self.template_controller = template_controller
18 |
19 | def get_interaction_url(self):
20 | host = self.mj_config['ng']['discord_server']
21 | return f"{host}/api/v9/interactions"
22 |
23 | def get_messages_url(self):
24 | host=self.mj_config['ng']['discord_server']
25 | CHANNEL_ID = self.mj_account.get_channel_id()
26 | return f"{host}/api/v9/channels/{CHANNEL_ID}/messages"
27 | def get_discord_upload_attachment_url(self):
28 | host=self.mj_config['ng']['discord_server']
29 | CHANNEL_ID = self.mj_account.get_channel_id()
30 | upload_attachment_url = f"{host}/api/v9/channels/{CHANNEL_ID}/attachments"
31 | return upload_attachment_url
32 |
33 |
34 |
35 | def imageine(self, task: MjTask):
36 | url = self.get_interaction_url()
37 |
38 | app_id = self.mj_account.get_app_id()
39 | guild_id = self.mj_account.get_gulid_id()
40 | channel_id = self.mj_account.get_channel_id()
41 | session_id = self.mj_account.get_session_id()
42 | user_token = self.mj_account.get_user_token()
43 |
44 | advanced_prompt = ""
45 |
46 | if task.image_prompt:
47 | advanced_prompt = task.image_prompt+" "+task.prompt
48 | else:
49 | advanced_prompt = task.prompt
50 |
51 | template_map = {
52 | "app_id": app_id,
53 | "guild_id": guild_id,
54 | "channel_id": channel_id,
55 | "session_id": session_id,
56 | "prompt": advanced_prompt,
57 | "nonce": task.nonce
58 | }
59 |
60 | imagine_template = self.template_controller.get_imagine(template_map)
61 |
62 | boundary = "----WebKitFormBoundaryqznUd46iGT62TY0s"
63 | fields = {"payload_json": (None, json.dumps(imagine_template))}
64 | from requests_toolbelt.multipart.encoder import MultipartEncoder
65 | form_data = MultipartEncoder(fields=fields, boundary=boundary)
66 |
67 | HEADERS = {
68 | "Content-Type": f"multipart/form-data; boundary={boundary}",
69 | "Authorization": user_token
70 | }
71 | res = requests.post(url=url, data=form_data, headers=HEADERS)
72 |
73 | if res.status_code not in [200, 204]:
74 | raise Exception(json.loads(res.text))
75 |
76 | print("res.content", res.content)
77 |
78 | def upscale(self, task: MjTask):
79 | url = self.get_interaction_url()
80 |
81 |
82 | app_id = self.mj_account.get_app_id()
83 | guild_id = self.mj_account.get_gulid_id()
84 | channel_id = self.mj_account.get_channel_id()
85 | session_id = self.mj_account.get_session_id()
86 | user_token = self.mj_account.get_user_token()
87 |
88 |
89 | template_map = {
90 | "app_id": app_id,
91 | "guild_id": guild_id,
92 | "channel_id": channel_id,
93 | "session_id": session_id,
94 | "nonce": task.nonce,
95 | "message_id": task.message_id,
96 | "custom_id": task.custom_upscale_id
97 | }
98 |
99 | upscale_template = self.template_controller.get_upscale(template_map)
100 | HEADERS = {
101 | "Content-Type": "application/json",
102 | "Authorization": user_token
103 | }
104 | res = requests.post(url=url, data=json.dumps(
105 | upscale_template), headers=HEADERS)
106 | print(res)
107 |
108 | def variation(self, task: MjTask):
109 | url = self.get_interaction_url()
110 |
111 | app_id = self.mj_account.get_app_id()
112 | guild_id = self.mj_account.get_gulid_id()
113 | channel_id = self.mj_account.get_channel_id()
114 | session_id = self.mj_account.get_session_id()
115 | user_token = self.mj_account.get_user_token()
116 |
117 | template_map = {
118 | "app_id": app_id,
119 | "guild_id": guild_id,
120 | "channel_id": channel_id,
121 | "session_id": session_id,
122 | "nonce": task.nonce,
123 | "message_id": task.message_id,
124 | "custom_id": task.custom_variation_id
125 | }
126 |
127 | upscale_template = self.template_controller.get_upscale(template_map)
128 | HEADERS = {
129 | "Content-Type": "application/json",
130 | "Authorization": user_token
131 | }
132 | res = requests.post(url=url, data=json.dumps(
133 | upscale_template), headers=HEADERS)
134 | print(res)
135 |
136 | def describe(self, task: MjTask):
137 | url = self.get_interaction_url()
138 |
139 | app_id = self.mj_account.get_app_id()
140 | guild_id = self.mj_account.get_gulid_id()
141 | channel_id = self.mj_account.get_channel_id()
142 | session_id = self.mj_account.get_session_id()
143 | user_token = self.mj_account.get_user_token()
144 |
145 | template_map = {
146 | "app_id": app_id,
147 | "guild_id": guild_id,
148 | "channel_id": channel_id,
149 | "session_id": session_id,
150 | "nonce": task.nonce,
151 | "message_id": task.message_id,
152 | "custom_id": task.custom_variation_id,
153 | "prompt": task.image_prompt
154 | }
155 |
156 | upscale_template = self.template_controller.get_describe(template_map)
157 | HEADERS = {
158 | "Content-Type": "application/json",
159 | "Authorization": user_token
160 | }
161 | res = requests.post(url=url, data=json.dumps(
162 | upscale_template), headers=HEADERS)
163 | print(res)
164 |
165 | def blend(self, task: MjTask):
166 | url = self.get_interaction_url()
167 |
168 | app_id = self.mj_account.get_app_id()
169 | guild_id = self.mj_account.get_gulid_id()
170 | channel_id = self.mj_account.get_channel_id()
171 | session_id = self.mj_account.get_session_id()
172 | user_token = self.mj_account.get_user_token()
173 |
174 | advanced_prompt = ""
175 |
176 | if not task.blend_imgs:
177 | raise Exception("blend_imgs is empty")
178 |
179 | advanced_prompt = " ".join(task.blend_imgs)
180 |
181 | if task.dimensions:
182 | dimentsion_prompt = ""
183 | if task.dimensions == 'Square':
184 | dimentsion_prompt = "--ar 1:1"
185 | elif task.dimensions == 'Landscape':
186 | dimentsion_prompt = "--ar 3:2"
187 | elif task.dimensions == 'Portrait':
188 | dimentsion_prompt = "--ar 2:3"
189 | else:
190 | dimentsion_prompt = ""
191 |
192 | advanced_prompt = advanced_prompt + " " + dimentsion_prompt
193 |
194 | template_map = {
195 | "app_id": app_id,
196 | "guild_id": guild_id,
197 | "channel_id": channel_id,
198 | "session_id": session_id,
199 | "prompt": advanced_prompt,
200 | "nonce": task.nonce
201 | }
202 |
203 | imagine_template = self.template_controller.get_imagine(template_map)
204 |
205 | boundary = "----WebKitFormBoundaryqznUd46iGT62TY0s"
206 | fields = {"payload_json": (None, json.dumps(imagine_template))}
207 | from requests_toolbelt.multipart.encoder import MultipartEncoder
208 | form_data = MultipartEncoder(fields=fields, boundary=boundary)
209 |
210 | HEADERS = {
211 | "Content-Type": f"multipart/form-data; boundary={boundary}",
212 | "Authorization": user_token
213 | }
214 | res = requests.post(url=url, data=form_data, headers=HEADERS)
215 |
216 | if res.status_code not in [200, 204]:
217 | raise Exception(json.loads(res.text))
218 |
219 | print("res.content", res.content)
220 |
221 |
222 | def action(self, task: MjTask):
223 | url = self.get_interaction_url()
224 |
225 | app_id = self.mj_account.get_app_id()
226 | guild_id = self.mj_account.get_gulid_id()
227 | channel_id = self.mj_account.get_channel_id()
228 | session_id = self.mj_account.get_session_id()
229 | user_token = self.mj_account.get_user_token()
230 |
231 | template_map = {
232 | "app_id": app_id,
233 | "guild_id": guild_id,
234 | "channel_id": channel_id,
235 | "session_id": session_id,
236 | "nonce": task.nonce,
237 | "message_id": task.reference_message_id,
238 | "custom_id": task.custom_id
239 | }
240 |
241 | upscale_template = self.template_controller.get_upscale(template_map)
242 | HEADERS = {
243 | "Content-Type": "application/json",
244 | "Authorization": user_token
245 | }
246 | res = requests.post(url=url, data=json.dumps(
247 | upscale_template), headers=HEADERS)
248 | print(res)
249 |
250 | def shorten(self, task: MjTask):
251 | url = self.get_interaction_url()
252 |
253 | app_id = self.mj_account.get_app_id()
254 | guild_id = self.mj_account.get_gulid_id()
255 | channel_id = self.mj_account.get_channel_id()
256 | session_id = self.mj_account.get_session_id()
257 | user_token = self.mj_account.get_user_token()
258 |
259 | template_map = {
260 | "app_id": app_id,
261 | "guild_id": guild_id,
262 | "channel_id": channel_id,
263 | "session_id": session_id,
264 | "nonce": task.nonce,
265 | "prompt": task.prompt
266 | }
267 |
268 | upscale_template = self.template_controller.get_shorten(template_map)
269 | HEADERS = {
270 | "Content-Type": "application/json",
271 | "Authorization": user_token
272 | }
273 | res = requests.post(url=url, data=json.dumps(
274 | upscale_template), headers=HEADERS)
275 | print(res)
276 |
277 |
278 | def message(self, final_file_name):
279 | url = self.get_messages_url()
280 |
281 | app_id = self.mj_account.get_app_id()
282 | guild_id = self.mj_account.get_gulid_id()
283 | channel_id = self.mj_account.get_channel_id()
284 | session_id = self.mj_account.get_session_id()
285 | user_token = self.mj_account.get_user_token()
286 |
287 | file_name = final_file_name.split("/")[-1]
288 |
289 | template_map = {
290 | "app_id": app_id,
291 | "guild_id": guild_id,
292 | "channel_id": channel_id,
293 | "session_id": session_id,
294 | "final_file_name": final_file_name,
295 | "file_name": file_name
296 | }
297 |
298 | message_template = self.template_controller.get_message(template_map)
299 | HEADERS = {
300 | "Content-Type": "application/json",
301 | "Authorization": user_token
302 | }
303 | res = requests.post(url=url, data=json.dumps(
304 | message_template), headers=HEADERS)
305 | print(res)
306 | if res.status_code==200:
307 | upload_resp = json.loads(res.text)
308 | image_url = upload_resp['attachments'][0]['url']
309 | return image_url
310 | else:
311 | raise Exception("message upload failed")
312 |
313 | def upload_img_if_bs64(self, img: str):
314 | return self.upload2discord(img)
315 |
316 | def upload2discord(self, img: str):
317 | if not img:
318 | return
319 | if img.startswith("data:image/"):
320 |
321 | discord_upload_attachment_url = self.get_discord_upload_attachment_url()
322 |
323 | base64_data = img.split(",")[-1]
324 | image_data = base64.b64decode(base64_data)
325 | # 将字节数据转换为BytesIO对象
326 | image_stream = io.BytesIO(image_data)
327 |
328 | # 将image_streamg转为bytes
329 | image_bytes = image_stream.read()
330 | file_size = len(image_bytes)
331 | hash_value = hashlib.md5(image_data).hexdigest()
332 | file_name=f"{hash_value}.png"
333 |
334 | HEADERS = {"Content-Type": "application/json", "Authorization": self.USER_TOKEN}
335 | payload = {
336 | "files": [{"filename": file_name, "file_size": file_size, "id": "0"}]}
337 | res = requests.post(discord_upload_attachment_url, headers=HEADERS, data= json.dumps(payload))
338 | if res.status_code == 200:
339 | res_data=json.loads(res.text)
340 | attach = res_data['attachments'][0]
341 | upload_url = attach['upload_url']
342 | upload_filename = attach['upload_filename']
343 | print("upload_url",upload_url)
344 | print("upload_filename",upload_filename)
345 | headers = {
346 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",
347 | "Content-Type": "application/octet-stream",
348 | "Content-Length": str(file_size),
349 | }
350 | res = requests.put(upload_url, data=image_data, headers = headers)
351 | print(res.status_code)
352 | if res.status_code == 200:
353 | image_url= self.discord_http_service.message(upload_filename)
354 | return image_url
355 | else:
356 | raise Exception("上传图片失败")
357 |
358 | else:
359 | raise Exception("上传图片失败")
--------------------------------------------------------------------------------
/service/discord_ws_service.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/service/mail_service.py:
--------------------------------------------------------------------------------
1 | import smtplib
2 | import random
3 | from email.mime.text import MIMEText
4 | from email.mime.multipart import MIMEMultipart
5 |
6 | def send_verification_code(email):
7 | # 生成6位随机验证码
8 | verification_code = ''.join(random.choices('0123456789', k=6))
9 |
10 | # 设置发件人和收件人邮箱地址
11 | sender_email = "496812133@qq.com"
12 | receiver_email = email
13 |
14 | # 创建邮件内容
15 | msg = MIMEMultipart()
16 | msg['From'] = sender_email
17 | msg['To'] = receiver_email
18 | msg['Subject'] = "验证码"
19 |
20 | body = f"您的验证码是:{verification_code}"
21 | msg.attach(MIMEText(body, 'plain'))
22 |
23 | # 发送邮件
24 | with smtplib.SMTP('smtp.example.com', 587) as smtp:
25 | smtp.starttls()
26 | smtp.login(sender_email, "your_password")
27 | smtp.send_message(msg)
28 |
29 | print("验证码已发送至您的邮箱,请查收。")
30 |
31 | # 在此处调用函数,并传入接收验证码的邮箱地址
32 |
33 | send_verification_code("recipient_email@example.com")
--------------------------------------------------------------------------------
/service/mj_data_service.py:
--------------------------------------------------------------------------------
1 |
2 | from queue import Queue
3 | import json
4 | from dataclasses import dataclass,field
5 | from datetime import datetime
6 | from typing import Optional
7 | import redis
8 | from typing import List
9 |
10 | def default_custom_ids():
11 | return []
12 |
13 | @dataclass(init=False)
14 | class MjTask:
15 | # 任务类型
16 | task_type: Optional[str] = None
17 | # 任务状态
18 | task_status: Optional[str] = None
19 | # task_id
20 | task_id: Optional[str] = None
21 | # 提示词
22 | prompt: Optional[str] = None
23 | # 提示词-英文
24 | promptEn: Optional[str] = None
25 | # 任务描述
26 | description: Optional[str] = None
27 | # 提交时间
28 | create_time: Optional[datetime] = None
29 | # 开始执行时间
30 | start_time: Optional[datetime] = None
31 | # 结束时间
32 | end_time: Optional[datetime] = None
33 | # 图片url
34 | image_url: Optional[str] = None
35 | # 进度
36 | progress: Optional[str] = None
37 | # 失败原因
38 | fail_reason: Optional[str] = None
39 | # 回调地址
40 | notify_hook: Optional[str] = None
41 |
42 | nonce: Optional[str] = None
43 |
44 | image_prompt: Optional[str] = None
45 |
46 | # mj返回的任务创建的prompt
47 | submit_prompt: Optional[str] = None
48 |
49 | # 之前对应的message_id
50 | message_id: Optional[str] = None
51 | reference_message_id: Optional[str] = None
52 |
53 |
54 | finished_message_id: Optional[str] = None
55 |
56 | buttons: Optional[List[str]] = None
57 |
58 | # action
59 | custom_id: Optional[str] = None
60 |
61 | # upscale
62 | upscale_custom_ids: List[str] = field(default_factory=default_custom_ids)
63 | custom_upscale_id: Optional[str] = None
64 | upscale_index: Optional[int] = None
65 |
66 | # variation
67 | variation_custom_ids: List[str] = field(default_factory=default_custom_ids)
68 | custom_variation_id: Optional[str] = None
69 | variation_index: Optional[int] = None
70 |
71 |
72 | # description
73 | description: Optional[str] = None
74 |
75 | # blend
76 | blend_imgs: Optional[List[str]] = None
77 | dimensions: Optional[str]
78 |
79 | class MjDataService:
80 | def __init__(self, redis_client: redis.Redis) -> None:
81 | self.redis_client = redis_client
82 | self.mj_tasks = Queue()
83 | self.mj_tasks_list = []
84 | self.task_id_map = {}
85 | self.nonce_map = {}
86 | self.message_id_map = {}
87 |
88 | def next_nonce(self):
89 | nonce = self.redis_client.get("mj.nonce")
90 | if nonce is None:
91 | nonce = 1
92 | self.redis_client.set("mj.nonce", nonce)
93 | # 将redis中的nonce加1
94 | nonce = self.redis_client.incr("mj.nonce")
95 | return nonce
96 | def get_tasks_queue(self) -> Queue:
97 | return self.mj_tasks
98 |
99 | def get_tasks_list(self) -> list:
100 | return self.mj_tasks_list
101 | def add_task(self, task: MjTask) -> None:
102 | self.mj_tasks.put(task)
103 | self.task_id_map[task.task_id] = task
104 | self.nonce_map[task.nonce] = task
105 | self.mj_tasks_list.append(task)
106 |
107 | def get_task_by_nonce(self, nonce: str) -> Optional[MjTask]:
108 | return self.nonce_map.get(nonce)
109 |
110 | def get_task_by_task_id(self, task_id: str) -> Optional[MjTask]:
111 | return self.task_id_map.get(task_id)
112 |
113 | def get_task_by_message_id(self, message_id: str) -> Optional[MjTask]:
114 | return self.message_id_map.get(message_id)
115 |
116 | def update_task_nonce(self, task: MjTask, nonce):
117 | task.nonce = nonce
118 |
119 | def update_task_message_id(self, task: MjTask, message_id):
120 | task.message_id = message_id
121 |
122 | def update_task_progress(self, task: MjTask, progress: str) -> None:
123 | task.progress = progress
124 |
125 | def update_task_status(self, task: MjTask, status: str) -> None:
126 | task.task_status = status
127 |
128 | def update_task_image_url(self, task: MjTask, image_url: str) -> None:
129 | task.image_url = image_url
130 |
131 | def update_message_id_map(self,task:MjTask):
132 | self.message_id_map[task.message_id] = task
133 |
134 |
135 | def update_upsacle_custom_ids(self, task: MjTask, custom_ids: List[str]) -> None:
136 | task.upscale_custom_ids = custom_ids
137 |
138 | def update_variation_custom_ids(self, task: MjTask, custom_ids: List[str]) -> None:
139 | task.variation_custom_ids = custom_ids
140 |
141 | def update_finished_message_id(self, task: MjTask, finished_message_id: str) -> None:
142 | task.finished_message_id = finished_message_id
143 |
144 | def update_task_description(self, task: MjTask, description: str) -> None:
145 | task.description = description
146 |
147 | def update_buttons(self, task: MjTask, buttons: List[str]) -> None:
148 | task.buttons = buttons
--------------------------------------------------------------------------------
/service/notify_service.py:
--------------------------------------------------------------------------------
1 | from support.mj_config import MjConfig
2 | import redis
3 | import requests
4 | import json
5 | import datetime
6 | from support.task_controller import MjTask
7 | from service.template_controller import TemplateController
8 | from service.mj_data_service import MjDataService
9 | import logging
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class NotifyService:
14 | def __init__(self, redis_client: redis.Redis) -> None:
15 | self.redis_client = redis_client
16 |
17 | def notify_task_change(self, task: MjTask):
18 | if not task:
19 | return
20 |
21 | notify_hook = task.notify_hook
22 | if not notify_hook:
23 | return
24 |
25 | if not notify_hook.startswith("http"):
26 | return
27 | logger.info(
28 | f"notify_task_change, task_id: {task.task_id}, hook: {notify_hook}")
29 |
30 | try:
31 | req = {
32 | "task_id": task.task_id,
33 | "task_type": task.task_type,
34 | "task_status": task.task_status,
35 | "image_url": task.image_url,
36 | "progress": task.progress,
37 | "action_index": task.upscale_index,
38 | "description": task.description,
39 | "buttons": task.buttons,
40 | "fail_reason":task.fail_reason
41 | }
42 | res = requests.post(notify_hook, json=req)
43 | # 如果为2xx
44 | if res.status_code // 100 == 2:
45 | logger.info(f"notify_task_change res 2xx, task_id: {task.task_id}, hook: {notify_hook}, res: {res.text}")
46 | return
47 | else:
48 | logger.error(
49 | f"notify_task_change res not 2xx, task_id: {task.task_id}, hook: {notify_hook}, res: {res.text}")
50 |
51 | except Exception as e:
52 | logger.error(f"notify_task_change, task_id: {task.task_id}, hook: {notify_hook}, error: {e}")
53 | return
--------------------------------------------------------------------------------
/service/template_controller.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | class TemplateController:
4 | def __init__(self) -> None:
5 | blend_file = "./resources/api_params/blend.json"
6 | self.blend_template = json.load(open(blend_file, "r", encoding="utf-8"))
7 | describe_file = "./resources/api_params/describe.json"
8 | self.describe_template = json.load(open(describe_file, "r", encoding="utf-8"))
9 | variation_file = "./resources/api_params/variation.json"
10 | self.variation_template = json.load(open(variation_file, "r", encoding="utf-8"))
11 | upscale_file = "./resources/api_params/upscale.json"
12 | self.upscale_template = json.load(open(upscale_file, "r", encoding="utf-8"))
13 | self.imagine_template = json.load(open("./resources/api_params/imagine.json", "r", encoding="utf-8"))
14 | self.shorten_template = json.load(open("./resources/api_params/shorten.json", "r", encoding="utf-8"))
15 | self.message_template = json.load(open("./resources/api_params/message.json", "r", encoding="utf-8"))
16 |
17 |
18 | # 将tempalte中的${key}$替换为template_map中的值
19 | def replace_template(self, template, template_map):
20 | for key, value in template_map.items():
21 | if "$"+key+"$" not in template:
22 | continue
23 | template = template.replace("$"+key+"$", value)
24 | return template
25 |
26 | def get_imagine(self, template_map):
27 | imagine_template_dumps=json.dumps(self.imagine_template)
28 | imagine_template_dumps = self.replace_template(imagine_template_dumps, template_map)
29 | return json.loads(imagine_template_dumps)
30 |
31 | def get_blend(self, template_map):
32 | blend_template_dumps=json.dumps(self.blend_template)
33 | blend_template_dumps = self.replace_template(blend_template_dumps, template_map)
34 | return json.loads(blend_template_dumps)
35 |
36 | def get_describe(self, template_map):
37 | describe_template_dumps=json.dumps(self.describe_template)
38 | describe_template_dumps = self.replace_template(describe_template_dumps, template_map)
39 | return json.loads(describe_template_dumps)
40 |
41 | def get_variation(self, template_map):
42 | variation_template_dumps=json.dumps(self.variation_template)
43 | variation_template_dumps = self.replace_template(variation_template_dumps, template_map)
44 | return json.loads(variation_template_dumps)
45 |
46 | def get_upscale(self, template_map):
47 | upscale_template_dumps=json.dumps(self.upscale_template)
48 | upscale_template_dumps = self.replace_template(upscale_template_dumps, template_map)
49 | return json.loads(upscale_template_dumps)
50 |
51 | def get_shorten(self, template_map):
52 | shorten_template_dumps=json.dumps(self.shorten_template)
53 | shorten_template_dumps = self.replace_template(shorten_template_dumps, template_map)
54 | return json.loads(shorten_template_dumps)
55 |
56 |
57 | def get_message(self, template_map):
58 | message_template_dumps = json.dumps(self.message_template)
59 | message_template_dumps = self.replace_template(message_template_dumps, template_map)
60 | return json.loads(message_template_dumps)
--------------------------------------------------------------------------------
/support/Injector.py:
--------------------------------------------------------------------------------
1 | from injector import Injector, Module, provider, singleton
2 | import redis
3 | import yaml
4 | from support.task_controller import TaskController
5 | from support.mj_config import MjConfig
6 | from wss.mj_wss_proxy import MjWssSercice
7 | from service.discord_http_service import DiscordHttpService
8 | from service.mj_data_service import MjDataService
9 | from service.template_controller import TemplateController
10 | from service.notify_service import NotifyService
11 | from support.load_balancer import LoadBalancer
12 | from wss.mj_wss_manager import MjWssManager
13 |
14 |
15 | class SupportModule(Module):
16 |
17 | @provider
18 | @singleton
19 | def provide_mj_config(self) -> MjConfig:
20 | mj_config = MjConfig()
21 | return mj_config
22 |
23 | @provider
24 | @singleton
25 | def provide_redis(self, mj_config: MjConfig) -> redis.Redis:
26 | config = mj_config.mj_config
27 | redis_client = redis.Redis(host=config['redis']['host'],
28 | port=config['redis']['port'],
29 | db=config['redis']['db'],
30 | password=config['redis']['password'])
31 | return redis_client
32 |
33 | @provider
34 | @singleton
35 | def provide_task_controller(self, redis_client: redis.Redis, load_balancer: LoadBalancer, mj_data_service: MjDataService) -> TaskController:
36 | task_controller = TaskController(
37 | redis_client, load_balancer, mj_data_service)
38 | return task_controller
39 |
40 | @provider
41 | @singleton
42 | def provide_mj_data_service(self, redis_client: redis.Redis) -> MjDataService:
43 | mj_data_service = MjDataService(redis_client)
44 | return mj_data_service
45 |
46 | @provider
47 | @singleton
48 | def provide_template_controller(self,) -> TemplateController:
49 | template_controller = TemplateController()
50 | return template_controller
51 |
52 |
53 | @provider
54 | @singleton
55 | def provide_notify_service(self, redis_client: redis.Redis) -> NotifyService:
56 | notify_service = NotifyService(redis_client)
57 | return notify_service
58 |
59 | @provider
60 | @singleton
61 | def provide_load_balancer(self, mj_config: MjConfig, redis_client: redis.Redis, template_controller: TemplateController) -> LoadBalancer:
62 | load_balancer = LoadBalancer(mj_config, redis_client, template_controller)
63 | return load_balancer
64 |
65 | @provider
66 | @singleton
67 | def provide_mj_wss_manager(self, mj_config: MjConfig, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> MjWssManager:
68 | mj_wss_manager = MjWssManager(mj_config, redis_client, mj_data_service, notify_service)
69 | return mj_wss_manager
70 |
71 | # 创建注入器
72 | injector = Injector([SupportModule()])
73 |
--------------------------------------------------------------------------------
/support/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__init__.py
--------------------------------------------------------------------------------
/support/__pycache__/Injector.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/Injector.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/image_controller.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/image_controller.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/load_balancer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/load_balancer.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/mj_account.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/mj_account.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/mj_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/mj_config.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/mj_task.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/mj_task.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/task_controller.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/task_controller.cpython-38.pyc
--------------------------------------------------------------------------------
/support/__pycache__/template_controller.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/support/__pycache__/template_controller.cpython-38.pyc
--------------------------------------------------------------------------------
/support/image_controller.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import redis
3 | import uuid
4 | from support.mj_config import MjConfig
5 | from service.discord_http_service import DiscordHttpService
6 | import json
7 | import io
8 | from PIL import Image
9 | import hashlib
10 | import base64
11 |
12 | class ImageController:
13 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, discord_http_service: DiscordHttpService) -> None:
14 | self.redis_client = redis_client
15 | self.mj_config = mj_config.mj_config
16 | self.discord_http_service = discord_http_service
17 |
18 | self.discord_upload_server = self.mj_config['ng']['discord_upload_server']
19 | self.CHANNEL_ID = self.mj_config['account']['channel_id']
20 | self.USER_TOKEN = self.mj_config['account']['user_token']
21 |
22 | def upload_img_if_bs64(self, img: str):
23 | # return self.upload2fastapi(img)
24 | return self.upload2discord(img)
25 |
26 | def upload2discord(self, img: str):
27 | if not img:
28 | return
29 | if img.startswith("data:image/"):
30 |
31 | discord_upload_attachment_url = self.discord_http_service.get_discord_upload_attachment_url()
32 |
33 | base64_data = img.split(",")[-1]
34 | image_data = base64.b64decode(base64_data)
35 | # 将字节数据转换为BytesIO对象
36 | image_stream = io.BytesIO(image_data)
37 |
38 | # 将image_streamg转为bytes
39 | image_bytes = image_stream.read()
40 | file_size = len(image_bytes)
41 | hash_value = hashlib.md5(image_data).hexdigest()
42 | file_name=f"{hash_value}.png"
43 |
44 | HEADERS = {"Content-Type": "application/json", "Authorization": self.USER_TOKEN}
45 | payload = {
46 | "files": [{"filename": file_name, "file_size": file_size, "id": "0"}]}
47 | res = requests.post(discord_upload_attachment_url, headers=HEADERS, data= json.dumps(payload))
48 | if res.status_code == 200:
49 | res_data=json.loads(res.text)
50 | attach = res_data['attachments'][0]
51 | upload_url = attach['upload_url']
52 | upload_filename = attach['upload_filename']
53 | print("upload_url",upload_url)
54 | print("upload_filename",upload_filename)
55 | headers = {
56 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",
57 | "Content-Type": "application/octet-stream",
58 | "Content-Length": str(file_size),
59 | }
60 | res = requests.put(upload_url, data=image_data, headers = headers)
61 | print(res.status_code)
62 | if res.status_code == 200:
63 | image_url= self.discord_http_service.message(upload_filename)
64 | return image_url
65 | else:
66 | raise Exception("上传图片失败")
67 |
68 | else:
69 | raise Exception("上传图片失败")
70 |
71 |
72 |
73 | elif img.startswith("http"):
74 | return img
75 |
76 |
77 | def upload2fastapi(self, img: str):
78 | if not img:
79 | return
80 | if img.startswith("data:image/"):
81 | url = self.mj_config['common']['img_upload_url']
82 | data = {
83 | "bs64": img
84 | }
85 | res = requests.post(url, data=json.dumps(data))
86 | res_data = json.loads(res.text)
87 | filename = res_data["data"]["filename"]
88 | image_url = f"http://43.153.103.254/images/{filename}"
89 | return image_url
90 | elif img.startswith("http"):
91 | return img
--------------------------------------------------------------------------------
/support/load_balancer.py:
--------------------------------------------------------------------------------
1 | from service.discord_http_service import DiscordHttpService
2 | import redis
3 | from support.mj_config import MjConfig
4 | from service.template_controller import TemplateController
5 | from support.mj_account import MjAccount
6 | import random
7 |
8 | class LoadBalancer:
9 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, template_controller: TemplateController):
10 |
11 | self.mj_config = mj_config.mj_config
12 | self.redis_client = redis_client
13 | self.template_controller = template_controller
14 |
15 | self.discord_http_services = []
16 | self.init_discord_http_services()
17 |
18 | def init_discord_http_services(self):
19 | discord_http_services = []
20 | for i in range(len(self.mj_config['accounts'])):
21 | account = self.mj_config['accounts'][i]
22 | mj_account = MjAccount(account)
23 | discord_http_services.append(DiscordHttpService(
24 | self.mj_config, self.redis_client, self.template_controller,mj_account))
25 |
26 | self.discord_http_services = discord_http_services
27 |
28 | def get_discord_http_service(self) -> DiscordHttpService:
29 | index = random.randint(0, len(self.discord_http_services)-1)
30 | return self.discord_http_services[index]
31 |
--------------------------------------------------------------------------------
/support/mj_account.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class MjAccount:
4 | def __init__(self, account_config: dict) -> None:
5 | self.account_config = account_config
6 |
7 | def get_user_token(self):
8 | return self.account_config['user_token']
9 |
10 | def get_bot_token(self):
11 | return self.account_config['bot_token']
12 |
13 | def get_channel_id(self):
14 | return self.account_config['channel_id']
15 |
16 | def get_gulid_id(self):
17 | return self.account_config['guild_id']
18 |
19 | def get_app_id(self):
20 | return self.account_config['app_id']
21 |
22 | def get_session_id(self):
23 | return self.account_config['session_id']
24 |
--------------------------------------------------------------------------------
/support/mj_config.py:
--------------------------------------------------------------------------------
1 |
2 | import yaml
3 | class MjConfig:
4 | def __init__(self) -> None:
5 | self.mj_config = yaml.load(open(
6 | "./resources/config/prod_mj_config.yaml", "r", encoding="utf-8"), Loader=yaml.FullLoader)
7 |
8 |
9 | def get_accounts(self):
10 | return self.mj_config['accounts']
--------------------------------------------------------------------------------
/support/task_controller.py:
--------------------------------------------------------------------------------
1 |
2 | from service.mj_data_service import MjTask
3 | from queue import Queue
4 | import time
5 | import threading
6 | import redis
7 | from service.discord_http_service import DiscordHttpService
8 | from service.mj_data_service import MjDataService
9 | from support.load_balancer import LoadBalancer
10 | import uuid
11 |
12 |
13 | def consume_tasks(mj_data_service: MjDataService, load_balancer: LoadBalancer):
14 | while True:
15 | task_queue = mj_data_service.get_tasks_queue()
16 | task: MjTask = task_queue.get()
17 | if task is None:
18 | continue
19 |
20 | # comsume任务
21 | if task.task_type is None:
22 | continue
23 |
24 | discord_http_service = load_balancer.get_discord_http_service()
25 |
26 | if task.task_type == "imagine":
27 | discord_http_service.imageine(task)
28 | elif task.task_type == "upscale":
29 | discord_http_service.upscale(task)
30 | elif task.task_type == "variation":
31 | discord_http_service.variation(task)
32 | elif task.task_type == "describe":
33 | discord_http_service.describe(task)
34 | elif task.task_type == "blend":
35 | discord_http_service.blend(task)
36 | elif task.task_type == "action":
37 | discord_http_service.action(task)
38 | elif task.task_type == "shorten":
39 | discord_http_service.shorten(task)
40 | else:
41 | print("unknown task type: " + task.task_type)
42 |
43 | class TaskController:
44 | def __init__(self, redis_client: redis.Redis,load_balancer: LoadBalancer, mj_data_service: MjDataService) -> None:
45 | self.mj_data_service: MjDataService = mj_data_service
46 |
47 | # discord_http_service: DiscordHttpService
48 | consume_thread = threading.Thread(
49 | target = consume_tasks, args=(self.mj_data_service, load_balancer))
50 | consume_thread.daemon = True
51 | consume_thread.start()
52 |
53 | def submit_imagine(self, task_data: MjTask):
54 | self.base_submit(task_data)
55 | return task_data.task_id
56 |
57 | def submit_upscale(self, task_data: MjTask):
58 | self.base_submit(task_data)
59 | return task_data.task_id
60 |
61 | def submit_describe(self, task_data: MjTask):
62 | self.base_submit(task_data)
63 | return task_data.task_id
64 |
65 | def submit_blend(self, task_data: MjTask):
66 | self.base_submit(task_data)
67 | return task_data.task_id
68 |
69 | def submit_variation(self, task_data: MjTask):
70 | self.base_submit(task_data)
71 | return task_data.task_id
72 |
73 | def submit_action(self, task_data: MjTask):
74 | self.base_submit(task_data)
75 | return task_data.task_id
76 |
77 | def submit_shorten(self, task_data: MjTask):
78 | self.base_submit(task_data)
79 | return task_data.task_id
80 |
81 | def base_submit(self, task_data: MjTask):
82 | task_data.task_id = str(uuid.uuid4())
83 | self.mj_data_service.add_task(task_data)
84 | return task_data.task_id
85 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # from bootstrap import Bootstrap
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/common_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/common_util.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/logger_util.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/wx_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/utils/__pycache__/wx_util.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/bootstrap.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import asyncio
4 | import random
5 | import time
6 | import uuid
7 | import yaml
8 | from .logger_util import get_logger
9 | import consul
10 | import consul.base
11 | import consul.callback
12 | from prometheus_client import CollectorRegistry, Gauge, push_to_gateway, start_http_server
13 |
14 | class Bootstrap():
15 | def __init__(self, file_yaml='config.yaml') -> None:
16 | config_yaml: dict = self.read_config(file_yaml)
17 | env_dict = os.environ
18 | # main
19 | self.RUN_NAME = config_yaml.get('RUN_NAME')
20 | self.LOG_PATH = config_yaml.get('LOG_PATH')
21 | self.ENV_IP = config_yaml.get('ENV_IP', env_dict.get('ENV_IP'))
22 | self.ENV_PORT = config_yaml.get('ENV_PORT', env_dict.get('ENV_PORT'))
23 | self.ENV_PORTS = config_yaml.get('ENV_PORTS', env_dict.get('ENV_PORTS'))
24 | self.DWONSTREAMS = config_yaml.get('DWONSTREAMS')
25 |
26 | # consul
27 | self.CONSUL_IP = config_yaml.get('CONSUL_IP')
28 | self.CONSUL_PORT = config_yaml.get('CONSUL_PORT')
29 | self.CONSUL_TOKEN = config_yaml.get('CONSUL_TOKEN')
30 | self.CONSUL_CHECK_IP = config_yaml.get('CONSUL_CHECK_IP')
31 | self.CONSUL_CHECK_TICK = config_yaml.get('CONSUL_CHECK_TICK', 5)
32 | self.CONSUL_CHECK_TIMEOUT = config_yaml.get('CONSUL_CHECK_TIMEOUT', 30)
33 | self.CONSUL_CHECK_DEREG = config_yaml.get('CONSUL_CHECK_DEREG', 30)
34 |
35 | # metrics
36 | self.METRICS_IP = config_yaml.get('METRICS_IP')
37 | self.METRICS_PORT = config_yaml.get('METRICS_PORT')
38 | self.METRICS_TOKEN = config_yaml.get('METRICS_TOKEN')
39 | self.METRICS_CHECK_TICK = config_yaml.get('METRICS_CHECK_TICK', 5)
40 |
41 | def lifespan(self, _):
42 | print("startup")
43 | self.background_start()
44 | yield
45 | print(f"shutdown wait {self.CONSUL_CHECK_DEREG}s for consul dereg")
46 | self.deregister_consul()
47 | # time.sleep(self.CONSUL_CHECK_DEREG)
48 |
49 | def init(self):
50 | self.logger = get_logger(self.RUN_NAME, self.LOG_PATH)
51 | self.consul_cli = consul.Consul(self.CONSUL_IP, self.CONSUL_PORT, self.CONSUL_TOKEN)
52 | self.register_consul()
53 | self.metrics_registry = CollectorRegistry()
54 | self.metrics_reg_base()
55 | self.metrics_reg()
56 | self.service_dict = {k: [] for k in self.DWONSTREAMS}
57 | self.background_service_update()
58 | self.not_shutdown = True
59 |
60 | def background_start(self):
61 | if self.service_dict:
62 | asyncio.create_task(self.background_service_update())
63 | asyncio.create_task(self.background_metrics_update())
64 |
65 | def read_config(self, file_yaml):
66 | with open(file=file_yaml, mode='r', encoding='utf-8') as f:
67 | return yaml.load(stream=f.read(), Loader=yaml.FullLoader)
68 |
69 | def register_consul(self):
70 | # 健康检查ip端口,检查时间:5,超时时间:30,注销时间:30s
71 | # curl --request PUT http://127.0.0.1:6011/v1/agent/service/register/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81
72 | # curl --request PUT http://127.0.0.1:6011/v1/agent/service/deregister/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81
73 | # curl --request GET http://127.0.0.1:6011/v1/agent/checks?token=6df8babf-1061-2469-79f0-4488640cba81
74 | # curl --request GET http://127.0.0.1:6011/v1/agent/check/test_service_up?token=6df8babf-1061-2469-79f0-4488640cba81
75 |
76 | self.consul_id = uuid.uuid4().hex
77 | checker = consul.Check().tcp(self.CONSUL_CHECK_IP, self.ENV_PORT,
78 | self.CONSUL_CHECK_TICK * 1000 * 1000,
79 | self.CONSUL_CHECK_TIMEOUT * 1000 * 1000,
80 | self.CONSUL_CHECK_DEREG * 1000 * 1000)
81 | # checker = consul.Check().http(f'http://127.0.0.1:8531/health',
82 | # 1 * 1000 * 1000, 2 * 1000 * 1000, 3 * 1000 * 1000)
83 | # res = self.consul_cli.catalog.register(
84 | # self.RUN_NAME,
85 | # self.ENV_IP,
86 | # {
87 | # "Service": self.RUN_NAME,
88 | # "ID": self.consul_id,
89 | # "Tags": [],
90 | # "Port": self.ENV_PORT,
91 | # },
92 | # check=checker,
93 | # token=self.CONSUL_TOKEN,
94 | # )
95 | res = self.consul_cli.agent.service.register(
96 | self.RUN_NAME,
97 | service_id=self.consul_id,
98 | address=self.ENV_IP,
99 | port=self.ENV_PORT,
100 | token=self.CONSUL_TOKEN,
101 | check=checker)
102 | if not res:
103 | raise RuntimeError('register_consul fail!')
104 |
105 | def consul_token_param(self):
106 | params = []
107 | if self.CONSUL_TOKEN:
108 | params.append(('token', self.CONSUL_TOKEN))
109 | return params
110 |
111 | def deregister_consul(self):
112 | # curl --request PUT http://127.0.0.1:6011/v1/agent/service/deregister/test_service_down
113 | # curl --request PUT http://101.43.141.18:6011/v1/agent/service/register/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81
114 | # curl --request PUT http://101.43.141.18:6011/v1/agent/service/deregister/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81
115 | # curl --request GET http://101.43.141.18:6011/v1/agent/services?token=6df8babf-1061-2469-79f0-4488640cba81
116 | # curl --request GET http://101.43.141.18:6011/v1/agent/service/test_service_down?token=6df8babf-1061-2469-79f0-4488640cba81
117 |
118 | res = self.consul_cli.agent.service.deregister(self.consul_id, token=self.CONSUL_TOKEN)
119 | # res = self.consul_cli.catalog.deregister()
120 | # res = self.consul_cli.agent.agent.http.put(consul.callback.CB.bool(), '/v1/agent/service/deregister/%s' % self.RUN_NAME, params=self.consul_token_param())
121 | if not res:
122 | self.logger.error('deregister_consul fail!')
123 |
124 | def get_server(self, service_name):
125 | service_list = self.service_dict[service_name]
126 | if not service_list:
127 | return None
128 | return random.choice(service_list)
129 |
130 | async def background_service_update(self):
131 | def cb(response):
132 | consul.callback.CB._status(response)
133 | if response.code == 404:
134 | return None
135 | return json.loads(response.body)
136 | while self.not_shutdown:
137 | for service_name in self.service_dict:
138 | filter = ('filter', f'Service == "{service_name}"')
139 | service_info = self.consul_cli.agent.agent.http.get(cb, '/v1/agent/services', params=(*self.consul_token_param(), filter))
140 | if isinstance(service_info, dict):
141 | self.service_dict[service_name] = [{'ip': item.get('Address'), 'port': item.get('Port')} for item in service_info.values()]
142 | self.metrics_downstream.labels(src=self.RUN_NAME, dst=service_name).set(len(service_info))
143 | else:
144 | self.logger.error('service_dict[%s] parse error! parsed to: %s', service_name, self.service_dict)
145 | self.metrics_downstream.labels(src=self.RUN_NAME, dst=service_name).set(-1)
146 | await asyncio.sleep(self.CONSUL_CHECK_TICK)
147 |
148 | async def background_metrics_update(self):
149 | while self.not_shutdown:
150 | push_to_gateway(f'{self.METRICS_IP}:{self.METRICS_PORT}', job=self.RUN_NAME, registry=self.metrics_registry)
151 | self.logger.info('push_to_gateway')
152 | await asyncio.sleep(self.METRICS_CHECK_TICK)
153 |
154 | def metrics_reg_base(self):
155 | self.metrics_downstream = Gauge('service_downstream', 'downstream service num', ['src', 'dst'], registry=self.metrics_registry)
156 |
157 | def metrics_reg(self):
158 | raise NotImplementedError()
159 |
160 | # c = Counter('my_requests_total', 'HTTP requests total', ['method', 'endpoint'], registry=self.metrics_registry)
161 | # c.labels(method='get', endpoint='/').inc()
162 | # c.labels(method='post', endpoint='/submit').inc()
163 |
--------------------------------------------------------------------------------
/utils/common_util.py:
--------------------------------------------------------------------------------
1 | # -*- ecoding: utf-8 -*-
2 | import numpy as np
3 | from typing import List
4 | import os
5 | import shutil
6 | # import pickle
7 | import json
8 | import string
9 | from PIL import Image
10 | import math
11 | import requests
12 | import base64
13 | from urllib.parse import urlparse
14 | from io import BytesIO
15 | import time
16 | import dill as pickle
17 |
18 |
19 | def data_generator(data, batch_size):
20 | assert batch_size > 0
21 | current_pos = 0
22 | index = 0
23 | total = len(data) // batch_size + 1
24 | while current_pos < len(data):
25 | yield index, total, data[current_pos: current_pos + batch_size]
26 | current_pos += batch_size
27 | index += 1
28 |
29 |
30 | def write2file(filename, text):
31 | with open(filename, "w", encoding="utf-8") as f:
32 | f.write(text)
33 |
34 |
35 | def write2file_arr(filename, arr):
36 | init = True
37 | dirname = os.path.dirname(filename)
38 | if not os.path.exists(dirname):
39 | os.system(f"mkdir -p {dirname}")
40 | with open(filename, "w", encoding="utf-8") as f:
41 | for a in arr:
42 | if not init:
43 | f.write("\n")
44 | f.write(a)
45 | else:
46 | f.write(a)
47 | init = False
48 |
49 |
50 | def readfile2arr(filename, skip=None):
51 | with open(filename, "r", encoding="utf-8") as f:
52 | if skip == None:
53 | arr = [a.strip() for a in f]
54 | else:
55 | arr = [a.strip() for a in f][skip:]
56 | return arr
57 |
58 |
59 | # 字符串最小编辑距离
60 | def minDistance(word1: str, word2: str) -> int:
61 | m, n = len(word1), len(word2)
62 | dp = [[0 for i in range(n + 1)] for j in range(m + 1)]
63 |
64 | for i in range(m + 1):
65 | dp[i][0] = i
66 |
67 | for j in range(n + 1):
68 | dp[0][j] = j
69 |
70 | for i in range(1, m + 1):
71 | for j in range(1, n + 1):
72 | if word1[i - 1] == word2[j - 1]:
73 | dp[i][j] = dp[i - 1][j - 1]
74 | else:
75 | dp[i][j] = min(dp[i - 1][j] + 1, dp[i]
76 | [j - 1] + 1, dp[i - 1][j - 1] + 1)
77 |
78 | return dp[m][n]
79 |
80 |
81 | def get_words2ids_pad(texts, words2id, max_seq):
82 | lines = []
83 | for line in texts:
84 | lines.append(line.strip())
85 |
86 | pad_ids = []
87 | for line in lines:
88 | ids = [words2id[i] for i in line]
89 | ids = ids[0:max_seq]
90 | if len(ids) < max_seq:
91 | ids = ids + [0] * (max_seq - len(ids))
92 | pad_ids.append(ids)
93 | return np.array(pad_ids)
94 |
95 |
96 | def remove_dir_and_files(root_dir):
97 | # 删除目录及文件
98 | shutil.rmtree(root_dir, True)
99 |
100 |
101 | def remove_subdirs_and_files(root_dir):
102 | for filename in os.listdir(root_dir):
103 | file_path = os.path.join(root_dir, filename)
104 | if os.path.isfile(file_path) or os.path.islink(file_path):
105 | os.unlink(file_path)
106 | elif os.path.isdir(file_path):
107 | shutil.rmtree(file_path)
108 |
109 |
110 | def copy_files(file_list, ori_dir, des_dir):
111 | for file_name in file_list:
112 | file_dir = os.path.join(ori_dir, file_name)
113 | shutil.copy(file_dir, des_dir)
114 |
115 |
116 | def save_pickle(filename, obj):
117 | with open(filename, "wb") as f:
118 | pickle.dump(obj, f)
119 |
120 |
121 | def load_pickle(filename):
122 | if filename.endswith(".json"):
123 | return read_json(filename)
124 | with open(filename, "rb") as f:
125 | obj = pickle.load(f)
126 | return obj
127 |
128 |
129 | def arr2str_tab(arr):
130 | result = []
131 | for i in arr:
132 | if isinstance(i, int):
133 | result.append(str(i))
134 | elif isinstance(i, float):
135 | result.append("%.3f" % i)
136 | elif isinstance(i, str):
137 | result.append(i)
138 | # arr = [str(i) if isinstance(i, int) elif "%.3f" % i for i in arr]
139 | return "\t".join(result)
140 |
141 |
142 | def to_torch(data, type=None, device="cpu"):
143 | import torch
144 |
145 | if type == None:
146 | return torch.from_numpy(np.array(data)).to(device)
147 | else:
148 | return torch.from_numpy(np.array(data, dtype=type)).to(device)
149 |
150 |
151 | def read_json(filename):
152 | with open(filename, "r", encoding="utf-8") as f:
153 | data = json.load(f)
154 | return data
155 |
156 |
157 | def cached_pkl(cache_path: str, load_from_cache: bool, gen_func, *args, **kargs):
158 | if os.path.exists(cache_path) and load_from_cache:
159 | print("gen from cached pkl")
160 | return load_pickle(cache_path)
161 | else:
162 | print("gen from origin file")
163 | out = gen_func(*args, **kargs)
164 | save_pickle(cache_path, out)
165 | return out
166 |
167 |
168 | def alignment(arr: List, max_length: int, padding):
169 | arr_len = len(arr)
170 | mask = [1.0] * len(arr)
171 | if len(arr) >= max_length:
172 | arr = arr[0:max_length]
173 | mask = mask[0:max_length]
174 | else:
175 | arr += [padding] * (max_length - arr_len)
176 | mask += [0.0] * (max_length - arr_len)
177 | return arr, mask
178 |
179 |
180 | def collate_fn_from_map(data, keys=None, device="cpu", ignore_keys=[]):
181 | assert data
182 | assert isinstance(data, list)
183 | result = {}
184 | import torch
185 |
186 | ignore_keys = set(ignore_keys)
187 | if isinstance(data[0], dict):
188 | for i in data:
189 | for k, v in i.items():
190 | if k not in ignore_keys:
191 | if k not in result:
192 | result[k] = []
193 | result[k].append(v)
194 | if keys:
195 | for k, v in result.items():
196 | if k in keys:
197 | result[k] = to_torch(v, type=keys[k], device=device)
198 |
199 | return result
200 |
201 |
202 | def save_json_datas(file: str, datas):
203 | # if not os.path.exists("../evaluate/python_result/"+testset_name+"/"):
204 | # os.mkdir("../evaluate/python_result/"+testset_name+"/")
205 | # paths=file.split("/")
206 | # tmp_path=paths[0]+"/"
207 | beg_index = 0
208 | while True:
209 | try:
210 | index = file.index("/", beg_index)
211 | except Exception as e:
212 | break
213 |
214 | recur_dir_path = file[: index + 1]
215 | if not os.path.exists(recur_dir_path):
216 | os.mkdir(recur_dir_path)
217 | beg_index = index + 1
218 |
219 | with open(file, "w", encoding="utf-8") as f:
220 | f.write(json.dumps(datas, indent=4, ensure_ascii=False))
221 |
222 |
223 | def is_chinese(uchar):
224 | """判断一个unicode是否是汉字"""
225 | if uchar >= "\u4e00" and uchar <= "\u9fa5":
226 | return True
227 | else:
228 | return False
229 |
230 |
231 | def save_huggface_json_datas(file, datas):
232 | line_datas = []
233 | for i in datas:
234 | line_datas.append(json.dumps(i, ensure_ascii=False))
235 | write2file_arr(file, line_datas)
236 |
237 |
238 | def read_huggingface_json_datas(file):
239 | line_datas = []
240 | datas = readfile2arr(file)
241 | for i in datas:
242 | r = json.loads(i)
243 | line_datas.append(r)
244 | return line_datas
245 |
246 |
247 | def get_punctuations():
248 | punctuations = set(list(string.punctuation + "。!?”’“‘…·《》【】—-,、,"))
249 | return punctuations
250 |
251 |
252 | def random_select_unrepeat(items, size):
253 | # 不重复
254 | selected = np.random.choice(items, replace=False, size=size)
255 | return selected
256 |
257 |
258 | # 将文档拆分成句子,根据max_len合并
259 | def cut_to_sents(text, max_len=120, merge=False):
260 | import re
261 |
262 | sentences = re.split(r"(?|\?|。|!|\…\…|\r|\n)", text)
263 |
264 | clip_sents = []
265 | for i in sentences:
266 | left = i
267 | while len(left) > 0:
268 | clip_sents.append(left[:max_len])
269 | left = left[max_len:]
270 | return clip_sents
271 |
272 | for i in range(len(sentences)):
273 | if i % 2 == 1:
274 | _sentences.append(sentences[i - 1] + sentences[i])
275 | elif i == len(sentences) - 1:
276 | # 最后一个
277 | _sentences.append(sentences[i])
278 |
279 | clip_sents = []
280 | for i in _sentences:
281 | left = i
282 | while len(left) > 0:
283 | clip_sents.append(left[:max_len])
284 | left = left[max_len:]
285 | # 不拼接
286 | if not merge:
287 | return clip_sents
288 |
289 | l = ""
290 | merged_sents = []
291 | for index, i in enumerate(clip_sents):
292 | if len(l) + len(i) > max_len:
293 | merged_sents.append(l)
294 | l = ""
295 | l += i
296 | else:
297 | l += i
298 |
299 | if len(l) > 0:
300 | merged_sents.append(l)
301 | return merged_sents
302 |
303 |
304 | def getImage(filename, bs64=False):
305 | if filename == None:
306 | return None
307 |
308 | if isinstance(filename, Image.Image):
309 | im = filename
310 | elif filename.startswith("data:image"):
311 | return filename
312 | elif is_url(filename):
313 | response = requests.get(filename, timeout=15)
314 | if response.ok:
315 | im = Image.open(BytesIO(response.content))
316 | else:
317 | print(f"error get img from url:{filename}")
318 | return None
319 | elif os.path.isfile(filename):
320 | im = Image.open(filename)
321 | elif isinstance(filename, Image.Image):
322 | im = filename
323 | else:
324 | im = base64_to_image(filename)
325 |
326 | if bs64:
327 | return image_to_base64(im)
328 | return im
329 |
330 |
331 | def image_to_base64(image: Image.Image, fmt="png") -> str:
332 | from io import BytesIO
333 | import base64
334 |
335 | output_buffer = BytesIO()
336 | image.save(output_buffer, format=fmt)
337 | byte_data = output_buffer.getvalue()
338 | base64_str = base64.b64encode(byte_data).decode("utf-8")
339 | return f"data:image/{fmt};base64," + base64_str
340 |
341 |
342 | def base64_to_image(image_base64):
343 | import base64
344 | from PIL import Image
345 | from io import BytesIO
346 |
347 | datas = image_base64.split(",")
348 | if len(datas) == 1:
349 | image_base64 = datas[0]
350 | else:
351 | image_base64 = datas[1]
352 | # 
353 | img_bs64 = base64.b64decode(image_base64)
354 | image = Image.open(BytesIO(img_bs64))
355 | return image
356 |
357 |
358 | def url_to_base64(url):
359 | response = requests.get(url)
360 | image_data = response.content
361 | base64_data = base64.b64encode(image_data)
362 | base64_image = base64_data.decode("utf-8")
363 | return f"data:image/jpeg;base64,{base64_image}"
364 |
365 |
366 | def image_to_np(image):
367 | return np.array(image)
368 |
369 |
370 | def hash_lora_file(filename):
371 | import mmap
372 | import hashlib
373 |
374 | """Hashes a .safetensors file using the new hashing method.
375 | Only hashes the weights of the model."""
376 | hash_sha256 = hashlib.sha256()
377 | blksize = 1024 * 1024
378 |
379 | with open(filename, mode="r", encoding="utf8") as file_obj:
380 | with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
381 | header = m.read(8)
382 | n = int.from_bytes(header, "little")
383 |
384 | with open(filename, mode="rb") as file_obj:
385 | offset = n + 8
386 | file_obj.seek(offset)
387 | for chunk in iter(lambda: file_obj.read(blksize), b""):
388 | hash_sha256.update(chunk)
389 | file_hash = hash_sha256.hexdigest()
390 | legacy_hash = file_hash[:12]
391 | return file_hash, legacy_hash
392 |
393 |
394 | def hash_sd_model(filename):
395 | """old hash that only looks at a small part of the file and is prone to collisions"""
396 |
397 | try:
398 | with open(filename, "rb") as file:
399 | import hashlib
400 |
401 | m = hashlib.sha256()
402 |
403 | file.seek(0x100000)
404 | m.update(file.read(0x10000))
405 | return m.hexdigest()
406 | except FileNotFoundError:
407 | return "NOFILE"
408 |
409 |
410 | def hash_file(filename):
411 | """old hash that only looks at a small part of the file and is prone to collisions"""
412 |
413 | try:
414 | with open(filename, "rb") as file:
415 | import hashlib
416 |
417 | m = hashlib.sha256()
418 | m.update(file.read(0x10000))
419 |
420 | return m.hexdigest()[:10]
421 | except FileNotFoundError:
422 | return "NOFILE"
423 |
424 |
425 | def file_walk(root_path, excludes=[]):
426 | file_list = []
427 | for dirPath, dirNames, fileNames in os.walk(root_path):
428 | for fileName in fileNames:
429 | if not fileName.split(".")[-1] in excludes:
430 | # if not fileName.lower().endswith('.png') and not fileName.lower().endswith('.jsonl'):
431 | # if not fileName.lower().endswith('.png') and not fileName.lower().endswith('.jsonl'):
432 | filePath = os.path.join(dirPath, fileName)
433 | file_list.append(filePath)
434 | return file_list
435 |
436 |
437 | def get_latest_file(root_path):
438 | file_list = file_walk(root_path)
439 | latest_file = max(file_list, key=os.path.getctime)
440 | return latest_file
441 |
442 |
443 | def webp2png(file_list):
444 | from PIL import Image, ImageOps
445 |
446 | for srcImagePath in file_list:
447 | image = Image.open(srcImagePath)
448 | image = ImageOps.exif_transpose(image)
449 | dstImagePath = os.path.splitext(srcImagePath)[0] + ".png"
450 | image.save(dstImagePath)
451 | print("%s ---> %s" % (srcImagePath, dstImagePath))
452 |
453 |
454 | def get_file_prefix(fname):
455 | names = fname.split(".")[:-1]
456 | return ".".join(names)
457 |
458 |
459 | def download_file(src, dest):
460 | # import wget
461 | # import ssl
462 | # # 取消ssl全局验证
463 | # ssl._create_default_https_context = ssl._create_unverified_context
464 | # wget.download(src, dest)
465 | import os
466 |
467 | os.system(f"wget -O {dest} {src}")
468 |
469 |
470 | def img_compress(in_file, out_file, target_size=40):
471 | from PIL import Image, ImageFile
472 |
473 | def get_size(file):
474 | # 获取文件大小:KB
475 | size = os.path.getsize(file)
476 | return int(size / 1024)
477 |
478 | # 防止图片超过178956970 pixels 而报错
479 | ImageFile.LOAD_TRUNCATED_IMAGES = True
480 | Image.MAX_IMAGE_PIXELS = None
481 | # 读取img文件
482 | im = Image.open(in_file)
483 | if im.mode == "RGBA":
484 | # print(in_file)
485 | im = im.convert("RGB")
486 |
487 | o_size = get_size(in_file)
488 | if o_size > target_size:
489 | # scale=im.width*im.height
490 | scale = math.sqrt(target_size / o_size)
491 | height = int(im.height * scale)
492 | width = int(im.width * scale)
493 | im = im.resize((width, height))
494 | im.save(out_file)
495 |
496 |
497 | def img_compress_bs64(bs64_in, target_size=1000):
498 | if target_size == None:
499 | return bs64_in
500 | bytes_count = len(bs64_in)
501 | size = bytes_count / 1024
502 | # print("img_size", size)
503 | image = base64_to_image(bs64_in)
504 |
505 | if size > target_size:
506 | scale = math.sqrt(target_size / size)
507 | height = int(image.height * scale)
508 | width = int(image.width * scale)
509 | image = image.resize((width, height))
510 | output = image_to_base64(image)
511 | return output
512 |
513 |
514 | def is_url(data):
515 | # 判断字符串是否是IP地址
516 | def is_ip_address(string):
517 | parts = string.split(".")
518 | if len(parts) != 4:
519 | return False
520 | for part in parts:
521 | if not part.isdigit() or int(part) < 0 or int(part) > 255:
522 | return False
523 | return True
524 |
525 | if not data:
526 | return False
527 |
528 | if data.startswith("http://") or data.startswith("https://"):
529 | return True
530 | elif is_ip_address(data):
531 | return True
532 | else:
533 | return False
534 |
535 | # return bool(parsed_url.scheme)
536 |
537 |
538 | def fmt_data(time_obj):
539 | fmt_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time_obj))
540 |
541 | return fmt_time
542 |
543 |
544 | def get_self_ip():
545 | import socket
546 | try:
547 | # 创建一个UDP套接字
548 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
549 |
550 | # 连接到公共的 DNS 服务器
551 | sock.connect(('8.8.8.8', 80))
552 |
553 | # 获取本地套接字的地址信息
554 | ip_address = sock.getsockname()[0]
555 |
556 | return ip_address
557 | except socket.error:
558 | return '无法获取IP地址'
559 |
560 |
561 | def get_remote_ip():
562 | import requests
563 | res = requests.get("https://ipinfo.io/ip")
564 | return res.text
565 |
566 |
567 | def get_osr_create_event_loop():
568 | import asyncio
569 | # 尝试获取当前线程的事件循环
570 | try:
571 | loop = asyncio.get_event_loop()
572 | except RuntimeError: # 当前线程没有事件循环时会引发这个异常
573 | loop = asyncio.new_event_loop()
574 | asyncio.set_event_loop(loop)
575 |
576 | return loop
577 |
578 | # 异地组网
579 |
580 |
581 | def get_vnc_ip():
582 | import socket
583 | import fcntl
584 | import struct
585 |
586 | def get_ip_address(ifname):
587 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
588 | try:
589 | ip_address = socket.inet_ntoa(fcntl.ioctl(
590 | s.fileno(),
591 | 0x8915, # SIOCGIFADDR
592 | struct.pack('256s', ifname[:15].encode('utf-8'))
593 | )[20:24])
594 | return ip_address
595 | except IOError:
596 | return None
597 |
598 | # 获取所有网卡名称
599 | nic_names = socket.if_nameindex()
600 |
601 | # 遍历每个网卡,并获取对应的IP地址
602 | for nic in nic_names:
603 | name = nic[1]
604 | ip_address = get_ip_address(name)
605 | if ip_address:
606 | if name == "oray_vnc":
607 | return ip_address
608 | return None
609 |
610 |
611 | def encode_base64(s):
612 | # 将字符串转换为字节
613 | byte_representation = s.encode('utf-8')
614 | # 使用base64库进行编码
615 | base64_bytes = base64.b64encode(byte_representation)
616 | # 将字节转换回字符串
617 | base64_string = base64_bytes.decode('utf-8')
618 | return base64_string
619 |
620 |
621 |
622 |
623 |
624 | if __name__ == "__main__":
625 | # articles = readfile2arr("./datas/corpus/raw_articles.csv")
626 | # print(len(articles))
627 | # print(collate_fn_from_map(
628 | # [{'a': 'texta', 'b': 'textb'}, {'a': 'texta2', 'b': 'textb2'}], keys=[]))
629 | # data = {"a_f": "a", "b_f": "b"}
630 | # datas = [data for i in range(10)]
631 | # # print(json.dumps(data, ensure_ascii=False))
632 | # # print(json.dumps(data, indent=4, ensure_ascii=False))
633 | # save_huggface_json_datas("./test.json", datas)
634 |
635 | from PIL import Image, ImageOps
636 |
637 | image = Image.open("/home/ubuntu/projects/imgs/origin/0f3335e556.jpg")
638 | bs64 = image_to_base64(image)
639 | out = img_compress_bs64(bs64, target_size=100)
640 | base64_to_image(out).save("a.png")
641 | # print(len(out))
642 |
--------------------------------------------------------------------------------
/utils/logger_util.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | def get_logger(logger_name, file_path=None, level=logging.DEBUG, as_default_logger=True) -> logging.Logger:
4 | logger = logging.getLogger(logger_name)
5 | logger.propagate = False
6 | logger.setLevel(level)
7 | if not file_path:
8 | file_path = '../logs'
9 | if not os.path.exists(file_path):
10 | os.makedirs(file_path)
11 |
12 | fh = logging.FileHandler(f'{file_path}/{logger_name}.log')
13 | fh.setLevel(level)
14 | logger.addHandler(fh)
15 | if as_default_logger:
16 | logging.root = logger
17 | return logger
18 |
--------------------------------------------------------------------------------
/utils/wx_util.py:
--------------------------------------------------------------------------------
1 | # 定义XML转字典的函数
2 | def trans_xml_to_dict(data_xml):
3 | # soup = BeautifulSoup(data_xml, features='xml')
4 | soup = BeautifulSoup(data_xml, "html.parser")
5 |
6 | xml = soup.find('xml') # 解析XML
7 | if not xml:
8 | return {}
9 | data_dict = dict([(item.name, item.text) for item in xml.find_all()])
10 | return data_dict
11 |
12 |
13 | # 定义字典转XML的函数
14 | def trans_dict_to_xml(data_dict):
15 | data_xml = []
16 | for k in sorted(data_dict.keys()): # 遍历字典排序后的key
17 | v = data_dict.get(k) # 取出字典中key对应的value
18 | if k == 'detail' and not v.startswith(''.format(v)
20 | data_xml.append('<{key}>{value}{key}>'.format(key=k, value=v))
21 | # return '{}'.format(''.join(data_xml)) # 返回XML
22 | # 返回XML,并转成utf-8,解决中文的问题
23 | return '{}'.format(''.join(data_xml)).encode('utf-8')
24 |
--------------------------------------------------------------------------------
/wss/__pycache__/mj_wss_manager.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/__pycache__/mj_wss_manager.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/__pycache__/mj_wss_proxy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/__pycache__/mj_wss_proxy.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/__pycache__/base_message_handler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/base_message_handler.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/__pycache__/describe_success_handler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/describe_success_handler.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/__pycache__/imagine_handler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/imagine_handler.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/__pycache__/imagine_hanlder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/imagine_hanlder.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/__pycache__/message_create_handler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/message_create_handler.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/__pycache__/upscale_handler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/upscale_handler.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/__pycache__/variation_handler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anychnn/py-midjourney-proxy/4a2e7be893140a3d41d513290a00d6d79e675b80/wss/handler/__pycache__/variation_handler.cpython-38.pyc
--------------------------------------------------------------------------------
/wss/handler/base_message_handler.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class MessageHandler:
4 | def __init__(self,) -> None:
5 | pass
6 |
7 | def handle(self, message):
8 | pass
--------------------------------------------------------------------------------
/wss/handler/describe_success_handler.py:
--------------------------------------------------------------------------------
1 | from wss.handler.base_message_handler import MessageHandler
2 | import redis
3 | import re
4 | import json
5 | from service.mj_data_service import MjDataService, MjTask
6 | from service.notify_service import NotifyService
7 | import logging
8 | logger = logging.getLogger(__name__)
9 |
10 | class DescribeSuccessHandler(MessageHandler):
11 |
12 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None:
13 | self.redis_client = redis_client
14 | self.mj_data_service = mj_data_service
15 | self.notify_service = notify_service
16 |
17 |
18 | def handle(self, message):
19 | op = message['op']
20 | t = message['t']
21 | if t == 'MESSAGE_UPDATE':
22 | d = message['d']
23 | # nonce = d['nonce']
24 | content = d['content']
25 | message_id = d['id']
26 | d_type = d['type']
27 | if "embeds" not in d or len(d['embeds'])==0:
28 | return
29 |
30 | if d_type == 20:
31 | embed = d['embeds'][0]
32 | description = embed['description']
33 | image_url = embed['image']['url']
34 |
35 | target_task = None
36 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list()
37 |
38 | for task in task_list:
39 | if task.task_status == 'pending' and task.image_prompt == image_url and task.task_type == 'describe':
40 | target_task = task
41 | break
42 |
43 | if target_task:
44 | self.mj_data_service.update_task_status(
45 | target_task, "success")
46 | self.mj_data_service.update_task_description(
47 | target_task, description)
48 | self.mj_data_service.update_task_progress(
49 | target_task, "100%")
50 | self.notify_service.notify_task_change(target_task)
51 | d['mj_proxy_handled'] = True
--------------------------------------------------------------------------------
/wss/handler/imagine_hanlder.py:
--------------------------------------------------------------------------------
1 | from wss.handler.base_message_handler import MessageHandler
2 | import redis
3 | import re
4 | import json
5 | from service.mj_data_service import MjDataService, MjTask
6 | import logging
7 | from service.notify_service import NotifyService
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | class ImagineHandler(MessageHandler):
12 |
13 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None:
14 | self.redis_client = redis_client
15 | self.mj_data_service = mj_data_service
16 | self.notify_service = notify_service
17 | # "**a white haired boy --niji 5** - <@1233674918630264866> (Waiting to start)"
18 | self.prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)")
19 |
20 | def extract_prompt(self,prompt):
21 | prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)")
22 | matched = prompt_pattern.match(prompt)
23 | if matched:
24 | return matched.group(1)
25 | # if matched:
26 | # inner_prompt = matched.group(1)
27 |
28 | # image_pattern = re.compile("<(.*?)> (.*)")
29 | # inner_mathed = image_pattern.match(inner_prompt)
30 | # if inner_mathed:
31 | # image_url = inner_mathed.group(1)
32 | # prompt = inner_mathed.group(2)
33 | # return image_url, prompt
34 | # else:
35 | # return None, inner_prompt
36 | else:
37 | raise Exception("prompt format error")
38 |
39 | def get_buttons(self,d):
40 | buttons=[]
41 | for i in range(len(d['components'])):
42 | buttons.extend(d['components'][i]['components'])
43 | return buttons
44 | def handle(self, message):
45 | op = message['op']
46 | t = message['t']
47 | if t == 'MESSAGE_CREATE':
48 | d = message['d']
49 | message_id = d['id']
50 | d_type = d['type']
51 | channel_id = d['channel_id']
52 | content = d['content']
53 | attachments = d['attachments']
54 | matched = self.prompt_pattern.findall(content)
55 | prompt = matched[0][0] if matched else None
56 | progress = matched[0][1] if matched else None
57 |
58 | # ' a white haired boy --niji 5 --relax'
59 |
60 | if d_type == 0:
61 | # Imagine已经完成
62 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list()
63 | # 查找条件:
64 | # 1. 任务状态为pending
65 | # 2. 任务的prompt相等
66 | # 3. task_type为imagine
67 | # print("finished")
68 | target_task = None
69 | for task in task_list:
70 |
71 | if task.task_status == 'pending' and task.submit_prompt == prompt:
72 | target_task = task
73 | break
74 | # if img_promot:
75 | # if task.task_status == 'pending' and task.prompt == prompt and task.task_type == 'imagine' and task.image_prompt!=None:
76 | # target_task = task
77 | # break
78 | # else:
79 | # if task.task_status == 'pending' and task.prompt == prompt and task.task_type == 'imagine':
80 | # target_task = task
81 | # break
82 |
83 | if target_task and len(attachments) > 0:
84 |
85 | buttons = self.get_buttons(d)
86 | self.mj_data_service.update_buttons(target_task, buttons)
87 |
88 |
89 | self.mj_data_service.update_finished_message_id(
90 | target_task, message_id)
91 | self.mj_data_service.update_task_message_id(
92 | target_task, message_id)
93 |
94 | self.mj_data_service.update_task_status(
95 | target_task, "success")
96 | self.mj_data_service.update_task_image_url(
97 | target_task, attachments[0]['url'])
98 | self.mj_data_service.update_task_progress(
99 | target_task, "100%")
100 | self.notify_service.notify_task_change(target_task)
101 | d['mj_proxy_handled'] = True
102 | else:
103 | raise Exception("no task found")
104 |
105 | elif d_type == 20:
106 | # Imagine创建成功
107 | nonce = d['nonce']
108 |
109 | task = self.mj_data_service.get_task_by_nonce(nonce)
110 | if not task:
111 | logger.error(f"no task found, nonce: {nonce}")
112 | return
113 |
114 | task.submit_prompt = prompt
115 | self.mj_data_service.update_task_progress(
116 | task, "0%")
117 | self.mj_data_service.update_task_nonce(task, nonce)
118 | self.mj_data_service.update_task_message_id(task, message_id)
119 | self.mj_data_service.update_message_id_map(task)
120 | self.notify_service.notify_task_change(task)
121 |
122 | d['mj_proxy_handled'] = True
123 |
124 | elif t == 'MESSAGE_UPDATE':
125 | d = message['d']
126 | # nonce = d['nonce']
127 | content = d['content']
128 | attachments = d['attachments']
129 | if not content:
130 | return
131 | message_id = d['id']
132 | d_type = d['type']
133 | if d_type == 20:
134 | # '**a white haired girl --niji 5** - <@1233674918630264866> (15%) (relaxed)'
135 | # 提取出content中的进度
136 | matched = self.prompt_pattern.findall(content)
137 | prompt = matched[0][0] if matched else None
138 | progress = matched[0][1] if matched else None
139 | channel_id = d['channel_id']
140 | task = self.mj_data_service.get_task_by_message_id(message_id)
141 |
142 | if not task:
143 | logger.error(f"no task found, message_id: {message_id}")
144 | return
145 | self.mj_data_service.update_task_progress(task, progress)
146 | if len(attachments):
147 | self.mj_data_service.update_task_image_url(
148 | task, attachments[0]['url'])
149 |
150 | self.notify_service.notify_task_change(task)
151 | d['mj_proxy_handled'] = True
152 | print(
153 | f'imagine message update :channel_id:{channel_id},content:{content}message_id:{message_id},progress:{progress}')
154 |
--------------------------------------------------------------------------------
/wss/handler/shorten_success_handler.py:
--------------------------------------------------------------------------------
1 | from wss.handler.base_message_handler import MessageHandler
2 | import redis
3 | import re
4 | import json
5 | from service.mj_data_service import MjDataService, MjTask
6 | import logging
7 | from service.notify_service import NotifyService
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | class DescribeSuccessHandler(MessageHandler):
12 |
13 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None:
14 | self.redis_client = redis_client
15 | self.mj_data_service = mj_data_service
16 | self.notify_service = notify_service
17 |
18 |
19 | def handle(self, message):
20 | op = message['op']
21 | t = message['t']
22 | if t == 'MESSAGE_UPDATE':
23 | d = message['d']
24 | # nonce = d['nonce']
25 | content = d['content']
26 | message_id = d['id']
27 | d_type = d['type']
28 | if "embeds" not in d or len(d['embeds'])==0:
29 | return
30 |
31 | if d_type == 19:
32 | embed = d['embeds'][0]
33 | description = embed['description']
34 | image_url = embed['image']['url']
35 |
36 | target_task = None
37 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list()
38 |
39 | for task in task_list:
40 | if task.task_status == 'pending' and task.image_prompt == image_url and task.task_type == 'shorten':
41 | target_task = task
42 | break
43 |
44 | if target_task:
45 | self.mj_data_service.update_task_status(
46 | target_task, "success")
47 | self.mj_data_service.update_task_description(
48 | target_task, description)
49 | self.mj_data_service.update_task_progress(
50 | target_task, "100%")
51 | self.notify_service.notify_task_change(target_task)
52 | d['mj_proxy_handled'] = True
--------------------------------------------------------------------------------
/wss/handler/upscale_handler.py:
--------------------------------------------------------------------------------
1 | from wss.handler.base_message_handler import MessageHandler
2 | import redis
3 | import re
4 | import json
5 | from service.mj_data_service import MjDataService,MjTask
6 | from service.notify_service import NotifyService
7 |
8 | class UpscaleHandler(MessageHandler):
9 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None:
10 | self.redis_client = redis_client
11 | self.mj_data_service = mj_data_service
12 | self.notify_service = notify_service
13 | self.prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)")
14 | # "**a white haired boy --niji 5** - Image #1 <@1233674918630264866>"
15 | self.upscale_pattern = re.compile("\*\*(.*?)\*\* - Image #(\d) <@\d+>")
16 | # \*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\)
17 | self.variation_pattern = re.compile("\*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\)")
18 |
19 | def get_reference_message_id(self,d):
20 | return d['message_reference']['message_id']
21 |
22 | def get_buttons(self,d):
23 | buttons=[]
24 | for i in range(len(d['components'])):
25 | buttons.extend(d['components'][i]['components'])
26 | return buttons
27 |
28 | def handle(self, message):
29 | op = message['op']
30 | t = message['t']
31 | if t == 'MESSAGE_CREATE':
32 | d = message['d']
33 | message_id = d['id']
34 | d_type = d['type']
35 | channel_id = d['channel_id']
36 | content = d['content']
37 | attachments = d['attachments']
38 | matched = self.prompt_pattern.findall(content)
39 | prompt = matched[0][0] if matched else None
40 | progress = matched[0][1] if matched else None
41 |
42 | if d_type == 19:
43 | # Upscale 图片放大
44 | if "Image" in content:
45 | # Upscale
46 | reference_message_id = self.get_reference_message_id(d)
47 | target_task = None
48 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list()
49 | matched = self.upscale_pattern.findall(content)
50 | # upscale_index= matched[0][1] if matched else None
51 |
52 | for task in task_list:
53 | if task.task_status == 'pending' and task.reference_message_id == reference_message_id and task.task_type == 'action':
54 | target_task = task
55 | break
56 | if target_task and len(attachments) > 0:
57 | buttons = self.get_buttons(d)
58 | self.mj_data_service.update_buttons(target_task, buttons)
59 | self.mj_data_service.update_task_message_id(
60 | target_task, message_id)
61 |
62 | self.mj_data_service.update_task_status(
63 | target_task, "success")
64 | self.mj_data_service.update_task_image_url(
65 | target_task, attachments[0]['url'])
66 | self.mj_data_service.update_task_progress(
67 | target_task, "100%")
68 | self.notify_service.notify_task_change(target_task)
69 |
70 |
71 | d['mj_proxy_handled'] = True
72 | print(
73 | f'channel_id:{channel_id},content:{content},message_id:{message_id}')
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/wss/handler/variation_handler.py:
--------------------------------------------------------------------------------
1 | from wss.handler.base_message_handler import MessageHandler
2 | import redis
3 | import re
4 | import json
5 | from service.mj_data_service import MjDataService,MjTask
6 | from service.notify_service import NotifyService
7 |
8 | class VariationHandler(MessageHandler):
9 | def __init__(self, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None:
10 | self.redis_client = redis_client
11 | self.mj_data_service = mj_data_service
12 | self.notify_service = notify_service
13 | self.prompt_pattern = re.compile("\*\*(.*?)\*\* - <@\d+> \((.*?)\)")
14 | # "**a white haired boy --niji 5** - Image #1 <@1233674918630264866>"
15 | self.upscale_pattern = re.compile("\*\*(.*?)\*\* - Image #(\d) <@\d+>")
16 | # \*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\)
17 | self.variation_pattern = re.compile("\*\*(.*?)\*\* - Variations \(.*?\) by <@\d+> \((.*?)\)")
18 | def get_reference_message_id(self,d):
19 | return d['message_reference']['message_id']
20 |
21 | def get_buttons(self,d):
22 | buttons=[]
23 | for i in range(len(d['components'])):
24 | buttons.extend(d['components'][i]['components'])
25 | return buttons
26 |
27 | def handle(self, message):
28 | op = message['op']
29 | t = message['t']
30 | if t == 'MESSAGE_CREATE':
31 | d = message['d']
32 | message_id = d['id']
33 | d_type = d['type']
34 | channel_id = d['channel_id']
35 | content = d['content']
36 | attachments = d['attachments']
37 | matched = self.prompt_pattern.findall(content)
38 | prompt = matched[0][0] if matched else None
39 | progress = matched[0][1] if matched else None
40 |
41 | if d_type == 19:
42 | if 'Variations' in content:
43 | reference_message_id = self.get_reference_message_id(d)
44 |
45 | target_task = None
46 | task_list: list[MjTask] = self.mj_data_service.get_tasks_list()
47 | matched = self.variation_pattern.findall(content)
48 | upscale_index= matched[0][1] if matched else None
49 |
50 | for task in task_list:
51 | if task.task_status == 'pending' and task.reference_message_id == reference_message_id and task.task_type == 'action':
52 | target_task = task
53 | break
54 | if target_task and len(attachments) > 0:
55 |
56 | buttons = self.get_buttons(d)
57 | self.mj_data_service.update_buttons(target_task, buttons)
58 |
59 | self.mj_data_service.update_task_message_id(
60 | target_task, message_id)
61 | self.mj_data_service.update_task_status(
62 | target_task, "success")
63 | self.mj_data_service.update_task_image_url(
64 | target_task, attachments[0]['url'])
65 | self.mj_data_service.update_task_progress(
66 | target_task, "100%")
67 |
68 | self.notify_service.notify_task_change(target_task)
69 | d['mj_proxy_handled'] = True
70 |
71 |
--------------------------------------------------------------------------------
/wss/mj_wss_manager.py:
--------------------------------------------------------------------------------
1 |
2 | from support.mj_config import MjConfig
3 | import redis
4 | from wss.mj_wss_proxy import MjWssSercice
5 | from service.mj_data_service import MjDataService
6 | from service.notify_service import NotifyService
7 | from support.mj_account import MjAccount
8 | from typing import List
9 |
10 | class MjWssManager:
11 | def __init__(self, mj_config: MjConfig, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService) -> None:
12 | self.mj_config = mj_config
13 | self.redis_client = redis_client
14 | self.mj_data_service = mj_data_service
15 | self.notify_service = notify_service
16 | self.wss_list:List[MjWssSercice] = []
17 | self.init_wss_list()
18 |
19 | def init_wss_list(self):
20 | mj_accounts = self.mj_config.get_accounts()
21 | for account in mj_accounts:
22 | mj_account = MjAccount(account)
23 | self.wss_list.append(MjWssSercice(
24 | self.mj_config.mj_config, self.redis_client, self.mj_data_service, self.notify_service, mj_account))
25 |
26 | def start_all(self):
27 | for wss in self.wss_list:
28 | wss.start()
--------------------------------------------------------------------------------
/wss/mj_wss_proxy.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import websocket
3 | import json
4 | import ssl
5 | import time
6 | import asyncio
7 | from utils.logger_util import get_logger
8 | import logging
9 | from typing import List
10 | from wss.handler.base_message_handler import MessageHandler
11 | from wss.handler.upscale_handler import UpscaleHandler
12 | from wss.handler.imagine_hanlder import ImagineHandler
13 | from wss.handler.upscale_handler import UpscaleHandler
14 | from wss.handler.describe_success_handler import DescribeSuccessHandler
15 | from wss.handler.variation_handler import VariationHandler
16 | import redis
17 | from service.mj_data_service import MjDataService
18 | from service.notify_service import NotifyService
19 | from support.mj_account import MjAccount
20 |
21 | logger = get_logger(__name__)
22 |
23 |
24 | class MjWssSercice:
25 | def __init__(self, config, redis_client: redis.Redis, mj_data_service: MjDataService, notify_service: NotifyService, account: MjAccount):
26 | self.config = config
27 | self.account = account
28 | self.sequence = 1
29 | self.redis_client = redis_client
30 | self.mj_data_service = mj_data_service
31 | self.notify_service = notify_service
32 | # Discord WebSocket 地址
33 | self.ws_url = config['ng']['discord_ws']
34 | self.bot_token = account.get_bot_token()
35 | self.ws = None
36 |
37 | self.message_handlers: List[MessageHandler] = []
38 | self.init_handlers()
39 |
40 | def init_handlers(self):
41 | self.message_handlers.append(UpscaleHandler(
42 | self.redis_client, self.mj_data_service, self.notify_service))
43 | self.message_handlers.append(ImagineHandler(
44 | self.redis_client, self.mj_data_service, self.notify_service))
45 | self.message_handlers.append(UpscaleHandler(
46 | self.redis_client, self.mj_data_service, self.notify_service))
47 | self.message_handlers.append(VariationHandler(
48 | self.redis_client, self.mj_data_service, self.notify_service))
49 | self.message_handlers.append(DescribeSuccessHandler(
50 | self.redis_client, self.mj_data_service, self.notify_service))
51 |
52 | # 登录
53 | def on_open(self, ws):
54 | logger.info("### opened ###")
55 | self.ws.send(json.dumps({
56 | 'op': 2,
57 | 'd': {
58 | 'token': self.bot_token,
59 | 'intents': 513,
60 | 'properties': {
61 | '$os': 'linux',
62 | '$browser': 'my_library_name',
63 | '$device': 'my_library_name'
64 | }
65 | }
66 | }))
67 |
68 | # 监听消息
69 | def on_message(self, ws, message):
70 | # print(message)
71 | data = json.loads(message)
72 | # print(data)
73 | if data['t'] not in ['MESSAGE_CREATE', 'MESSAGE_UPDATE', 'MESSAGE_DELETE']:
74 | return
75 | if self.ignoreMessage(data):
76 | return
77 |
78 | # ACK
79 | if 'op' in data and data['op'] == 11:
80 | return
81 | else:
82 | if 'op' in data and data['op'] == 10:
83 | # 心跳包
84 | logger.info("receive heart from message")
85 | heartbeat_interval = data['d']['heartbeat_interval']
86 | self.ws.send(json.dumps({'op': 1, 'd': self.sequence}))
87 | self.sequence += 1
88 | elif 'op' in data and data['op'] == 0:
89 | pass
90 | # event_type = data['t']
91 | # if event_type == 'MESSAGE_CREATE':
92 | # message_content = data['d']['content']
93 | # message_author = data['d']['author']['username']
94 | # logger.info(f'{message_author}: {message_content}')
95 |
96 | if 'op' in data and data['op'] == 0:
97 | for handler in self.message_handlers:
98 | if "mj_proxy_handled" in data['d'] and data['d']['mj_proxy_handled']:
99 | break
100 | handler.handle(data)
101 |
102 | def ignoreMessage(self, data: dict):
103 | channel_id = data.get('d').get('channel_id')
104 | if not channel_id:
105 | return True
106 | if channel_id != self.account.get_channel_id():
107 | return True
108 | op = data.get('op')
109 | if op in [0, 10]:
110 | self.remove_extra_info(data)
111 | logger.info(json.dumps(data))
112 | return False
113 |
114 | # 在打印的时候去除掉多余的信息
115 | def remove_extra_info(self, data: dict):
116 | if 'd' in data:
117 | if "author" in data['d']:
118 | del data['d']['author']
119 | if "mentions" in data['d']:
120 | del data['d']['mentions']
121 | if "member" in data['d']:
122 | del data['d']['member']
123 | if "interaction_metadata" in data['d']:
124 | del data['d']['interaction_metadata']
125 | if "interaction" in data['d']:
126 | del data['d']['interaction']
127 | if "referenced_message" in data['d']:
128 | del data['d']['referenced_message']
129 |
130 | def on_error(self, ws, error):
131 | logger.error(error)
132 |
133 | def on_close(self, ws, close_status_code, close_msg):
134 | logger.info("### closed ###")
135 |
136 | def run_heart(self):
137 | while True:
138 | heartbeat = {"op": 1, "d": self.sequence}
139 | # logger.info("send heart")
140 | time.sleep(30)
141 | self.ws.send(json.dumps(heartbeat))
142 | self.sequence += 1
143 |
144 | def websocket_start_inner(self):
145 |
146 | self.ws = websocket.WebSocketApp(
147 | self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close)
148 | self.ws.on_open = self.on_open
149 | self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
150 |
151 | # 断线重连
152 | while True:
153 | try:
154 | self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
155 | except Exception as e:
156 | logger.error(e)
157 | time.sleep(5)
158 | continue
159 |
160 | def start(self):
161 | # websocket.enableTrace(True)
162 | thread_hi = threading.Thread(target=self.run_heart)
163 | thread_hi.start()
164 |
165 | thread_websocket = threading.Thread(target=self.websocket_start_inner)
166 | thread_websocket.start()
167 |
168 |
169 | if __name__ == "__main__":
170 | pass
171 |
--------------------------------------------------------------------------------