├── start.sh ├── examples ├── exp1.png ├── exp2.png ├── exp3.png ├── exp4.jpg └── exp6.png ├── go.mod ├── .gitignore ├── go.sum ├── config.json ├── LICENSE ├── wechat_client.go ├── Readme.md └── bot.py /start.sh: -------------------------------------------------------------------------------- 1 | go run wechat_client.go & python bot.py --config $1 2 | 3 | -------------------------------------------------------------------------------- /examples/exp1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxysxy/SXYWeChatBot/HEAD/examples/exp1.png -------------------------------------------------------------------------------- /examples/exp2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxysxy/SXYWeChatBot/HEAD/examples/exp2.png -------------------------------------------------------------------------------- /examples/exp3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxysxy/SXYWeChatBot/HEAD/examples/exp3.png -------------------------------------------------------------------------------- /examples/exp4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxysxy/SXYWeChatBot/HEAD/examples/exp4.jpg -------------------------------------------------------------------------------- /examples/exp6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxysxy/SXYWeChatBot/HEAD/examples/exp6.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module SXYWechatBot 2 | 3 | go 1.19 4 | 5 | require github.com/eatmoreapple/openwechat v1.4.3 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | storage.json 2 | .vscode 3 | myconf.json 4 | latest*.png 5 | .DS_Store 6 | anything-v4.0 7 | chatglm2-6b 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/eatmoreapple/openwechat v1.4.3 h1:hpqR3M0c180GN5e6sfkqdTmna1+vnvohqv8LkS7MecI= 2 | github.com/eatmoreapple/openwechat v1.4.3/go.mod h1:ZxMcq7IpVWVU9JG7ERjExnm5M8/AQ6yZTtX30K3rwRQ= 3 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "OpenAI-GPT" : { 3 | "Enable" : true, 4 | "OpenAI-Key" : "Please write your openai api key", 5 | "GPT-Model" : "gpt-3.5-turbo-0301", 6 | "Temperature" : 0.7, 7 | "MaxPromptTokens" : 2560, 8 | "MaxTokens" : 4096 9 | }, 10 | "ChatGLM" : { 11 | "Enable" : false, 12 | "GPT-Model" : "THUDM/chatglm2-6b" 13 | }, 14 | "Diffusion" : { 15 | "Diffusion-Model" : "stabilityai/stable-diffusion-2-1", 16 | "NoNSFWChecker" : true, 17 | "UseFP16" : true 18 | } 19 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) <2023> 2 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 3 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 4 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /wechat_client.go: -------------------------------------------------------------------------------- 1 | package main 2 | import ( 3 | "fmt" 4 | "strings" 5 | "net/http" 6 | "io/ioutil" 7 | "time" 8 | "os" 9 | "bytes" 10 | "encoding/json" 11 | "github.com/eatmoreapple/openwechat" 12 | ) 13 | 14 | func Use(vals ...interface{}) { 15 | for _, val := range vals { 16 | _ = val 17 | } 18 | } 19 | 20 | type SendTextRequest struct { 21 | InGroup bool `json:"in_group"` //本来想用于区分在群聊和非群聊时的上下文记忆规则,但是最终没有实现... 22 | UserID string `json:"user_id"` 23 | Text string `json:"text"` 24 | } 25 | 26 | type SendTextResponse struct { 27 | UserID string `json:"user_id"` 28 | Text string `json:"text"` 29 | HasError bool `json:"error"` 30 | ErrorMessage string `json:"error_msg"` 31 | } 32 | 33 | type SendImageRequest struct { 34 | UserName string `json:"user_name"` 35 | FileNames []string `json:"filenames"` 36 | HasError bool `json:"error"` 37 | ErrorMessage string `json:"error_msg"` 38 | } 39 | 40 | type GenerateImageRequest struct { 41 | UserName string `json:"user_name"` 42 | Prompt string `json:"prompt"` 43 | } 44 | 45 | func HttpPost(url string, data interface{}, timelim int) []byte { 46 | // 超时时间 47 | timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim)) //是的,这里有个bug,但是这里就是靠这个bug正常运行的!!!??? 48 | 49 | client := &http.Client{Timeout: timeout} 50 | jsonStr, _ := json.Marshal(data) 51 | resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonStr)) 52 | if err != nil { 53 | return []byte("") 54 | } 55 | defer resp.Body.Close() 56 | 57 | result, _ := ioutil.ReadAll(resp.Body) 58 | return result 59 | 60 | // ——————————————— 61 | // 版权声明:本文为CSDN博主「gaoluhua」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 62 | // 原文链接:https://blog.csdn.net/gaoluhua/article/details/124855716 63 | } 64 | 65 | func main() { 66 | 67 | bot := openwechat.DefaultBot(openwechat.Desktop) // 桌面模式,上面登录不上的可以尝试切换这种模式 68 | reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json") 69 | defer reloadStorage.Close() 70 | 71 | err := bot.PushLogin(reloadStorage, openwechat.NewRetryLoginOption()) 72 | if err != nil { 73 | fmt.Println(err) 74 | return 75 | } 76 | 77 | // 获取登陆的用户 78 | self, err := bot.GetCurrentUser() 79 | if err != nil { 80 | fmt.Println(err) 81 | return 82 | } 83 | 84 | Use(self) 85 | 86 | // 注册消息处理函数 87 | bot.MessageHandler = func(msg *openwechat.Message) { 88 | if msg.IsTickledMe() { 89 | msg.ReplyText("别拍了,机器人是会被拍坏掉的。") 90 | return 91 | } 92 | 93 | if !msg.IsText() { 94 | return 95 | } 96 | 97 | // fmt.Println(msg.Content) 98 | 99 | content := msg.Content 100 | if msg.IsSendByGroup() && !msg.IsAt() { 101 | return 102 | } 103 | 104 | if msg.IsSendByGroup() && msg.IsAt() { 105 | atheader := fmt.Sprintf("@%s", self.NickName) 106 | //fmt.Println(atheader) 107 | if strings.HasPrefix(content, atheader) { 108 | content = strings.TrimLeft(content[len(atheader):], "  \t\n") 109 | } 110 | } 111 | //fmt.Println(content) 112 | 113 | content = strings.TrimRight(content, "  \t\n") 114 | if content == "查看机器人信息" { 115 | info := HttpPost("http://localhost:11111/info", nil, 20) 116 | msg.ReplyText(string(info)) 117 | 118 | } else if strings.HasPrefix(content, "生成图片") { 119 | // 调用Stable Diffusion 120 | // msg.ReplyText("这个功能还没有实现,可以先期待一下~") 121 | sender, _ := msg.Sender() 122 | 123 | content = strings.TrimLeft(content[len("生成图片"):], " \t\n") 124 | 125 | resp_raw := HttpPost("http://localhost:11111/draw", GenerateImageRequest{UserName : sender.ID(), Prompt : content}, 120) 126 | if len(resp_raw) == 0 { 127 | msg.ReplyText("生成图片出错啦QwQ,或许可以再试一次") 128 | return 129 | } 130 | 131 | resp := SendImageRequest{} 132 | json.Unmarshal(resp_raw, &resp) 133 | //fmt.Println(resp.FileName) 134 | if resp.HasError { 135 | msg.ReplyText( fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage) ) 136 | } else { 137 | for i := 0; i < len(resp.FileNames); i++ { 138 | img, _ := os.Open(resp.FileNames[i]) 139 | defer img.Close() 140 | msg.ReplyImage(img) 141 | } 142 | } 143 | 144 | } else { 145 | // 调用GPT 146 | 147 | sender, _ := msg.Sender() 148 | //var group openwechat.Group{} = nil 149 | var group *openwechat.Group = nil 150 | 151 | if msg.IsSendByGroup() { 152 | group = &openwechat.Group{User : sender} 153 | } 154 | 155 | if content == "重置上下文" { 156 | if !msg.IsSendByGroup() { 157 | HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup : msg.IsSendByGroup(), UserID : sender.ID(), Text : ""}, 60) 158 | } else { 159 | HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup : msg.IsSendByGroup(), UserID : group.ID(), Text : ""}, 60) 160 | } 161 | msg.ReplyText("OK,我忘掉了之前的上下文。") 162 | return 163 | } 164 | 165 | resp := SendTextResponse{} 166 | resp_raw := []byte("") 167 | 168 | if !msg.IsSendByGroup() { 169 | resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup : false, UserID : sender.ID(), Text : msg.Content}, 60) 170 | } else { 171 | resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup : false, UserID : group.ID(), Text : msg.Content}, 60) 172 | } 173 | if len(resp_raw) == 0 { 174 | msg.ReplyText("运算超时了QAQ,或许可以再试一次。") 175 | return 176 | } 177 | 178 | json.Unmarshal(resp_raw, &resp) 179 | 180 | if len(resp.Text) == 0 { 181 | msg.ReplyText("GPT对此没有什么想说的,换个话题吧。") 182 | } else { 183 | if resp.HasError { 184 | if msg.IsSendByGroup() { 185 | sender_in_group, _ := msg.SenderInGroup() 186 | nickname := sender_in_group.NickName 187 | msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage)) 188 | } else { 189 | msg.ReplyText(resp.ErrorMessage) 190 | } 191 | } else { 192 | if msg.IsSendByGroup() { 193 | sender_in_group, _ := msg.SenderInGroup() 194 | nickname := sender_in_group.NickName 195 | msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text)) 196 | } else { 197 | msg.ReplyText(resp.Text) 198 | } 199 | } 200 | } 201 | 202 | } 203 | } 204 | 205 | bot.Block() 206 | } -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # SXYWeChatBot 2 | 3 | 一个接入了ChatGPT和NovelAI的微信聊天机器人,兼容windows、mac、linux平台,代码很短很简单很容易扩展别的功能~ 4 | 5 | ## [安装配置方法](#ch1) 6 | ### [依赖](#ch11) 7 | ### [修改配置](#ch12) 8 | ### [然后就可以运行了](#ch13) 9 | ### [注意](#ch14) 10 | 11 | ## [机器人使用方法](#ch2) 12 | ### [例子](#ch21) 13 | 14 | ## [更新日志](#ch3) 15 | 16 | ## [使用协议](#ch4) 17 | 18 | ## 安装配置方法 19 |

20 | 21 | ### 依赖 22 |

23 | 24 | 用到了两种编程语言:go和python3。使用go是因为本项目依赖于强力的使用go写成的openwechat(请按照这里配置openwechat的方法安装openwechat)实现对微信会话的获取以及发送消息的功能。调用ChatGPT以及Stable Diffusion模型则使用python3。 25 | 26 | python3需要再安装这些库,使用pip安装就可以: 27 | 28 | ``` 29 | pip install torch flask openai transformers diffusers accelerate 30 | ``` 31 | 32 | 如果你要使用ChatGLM,注意torch的版本应该>=2.0,transformers的版本为4.30.2,并且还要安装 33 | 34 | ``` 35 | pip install gradio mdtex2html sentencepiece cpm_kernels 36 | ``` 37 | 38 | 当然如果使用cuda加速建议按照pytorch官网提供的方法安装支持cuda加速的torch版本。 39 | 40 | Apple Silicon的macbook上可以使用mps后端加速,我开发的时候使用的就是M1 Max芯片的Macbook Pro。 41 | 42 | ### 修改配置 43 |

44 | 45 | 打开config.json进行修改 46 | 47 | 如果你使用OpenAI的GPT: 48 | 49 | 默认的配置文件就是使用OpenAI的GPT的,你需要有一个OpenAI账号,然后将API Key写到config.json的OpenAI-API-Key字段后,然后保存。其余的配置通常按照默认的就可以,或者可以前往OpenAI官网查看其他可用的GPT模型, 50 | 51 | 如果你使用ChatGLM: 52 | 53 | 先把OpenAI-GPT的Enable改为false,然后再把ChatGLM的Enable改为true即可。 54 | 55 | 因为懒省事所以有一些参数是写死在代码里的(坏文明),也是可以调整的,比如超时时间可以在wechat_client.go的代码中修改,这样在生成高分辨率、迭代次数非常多的图片的时候留有更多的时间。总之就是代码太简单了,自己看着改一下就行了(就是作者懒)。还有bot.py运行的时候用WSGI什么的,也就是加两行代码。(懒+1) 56 | 57 | 关于Diffusion模型: 58 | 59 | UseFP16一般打开就可以了,能显著减少显存需求,并且对画质几乎没有影响。 60 | 61 | NoNSFWChecker一般打开就可以了,用于过滤生成含有NSFW内容的图片比如涩图,被过滤的图片会变为纯黑色。 62 | 63 | ### 然后就可以运行了 64 |

65 | 66 | Mac/Linux用户可以直接运行start.sh: 67 | 68 | ``` 69 | ./start.sh config.json 70 | ``` 71 | 72 | 或者分开运行bot和wechat_client: 73 | 74 | ``` 75 | python bot.py 76 | go run wechat_client.go 77 | ``` 78 | 79 | ### 注意 80 |

81 | 82 | 第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆(这时候go程序会卡住没有任何提示,注意掏出手机确认登录微信)。 83 | 84 | 第一次运行需要下载Diffusion模型,文件很大,并且从外网下载,需要有比较快速稳定的网络条件。 85 | 86 | 使用Diffusion系列模型生成图片对显存容量的需求很大,在默认开启16位浮点数、768x768分辨率的条件下,迭代20次需要14GB显存(参考:RTX4080仅有12GB显存)。搭载Apple Silicon的Macbook/Mac Studio因为统一内存,有比较好的表现。 87 | 88 | Diffusion推荐使用的模型: 89 | 90 | ``` 91 | andite/anything-v4.0 : 二次元浓度很高,画人的水平不错。(目前在huggingface上该模型已被删除,如果有本地缓存的话还能找到) 92 | stabilityai/stable-diffusion-2-1 : 比较通用,能生成各种图片,二次元风格真实风格都可以,但是画人的能力很差,经常出现崩坏的手,缺胳膊少腿等问题。。。 93 | ``` 94 | 95 | bot.py会占用本地11111网络端口,如果发生冲突可以在bot.py中修改这个端口号(没弄配置文件里?没错还是作者懒,要不是作者不想把自己的API Key写代码里开源了,连config.json配置文件都不会有(笑)) 96 | 97 | 没写自动通过好友请求的功能,呃。。。等我啥时候不小心再点开这个工程文件夹的时候再加入这个功能好了。 98 | 99 | 生成图片的时候图片都会临时保存为latest.png,那么这样的话面对多个请求同时生成图片的时候,可能会意外覆写latest.png导致返回错误的图片。解决方法可以是比如图片的二进制数据直接通过socket传到wechat_clinet里面,而非通过文件的方式。但是我就临时学了一晚上的go语言,我不会,写wechat_client.go能跑起来已经是难为死我了。。。。 100 | 101 | ## 机器人使用方法 102 |

103 | 104 | 在微信上私聊机器人登陆的微信号,或者将机器人拉入微信群,@机器人 使用就可以。对话不需要特殊指令,直接聊就可以,汉语英语日语等都可以。使用Stable Diffusion模型生成图片时候需要使用特殊指令 生成图片,格式为 105 | 106 | ``` 107 | 生成图片: 咒语 108 | 负面咒语 109 | ``` 110 | 111 | 或者 112 | 113 | ``` 114 | 生成图片(宽 高 迭代次数): 咒语 115 | 负面咒语 116 | ``` 117 | 118 | 咒语只能用英语,如果使用默认的模型,咒语的长度不能超过77个单词(CLIP的TextEncoder的限制,可以更换模型解决),负面咒语就是negative prompt,比如不想让模型生成丑陋的脸,崩坏的手之类的,就在负面咒语的部分写上ugly face, corrupted face之类的。 119 | 120 | 还有个特殊指令是 重置上下文。text-davinci系列模型和chat.openai.com的ChatGPT还是有一些区别的,它承受不了太长的上下文。当需要你告诉机器人“重置上下文”的时候他会告诉你。 121 | 122 | ### 例子 123 |

124 | 125 | ![](examples/exp1.png) 126 | 127 | ![](examples/exp2.png) 128 | 129 | ![](examples/exp3.png) 130 | 131 | ``` 132 | 生成图片(800 600 120): best quality, high resolution, (((masterpiece))), dazzling, extremely detailed, cyberpunck city landscape 133 | ``` 134 | 135 | 这个例子使用了参数,要求图片的分辨率为800x600。生成的时候迭代120次。 136 | 137 | ![](examples/exp4.jpg) 138 | 139 | 下面这个例子使用 andite/anything-v4.0 模型 140 | 141 | ``` 142 | 生成图片(800 720 25):otokonoko,masterpiece, best quality,best qualityc, maid uniform,white hair,very long hair,golden eyes,cat ears,cat tail,smile,long stocks,white stocks,bedroom,(looking at viewer) 143 | low quality, dark, fuzzy, normal quality, ugly, twisted face, scary eyes, sexual implication 144 | ``` 145 | 146 | ![](examples/exp6.png) 147 | 148 | ## 更新日志 149 |

150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 |
版本 日期 说明
v1.2 2023.07.02 跟进OpenAI的更新,使用openai.ChatCompletion对话API而不是文本补全API
支持清华大学的开源的ChatGLM2作为GPT模型。
由于anything-v4.0模型在hunggingface上被删除,默认的Diffusion模型改为stabilityai/stable-diffusion-2-1
更新了所依赖的openwechat的版本到v1.4.3
v1.1 2023.02.07 1.修改默认的Diffusion模型为andite/anything-v4.0
2.新增特殊指令“查看机器人信息”
3.新增半精度浮点数开关,并默认开启,减少内存占用
4.新增内容安全性检查开关
v1.0 2023.02.05 初始版本
175 | 176 | ## 使用协议 177 |

178 | 179 | 1.作者sxysxy依法享有对本软件的软件著作权:非商业使用遵循MIT协议即可(见LICENSE文件),商业使用联系作者,邮箱sxysxygm@gmail.com或1441157749@qq.com。(The author sxysxy is legally entitled to the software copyright: for non-commercial use, please follow the MIT license(see LICENSE file). For commercial use, contact the author at sxysxygm@gmail.com or 1441157749@qq.com) 180 | 181 | 2.内容版权:我不具有模型训练数据的版权,也不具有其创作成果的版权,我不能保证机器人的创作成果商业使用的合理与合法性。因为机器人的创作内容产生的一切纠纷与本项目作者无关。(Content copyright: I do not have the copyright of the model training data, nor the copyright of its creation results. I cannot guarantee the reasonableness and legality of the commercial use of the robot's creation results. All disputes arising from the creation content of robots have nothing to do with the author of this project.) 182 | 183 | 3.请不要诱导AI产生有害的内容,比如在公共场合创作R18、zz敏感、歧视偏见等的内容。(Plase do not induce AI to produce harmful content, such as R18, political sensitive, discriminatory and biased content in public situations.) 184 | 185 | 4.使用本项目的人,已知悉并同意”使用协议“的内容,否则请删除本项目的所有文件。(The uers of this project has known and agreed to the content of the "Usage Agreement", otherwise, please delate all documents and programs of this project.) 186 | 187 | -------------------------------------------------------------------------------- /bot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import openai 3 | import re 4 | from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler 5 | from transformers import AutoTokenizer, AutoModel 6 | import torch 7 | import argparse 8 | import flask 9 | import typing 10 | import traceback 11 | 12 | ps = argparse.ArgumentParser() 13 | ps.add_argument("--config", default="config.json", help="Configuration file") 14 | args = ps.parse_args() 15 | 16 | with open(args.config) as f: 17 | config_json = json.load(f) 18 | 19 | class GlobalData: 20 | # OPENAI_ORGID = config_json[""] 21 | OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"] 22 | OPENAI_MODEL = config_json["OpenAI-GPT"]["GPT-Model"] 23 | OPENAI_MODEL_TEMPERATURE = int(config_json["OpenAI-GPT"]["Temperature"]) 24 | OPENAI_MODEL_MAXTOKENS = min(2048, int(config_json["OpenAI-GPT"]["MaxTokens"])) 25 | 26 | CHATGLM_MODEL = config_json["ChatGLM"]["GPT-Model"] 27 | 28 | context_for_users = {} 29 | context_for_groups = {} 30 | 31 | GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") 32 | GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") 33 | GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+") 34 | GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数 35 | 36 | USE_OPENAIGPT = False 37 | USE_CHATGLM = False 38 | 39 | if config_json["OpenAI-GPT"]["Enable"]: 40 | print(f"Use OpenAI GPT Model({GlobalData.OPENAI_MODEL}).") 41 | USE_OPENAIGPT = True 42 | elif config_json["ChatGLM"]["Enable"]: 43 | print(f"Use ChatGLM({GlobalData.CHATGLM_MODEL}) as GPT-Model.") 44 | chatglm_tokenizer = AutoTokenizer.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True) 45 | chatglm_model = AutoModel.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True) 46 | if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 47 | chatglm_model = chatglm_model.to('mps') 48 | elif torch.cuda.is_available(): 49 | chatglm_model = chatglm_model.to('cuda') 50 | chatglm_model = chatglm_model.eval() 51 | USE_CHATGLM = True 52 | 53 | app = flask.Flask(__name__) 54 | 55 | # 这个用于放行生成的任何图片,替换掉默认的NSFW检查器,公共场合慎重使用 56 | def run_safety_nochecker(image, device, dtype): 57 | print("警告:屏蔽了内容安全性检查,可能会产生有害内容") 58 | return image, None 59 | 60 | sd_args = { 61 | "pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"], 62 | "torch_dtype" : (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32) 63 | } 64 | 65 | sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args) 66 | sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) 67 | if config_json["Diffusion"]["NoNSFWChecker"]: 68 | setattr(sd_pipe, "run_safety_checker", run_safety_nochecker) 69 | 70 | if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 71 | sd_pipe = sd_pipe.to("mps") 72 | elif torch.cuda.is_available(): 73 | sd_pipe = sd_pipe.to("cuda") 74 | 75 | GPT_SUCCESS = 0 76 | GPT_NORESULT = 1 77 | GPT_ERROR = 2 78 | 79 | def CallOpenAIGPT(prompts : typing.List[str]): 80 | try: 81 | res = openai.ChatCompletion.create( 82 | model=config_json["OpenAI-GPT"]["GPT-Model"], 83 | messages=prompts 84 | ) 85 | if len(res["choices"]) > 0: 86 | return (GPT_SUCCESS, res["choices"][0]["message"]["content"].strip()) 87 | else: 88 | return (GPT_NORESULT, "") 89 | except openai.InvalidRequestError as e: 90 | return (GPT_ERROR, e) 91 | except Exception as e: 92 | traceback.print_exception(e) 93 | return (GPT_ERROR, str(e)) 94 | 95 | def CallChatGLM(msg, history : typing.List[str]): 96 | try: 97 | resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history) 98 | if isinstance(resp, tuple): 99 | resp = resp[0] 100 | return (GPT_SUCCESS, resp) 101 | except Exception as e: 102 | return (GPT_ERROR, str(e)) 103 | 104 | def add_context(uid : str, is_user : bool, msg : str): 105 | if not uid in GlobalData.context_for_users: 106 | GlobalData.context_for_users[uid] = [] 107 | if USE_OPENAIGPT: 108 | GlobalData.context_for_users[uid].append({ 109 | "role" : "system", 110 | "content" : msg 111 | } 112 | ) 113 | elif USE_CHATGLM: 114 | GlobalData.context_for_users[uid].append(msg) 115 | 116 | def get_context(uid : str): 117 | if not uid in GlobalData.context_for_users: 118 | GlobalData.context_for_users[uid] = [] 119 | return GlobalData.context_for_users[uid] 120 | 121 | 122 | @app.route("/chat_clear", methods=['POST']) 123 | def app_chat_clear(): 124 | data = json.loads(flask.globals.request.get_data()) 125 | GlobalData.context_for_users[data["user_id"]] = [] 126 | print(f"Cleared context for {data['user_id']}") 127 | return "" 128 | 129 | @app.route("/chat", methods=['POST']) 130 | def app_chat(): 131 | data = json.loads(flask.globals.request.get_data()) 132 | #print(data) 133 | uid = data["user_id"] 134 | 135 | if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']: 136 | data["text"] += "。" 137 | 138 | if USE_OPENAIGPT: 139 | add_context(uid, True, data["text"]) 140 | #prompt = GlobalData.context_for_users[uid] 141 | prompt = get_context(uid) 142 | resp = CallOpenAIGPT(prompt=prompt) 143 | #GlobalData.context_for_users[data["user_id"]] = (prompt + resp) 144 | add_context(uid, False, resp[1]) 145 | #print(f"Prompt = {prompt}\nResponse = {resp[1]}") 146 | elif USE_CHATGLM: 147 | #prompt = GlobalData.context_for_users[uid] 148 | prompt = get_context(uid) 149 | resp = CallChatGLM(msg=data["text"], history=prompt) 150 | add_context(uid, True, (data["text"], resp[1])) 151 | else: 152 | pass 153 | 154 | if resp[0] == GPT_SUCCESS: 155 | return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""}) 156 | else: 157 | return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]}) 158 | 159 | @app.route("/draw", methods=['POST']) 160 | def app_draw(): 161 | data = json.loads(flask.globals.request.get_data()) 162 | 163 | prompt = data["prompt"] 164 | 165 | i = 0 166 | for i in range(len(prompt)): 167 | if prompt[i] == ':' or prompt[i] == ':': 168 | break 169 | if i == len(prompt): 170 | return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) 171 | 172 | 173 | match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i]) 174 | if not match_args is None: 175 | W = int(match_args.group(2)) 176 | H = int(match_args.group(3)) 177 | ITS = int(match_args.group(4)) 178 | NUM_PIC = int(match_args.group(5)) 179 | else: 180 | match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i]) 181 | if not match_args is None: 182 | W = int(match_args.group(2)) 183 | H = int(match_args.group(3)) 184 | ITS = int(match_args.group(4)) 185 | NUM_PIC = 1 186 | else: 187 | if len(prompt[:i].strip()) != 0: 188 | return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) 189 | else: 190 | W = 768 191 | H = 768 192 | ITS = config_json.get('DefaultDiffutionIterations', 20) 193 | NUM_PIC = 1 194 | 195 | if W > 2500 or H > 2500: 196 | return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"}) 197 | 198 | if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS: 199 | return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"}) 200 | 201 | prompt = prompt[(i+1):].strip() 202 | 203 | prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt) 204 | prompt = prompts[0] 205 | 206 | neg_prompt = None 207 | if len(prompts) > 1: 208 | neg_prompt = prompts[1] 209 | 210 | print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}") 211 | 212 | try: 213 | if NUM_PIC > 1 and torch.backends.mps.is_available(): #Apple silicon上的bug:https://github.com/huggingface/diffusers/issues/363 214 | return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, 215 | "error_msg" : "单prompt生成多张图像在Apple silicon上无法实现,相关讨论参考https://github.com/huggingface/diffusers/issues/363"}) 216 | 217 | images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, num_images_per_prompt=NUM_PIC).images[:NUM_PIC] 218 | if len(images) == 0: 219 | return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "没有产生任何图像"}) 220 | filenames = [] 221 | for i, img in enumerate(images): 222 | img.save(f"latest-{i}.png") 223 | filenames.append(f"latest-{i}.png") 224 | return json.dumps({"user_name" : data["user_name"], "filenames" : filenames, "error" : False, "error_msg" : ""}) 225 | 226 | except Exception as e: 227 | return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : str(e)}) 228 | 229 | @app.route("/info", methods=['POST', 'GET']) 230 | def app_info(): 231 | return "\n".join([f"GPT模型:{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}", f"Diffusion模型:{config_json['Diffusion']['Diffusion-Model']}", 232 | "默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20", 233 | f"使用半精度浮点数 : {'是' if config_json['Diffusion'].get('UseFP16', True) else '否'}", 234 | f"屏蔽NSFW检查:{'是' if config_json['Diffusion']['NoNSFWChecker'] else '否'}", 235 | "清空上下文指令:重置上下文", 236 | "生成图片指令:生成图片(宽 高 迭代次数):正面提示 换行写负面提示,其中(宽 高 迭代次数)和换行写的负面提示都是可以省略的"]) 237 | 238 | if __name__ == "__main__": 239 | 240 | if USE_OPENAIGPT: 241 | openai.api_key = GlobalData.OPENAI_APIKEY 242 | 243 | app.run(host="0.0.0.0", port=11111) --------------------------------------------------------------------------------