├── CNGPT ├── API.py ├── READ │ ├── Nano.txt │ └── t.PNG ├── README.md └── bot.py ├── LICENSE ├── README.md ├── TF1_GPT-2 ├── DEVELOPERS.md ├── README.md ├── download_model.py ├── requirements.txt ├── src │ ├── bot.py │ ├── encoder.py │ ├── generate_unconditional_samples.py │ ├── interactive_conditional_samples.py │ ├── model.py │ └── sample.py └── 捕获.PNG ├── TF2_GPT-2 ├── Api.py ├── README.md ├── bot.py └── 捕获.PNG └── bot.py /CNGPT/API.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import jieba 3 | import os 4 | import numpy as np 5 | 6 | from Dtasest import MyDataset 7 | from Layers import Config 8 | from Layers import utils 9 | from Layers.Model import GPT_Model 10 | 11 | 12 | GPT = GPT_Model#模型 13 | GPTconfig = Config.GPTConfig#模型配置 14 | Sample = utils.sample#示例 15 | 16 | 17 | def save_txt(model_path:str,train_data_name:str,context:str,steps:int): 18 | """ 19 | :param pre_model_path: 预训练模型的位置 20 | :param train_data_name: 训练用的数据 21 | :param context: 生成文章的标题 22 | :param steps: 生成文章的字数 23 | :return: 24 | """ 25 | 26 | path_ = os.path.join('datas', train_data_name) 27 | f = open(path_, encoding='utf-8').read() 28 | aa = jieba.lcut(f) 29 | #print(aa) 30 | # 构建 GPT 模型 31 | train_dataset = MyDataset(aa, 20) 32 | mconf = GPTconfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=12, n_head=12, n_embd=768) # a GPT-1 33 | model = GPT(config=mconf) 34 | #print(model) 35 | #加载预训练模型 36 | pre_model_path = os.path.join('Pre_models', model_path) 37 | model.load_state_dict(torch.load(pre_model_path, map_location='cpu')) 38 | 39 | x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...] 40 | y = Sample(model, x, steps=steps, temperature=1.0, sample=True, top_k=10)[0] 41 | completion = ''.join([train_dataset.itos[int(i)] for i in y]) 42 | f = open('save.txt','w',encoding='utf-8') 43 | f.write(completion) 44 | f.close() 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /CNGPT/READ/Nano.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /CNGPT/READ/t.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FloatTech/AI-Bot/083cc20b48d5f4538db3629d3c1d0e6e19bb1cff/CNGPT/READ/t.PNG -------------------------------------------------------------------------------- /CNGPT/README.md: -------------------------------------------------------------------------------- 1 | # 基于StarxSky的CNGPT中文GPT-2接入Bot的实现 2 | 3 | ## 使用方法(和以前一样先上目录图) 4 | ![im](https://github.com/FloatTech/AI-Bot/blob/main/CNGPT/READ/t.PNG?raw=true) 5 | - Steps 1 : 6 | - 使用前需要先git clone 克隆下来[GPT-2](https://github.com/StarxSky/GPT-2)的仓库 7 | - 并在其中找到```CNGPT```文件,进入 8 | - 在```CNGPT```目录下执行```pip install -r requirments.txt```安装所需的包 9 | - 将所下载的预训练模型或者已训练好的模型放置在```Pre_models```目录下 10 | - 将此仓库的```API.py```和```bot.py```移动到```CNGPT```目录下 11 | ### 您需要修改```bot.py```的以下代码: 12 | - 将您下载的或者通过```CNGPT```训练的模型填写到对应的位置 13 | - 将```CNGPT```目录下的```datas```中的```train.text```填写到对应的位置 14 | - 注意!!您用哪种文本语料训练的```CNGPT```您就需要把您的语料路径填写进去!!(默认的语料库是```datas```中的```train.text```) 15 | - 如果还有问题或者详细的如何训练CNGPT语言模型请[点击这里查看CNGPT的详情](https://github.com/StarxSky/GPT-2/blob/main/CNGPT/README.md) 16 | ```python 17 | 18 | # GPT-2生成文章插件 19 | class GeneratePlugin(Plugin): 20 | def match(self): 21 | return self.on_full_match('生成文章') 22 | 23 | def handle(self): 24 | GPT.save_txt(context='你好!', # 文章题目 25 | steps='20', # 生成文章的字数 26 | model_path='', #模型存放的路径 27 | train_data_name='' #训练数据的文件名字 28 | ) 29 | 30 | f = open('save.txt', encoding='utf-8').read() 31 | self.send_msg(text('哒哒哒~~~生成完成:{}'.format(f))) 32 | ``` 33 | -------------------------------------------------------------------------------- /CNGPT/bot.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import queue 4 | import logging 5 | import threading 6 | import collections 7 | import json as json_ 8 | import API as GPT 9 | import numpy as np 10 | import os 11 | import psutil 12 | import websocket 13 | 14 | WS_URL = "ws://127.0.0.1:6700/ws" # WebSocket 地址 15 | NICKNAME = ["BOT", "ROBOT"] # 机器人昵称 16 | SUPER_USER = [12345678, 23456789] # 主人的 QQ 号 17 | # 日志设置 level=logging.DEBUG -> 日志级别为 DEBUG 18 | logging.basicConfig(level=logging.DEBUG, format="[void] %(asctime)s - %(levelname)s - %(message)s") 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Plugin: 23 | def __init__(self, context: dict): 24 | self.ws = WS_APP 25 | self.context = context 26 | 27 | def match(self) -> bool: 28 | return self.on_full_match("hello") 29 | 30 | def handle(self): 31 | self.send_msg(text("hello world!")) 32 | 33 | def on_message(self) -> bool: 34 | return self.context["post_type"] == "message" 35 | 36 | def on_full_match(self, keyword="") -> bool: 37 | return self.on_message() and self.context["message"] == keyword 38 | 39 | def on_reg_match(self, pattern="") -> bool: 40 | return self.on_message() and re.search(pattern, self.context["message"]) 41 | 42 | def only_to_me(self) -> bool: 43 | flag = False 44 | for nick in NICKNAME + [f"[CQ:at,qq={self.context['self_id']}] "]: 45 | if self.on_message() and nick in self.context["message"]: 46 | flag = True 47 | self.context["message"] = self.context["message"].replace(nick, "") 48 | return flag 49 | 50 | def super_user(self) -> bool: 51 | return self.context["user_id"] in SUPER_USER 52 | 53 | def admin_user(self) -> bool: 54 | return self.super_user() or self.context["sender"]["role"] in ("admin", "owner") 55 | 56 | def call_api(self, action: str, params: dict) -> dict: 57 | echo_num, q = echo.get() 58 | data = json_.dumps({"action": action, "params": params, "echo": echo_num}) 59 | logger.info("发送调用 <- " + data) 60 | self.ws.send(data) 61 | try: # 阻塞至响应或者等待30s超时 62 | return q.get(timeout=30) 63 | except queue.Empty: 64 | logger.error("API调用[{echo_num}] 超时......") 65 | 66 | def send_msg(self, *message) -> int: 67 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_msg-%E5%8F%91%E9%80%81%E6%B6%88%E6%81%AF 68 | if "group_id" in self.context and self.context["group_id"]: 69 | return self.send_group_msg(*message) 70 | else: 71 | return self.send_private_msg(*message) 72 | 73 | def send_private_msg(self, *message) -> int: 74 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_private_msg-%E5%8F%91%E9%80%81%E7%A7%81%E8%81%8A%E6%B6%88%E6%81%AF 75 | params = {"user_id": self.context["user_id"], "message": message} 76 | ret = self.call_api("send_private_msg", params) 77 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 78 | 79 | def send_group_msg(self, *message) -> int: 80 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_group_msg-%E5%8F%91%E9%80%81%E7%BE%A4%E6%B6%88%E6%81%AF 81 | params = {"group_id": self.context["group_id"], "message": message} 82 | ret = self.call_api("send_group_msg", params) 83 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 84 | 85 | 86 | def text(string: str) -> dict: 87 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E7%BA%AF%E6%96%87%E6%9C%AC 88 | return {"type": "text", "data": {"text": string}} 89 | 90 | 91 | def image(file: str, cache=True) -> dict: 92 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E5%9B%BE%E7%89%87 93 | return {"type": "image", "data": {"file": file, "cache": cache}} 94 | 95 | 96 | def record(file: str, cache=True) -> dict: 97 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E8%AF%AD%E9%9F%B3 98 | return {"type": "record", "data": {"file": file, "cache": cache}} 99 | 100 | 101 | def at(qq: int) -> dict: 102 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E6%9F%90%E4%BA%BA 103 | return {"type": "at", "data": {"qq": qq}} 104 | 105 | 106 | def xml(data: str) -> dict: 107 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#xml-%E6%B6%88%E6%81%AF 108 | return {"type": "xml", "data": {"data": data}} 109 | 110 | 111 | def json(data: str) -> dict: 112 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#json-%E6%B6%88%E6%81%AF 113 | return {"type": "json", "data": {"data": data}} 114 | 115 | 116 | def music(data: str) -> dict: 117 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E9%9F%B3%E4%B9%90%E5%88%86%E4%BA%AB- 118 | return {"type": "music", "data": {"type": "qq", "id": data}} 119 | 120 | 121 | """ 122 | 在下面加入你自定义的插件,自动加载本文件所有的 Plugin 的子类 123 | 只需要写一个 Plugin 的子类,重写 match() 和 handle() 124 | match() 返回 True 则自动回调 handle() 125 | """ 126 | 127 | 128 | 129 | class TestPlugin(Plugin): 130 | def match(self): # 说 hello 则回复 131 | return self.on_full_match("hello") 132 | 133 | def handle(self): 134 | self.send_msg(at(self.context["user_id"]), text("hello world!")) 135 | 136 | 137 | class f(Plugin): 138 | def match(self): 139 | return self.on_full_match("mua~") 140 | 141 | def handle(self): 142 | self.send_msg(at(self.context["user_id"]), text("恶心🤢")) 143 | 144 | 145 | class ss(Plugin): 146 | def match(self): 147 | return self.on_full_match("沙比") 148 | 149 | def handle(self): 150 | 151 | po = np.random.random(1) 152 | op = np.random.random(1) 153 | if op > po: 154 | self.send_msg(at(self.context["user_id"]), text('歪!!骂谁呐!')) 155 | else: 156 | self.send_msg(at(self.context["user_id"]), text('草草....草尼🐎🐎(¬︿̫̿¬☆)不理你了')) 157 | 158 | 159 | class ADD(Plugin): 160 | def match(self): 161 | return self.only_to_me() and self.on_full_match("好慢啊你") 162 | 163 | def handle(self): 164 | self.send_msg(at(self.context["user_id"]), text("要不你来试试?!!呜呜呜😭")) 165 | 166 | 167 | class SELF(Plugin): 168 | def match(self): 169 | return self.on_full_match("检查身体") 170 | 171 | def handle(self): 172 | info = os.system('ver') 173 | 174 | net_work = psutil.cpu_stats() 175 | 176 | mem = psutil.virtual_memory() 177 | # 系统总计内存 178 | All_M = float(mem.total) / 1024 / 1024 / 1024 179 | # 系统已经使用内存 180 | use_ing = float(mem.used) / 1024 / 1024 / 1024 181 | 182 | # 系统空闲内存 183 | free = float(mem.free) / 1024 / 1024 / 1024 184 | 185 | all_m = '系统总计内存:%d.3GB' % All_M 186 | Use = '系统已经使用内存:%d.3GB' % use_ing 187 | Free = '系统空闲内存:%d.3GB' % free 188 | C_k = 'CPU状态:{}'.format(net_work) 189 | 190 | self.send_msg(text('{}\n\n{}\n\n{}\n\n{}\n{}'.format(info, all_m, Use, Free, C_k))) 191 | 192 | 193 | class TestPlugin3(Plugin): 194 | def match(self): # 戳一戳机器人则回复 195 | return self.context["post_type"] == "notice" and self.context["sub_type"] == "poke" \ 196 | and self.context["target_id"] == self.context["self_id"] 197 | 198 | def handle(self): 199 | k = np.random.random(1) 200 | j = np.random.random(1) 201 | x = "请不要戳我 >_<" 202 | h = "歪!!戳我干嘛!!(╯▔皿▔)╯" 203 | if k < j: 204 | self.send_msg(text(x)) 205 | else: 206 | self.send_msg(text(h)) 207 | 208 | 209 | class TPugin(Plugin): 210 | def match(self): 211 | return self.on_full_match('生成文章') 212 | 213 | def handle(self): 214 | self.send_msg(text('构思中可能需要几分钟,取决于我的小脑袋ε=ε=ε=(~ ̄▽ ̄)~........')) 215 | 216 | 217 | # GPT-2生成文章插件 218 | class GeneratePlugin(Plugin): 219 | def match(self): 220 | return self.on_full_match('生成文章') 221 | 222 | def handle(self): 223 | GPT.save_txt(context='你好!',# 文章题目 224 | steps='20',# 生成文章的字数 225 | model_path='model.bin',#模型的路径 226 | train_data_name='a.txt'#训练数据的名字 227 | ) 228 | 229 | f = open('save.txt', encoding='utf-8').read() 230 | self.send_msg(text('哒哒哒~~~生成完成:{}'.format(f))) 231 | 232 | # 这里是私发可以改为群发 233 | 234 | """ 235 | 在上面自定义你的插件 236 | """ 237 | 238 | 239 | def plugin_pool(context: dict): 240 | # 遍历所有的 Plugin 的子类,执行匹配 241 | for P in Plugin.__subclasses__(): 242 | plugin = P(context) 243 | if plugin.match(): 244 | plugin.handle() 245 | 246 | 247 | class Echo: 248 | def __init__(self): 249 | self.echo_num = 0 250 | self.echo_list = collections.deque(maxlen=20) 251 | 252 | def get(self): 253 | self.echo_num += 1 254 | q = queue.Queue(maxsize=1) 255 | self.echo_list.append((self.echo_num, q)) 256 | return self.echo_num, q 257 | 258 | def match(self, context: dict): 259 | for obj in self.echo_list: 260 | if context["echo"] == obj[0]: 261 | obj[1].put(context) 262 | 263 | 264 | def on_message(_, message): 265 | # https://github.com/botuniverse/onebot-11/blob/master/event/README.md 266 | context = json_.loads(message) 267 | if "echo" in context: 268 | logger.debug("调用返回 -> " + message) 269 | # 响应报文通过队列传递给调用 API 的函数 270 | echo.match(context) 271 | elif "meta_event_type" in context: 272 | logger.debug("心跳事件 -> " + message) 273 | else: 274 | logger.info("收到事件 -> " + message) 275 | # 消息事件,开启线程 276 | t = threading.Thread(target=plugin_pool, args=(context, )) 277 | t.start() 278 | 279 | 280 | if __name__ == "__main__": 281 | echo = Echo() 282 | WS_APP = websocket.WebSocketApp( 283 | WS_URL, 284 | on_message=on_message, 285 | on_open=lambda _: logger.debug("连接成功......"), 286 | on_close=lambda _: logger.debug("重连中......"), 287 | ) 288 | while True: # 掉线重连 289 | WS_APP.run_forever() 290 | time.sleep(5) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 FloatTech 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI-Bot 2 | - ![Star](https://img.shields.io/github/stars/FloatTech/AI-Bot) ![MIT](https://img.shields.io/github/license/FloatTech/AI-Bot) 3 | 4 | 5 | 6 | - 一个基于WATERMELON重构的GPT-2的AI-Bot 7 | - ## [go-cqhttp](https://github.com/Mrs4s/go-cqhttp/releases) 8 | - ### [FloatTech 原版bot.py](https://github.com/floattech/AI-Bot/bot.py) 9 | 10 | 11 | ## 在Binder上直接运行测试 12 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/FloatTech/AI-Bot/HEAD) 13 | ## 目前有三种实现方式(每一种实现中都自带自己所实现的```Bot.py```文件,注意:每次运行时需要先运行```go-cqhttp```再运行```bot.py```) 14 | |实现 15 | |--------------------------------------------------- 16 | | [Pytorch的CNGPT](https://github.com/StarxSky/GPT-2/tree/main/CNGPT) -------[其```Bot.py```实现](https://github.com/FloatTech/AI-Bot/blob/main/CNGPT/bot.py)| 17 | | [TF2的GPT-2](https://github.com/FloatTech/AI-Bot/tree/main/TF2_GPT-2) -------[其```Bot.py```实现](https://github.com/FloatTech/AI-Bot/blob/main/TF2_GPT-2/bot.py) 18 | | [TF1的GPT-2](https://github.com/FloatTech/AI-Bot/tree/main/TF1_GPT-2) --------[其```Bot.py```实现](https://github.com/FloatTech/AI-Bot/tree/main/TF1_GPT-2/src/bot.py)| 19 | 20 | 21 | 22 | 23 | 24 | |进度 25 | |---------------------- 26 | | 目前 [CNGPT-2](https://github.com/StarxSky/GPT-2/tree/main/CNGPT)(基于Pytorch进行重构GPT-2)现已移植到Bot上可以实现中文生成 --2022/03/19 27 | | 目前 [TF1的GPT-2](https://github.com/FloatTech/AI-Bot/tree/main/TF1_GPT-2)已经可以调用API --2022/02/03 28 | | 目前 [TF1的GPT-2](https://github.com/FloatTech/AI-Bot/tree/main/TF1_GPT-2)已经可以使用文章生成功能 --2022/02/05 29 | | 目前 [TF2的GPT-2](https://github.com/FloatTech/AI-Bot/tree/main/TF2_GPT-2)已经可以使用文章生成功能 --2022/02/08 30 | | 目前 Pytorch的GPT-2研发中 --2022/02/11 31 | 32 | ## 感谢 33 | |THANKS TO! 34 | |----------- 35 | | [OpenAI](https://github.com/openai/gpt-2) 36 | | [理理](https://github.com/Yiwen-Chan) 37 | | [皮皮佬](https://github.com/DawnNights) 38 | | [源文雨](https://github.com/fumiama) 39 | | [myr](https://github.com/MayuriNFC) 40 | # [许可证](https://github.com/FloatTech/AI-Bot/blob/main/LICENSE) 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /TF1_GPT-2/DEVELOPERS.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Git clone this repository, and `cd` into directory for remaining commands 4 | ``` 5 | git clone https://github.com/openai/gpt-2.git && cd gpt-2 6 | ``` 7 | 8 | Then, follow instructions for either native or Docker installation. 9 | 10 | ## Native Installation 11 | 12 | All steps can optionally be done in a virtual environment using tools such as `virtualenv` or `conda`. 13 | 14 | Install tensorflow 1.12 (with GPU support, if you have a GPU and want everything to run faster) 15 | ``` 16 | pip3 install tensorflow==1.12.0 17 | ``` 18 | or 19 | ``` 20 | pip3 install tensorflow-gpu==1.12.0 21 | ``` 22 | 23 | Install other python packages: 24 | ``` 25 | pip3 install -r requirements.txt 26 | ``` 27 | 28 | Download the model data 29 | ``` 30 | python3 download_model.py 124M 31 | python3 download_model.py 355M 32 | python3 download_model.py 774M 33 | python3 download_model.py 1558M 34 | ``` 35 | 36 | ## Docker Installation 37 | 38 | Build the Dockerfile and tag the created image as `gpt-2`: 39 | ``` 40 | docker build --tag gpt-2 -f Dockerfile.gpu . # or Dockerfile.cpu 41 | ``` 42 | 43 | Start an interactive bash session from the `gpt-2` docker image. 44 | 45 | You can opt to use the `--runtime=nvidia` flag if you have access to a NVIDIA GPU 46 | and a valid install of [nvidia-docker 2.0](https://github.com/nvidia/nvidia-docker/wiki/Installation-(version-2.0)). 47 | ``` 48 | docker run --runtime=nvidia -it gpt-2 bash 49 | ``` 50 | 51 | # Running 52 | 53 | | WARNING: Samples are unfiltered and may contain offensive content. | 54 | | --- | 55 | 56 | Some of the examples below may include Unicode text characters. Set the environment variable: 57 | ``` 58 | export PYTHONIOENCODING=UTF-8 59 | ``` 60 | to override the standard stream settings in UTF-8 mode. 61 | 62 | ## Unconditional sample generation 63 | 64 | To generate unconditional samples from the small model: 65 | ``` 66 | python3 src/generate_unconditional_samples.py | tee /tmp/samples 67 | ``` 68 | There are various flags for controlling the samples: 69 | ``` 70 | python3 src/generate_unconditional_samples.py --top_k 40 --temperature 0.7 | tee /tmp/samples 71 | ``` 72 | 73 | To check flag descriptions, use: 74 | ``` 75 | python3 src/generate_unconditional_samples.py -- --help 76 | ``` 77 | 78 | ## Conditional sample generation 79 | 80 | To give the model custom prompts, you can use: 81 | ``` 82 | python3 src/interactive_conditional_samples.py --top_k 40 83 | ``` 84 | 85 | To check flag descriptions, use: 86 | ``` 87 | python3 src/interactive_conditional_samples.py -- --help 88 | ``` 89 | -------------------------------------------------------------------------------- /TF1_GPT-2/README.md: -------------------------------------------------------------------------------- 1 | ## 使用TF1的OpenAI官方预训练模型的实现 2 | - Python3.6 3 | - [官方模型详情](https://github.com/FloatTech/AI-Bot/blob/main/TF1_GPT-2/DEVELOPERS.md) 4 | - 模型经过改动方便使用 5 | - [Bot.py实现](https://github.com/FloatTech/AI-Bot/blob/main/TF1_GPT-2/src/bot.py) 6 | # 使用前准备 7 | ``` 8 | 1. >> pip install tensorflow==1.15.0 9 | 2. >> pip install -r requirements.txt 10 | 3. >> python download_model.py 124M //选择模型参数大小 124M,355M,774M,1558M 11 | python download_model.py 355M 12 | python download_model.py 774M 13 | python download_model.py 1558M 14 | 15 | ``` 16 | 17 | 18 | # 使用时一定要注意将所下载预训练模型的路径添加到api实现文件模型参数的model_dir =中 19 | ```python 20 | 21 | GPT.interact_model( 22 | model_name='124M',#所下载模型的名称 23 | seed=None, 24 | nsamples=1, 25 | batch_size=1, 26 | length=None, 27 | temperature=1, 28 | top_k=0, 29 | top_p=1, 30 | models_dir='models',#更改为自己所下载的预训练模型地址 31 | input_m = ''#输入文本接口可在需要调用时定义一个变量并将其索引引用示例请看解释: 32 | ) 33 | 34 | ``` 35 | 36 | # 使用 37 | ``` 38 | 1. >> python download_model.py 124M 39 | ## 记得修改完bot.py的模型路径后再运行bot.py 40 | 2. >> python src/bot.py 41 | $ ./go-cqhttp 42 | or win >go-cqhttp.exe 43 | 44 | 45 | ``` 46 | 47 | # 运行官方Demo(若不成功需要进行改动生成模块) 48 | 49 | ``` 50 | 1.生成随机文本 51 | >> python src/generate_unconditional_samples.py 52 | 2.给定开头生成文章 53 | >> python src/interactive_conditional_samples.py --top_k 40 54 | (查看全部:) 55 | >> python src/interactive_conditional_samples.py -- --help 56 | 57 | ``` 58 | 59 | # 实现API 60 | - [API实现](https://github.com/FloatTech/AI-Bot/blob/main/TF1_GPT-2/src/interactive_conditional_samples.py) 61 | - [bot.py实现](https://github.com/FloatTech/AI-Bot/blob/main/TF1_GPT-2/src/bot.py) 62 | - ![Image](https://github.com/FloatTech/AI-Bot/blob/main/TF1_GPT-2/%E6%8D%95%E8%8E%B7.PNG?raw=true) 63 | - 注:API的功能为输入一段命题使GPT-2生成一篇文章,若想实现其他功能欢迎PR!! 64 | 65 | -------------------------------------------------------------------------------- /TF1_GPT-2/download_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | from tqdm import tqdm 5 | 6 | if len(sys.argv) != 2: 7 | print('You must enter the model name as a parameter, e.g.: download_model.py 124M') 8 | sys.exit(1) 9 | 10 | model = sys.argv[1] 11 | 12 | subdir = os.path.join('models', model) 13 | if not os.path.exists(subdir): 14 | os.makedirs(subdir) 15 | subdir = subdir.replace('\\','/') # needed for Windows 16 | 17 | for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: 18 | 19 | r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True) 20 | 21 | with open(os.path.join(subdir, filename), 'wb') as f: 22 | file_size = int(r.headers["content-length"]) 23 | chunk_size = 1000 24 | with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: 25 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 26 | for chunk in r.iter_content(chunk_size=chunk_size): 27 | f.write(chunk) 28 | pbar.update(chunk_size) 29 | -------------------------------------------------------------------------------- /TF1_GPT-2/requirements.txt: -------------------------------------------------------------------------------- 1 | fire>=0.1.3 2 | regex==2017.4.5 3 | requests==2.21.0 4 | tqdm==4.31.1 5 | tensorflow==1.15.0 6 | psutil 7 | websocket-client 8 | -------------------------------------------------------------------------------- /TF1_GPT-2/src/bot.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import queue 4 | import logging 5 | import threading 6 | import collections 7 | import json as json_ 8 | import os 9 | import psutil 10 | 11 | import websocket 12 | 13 | import interactive_conditional_samples as GPT 14 | 15 | 16 | 17 | 18 | 19 | 20 | WS_URL = "ws://127.0.0.1:6700/ws" # WebSocket 地址 21 | NICKNAME = ["BOT", "ROBOT"] # 机器人昵称 22 | SUPER_USER = [1237545454] # 主人的 QQ 号 23 | # 日志设置 level=logging.DEBUG -> 日志级别为 DEBUG 24 | logging.basicConfig(level=logging.DEBUG, format="[void] %(asctime)s - %(levelname)s - %(message)s") 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class Plugin: 29 | def __init__(self, context: dict): 30 | self.ws = WS_APP 31 | self.context = context 32 | 33 | def match(self) -> bool: 34 | return self.on_full_match("hello") 35 | 36 | def handle(self): 37 | self.send_msg(text("hello world!")) 38 | 39 | def on_message(self) -> bool: 40 | return self.context["post_type"] == "message" 41 | 42 | def on_full_match(self, keyword="") -> bool: 43 | return self.on_message() and self.context["message"] == keyword 44 | 45 | def on_reg_match(self, pattern="") -> bool: 46 | return self.on_message() and re.search(pattern, self.context["message"]) 47 | 48 | def only_to_me(self) -> bool: 49 | flag = False 50 | for nick in NICKNAME + [f"[CQ:at,qq={self.context['self_id']}] "]: 51 | if self.on_message() and nick in self.context["message"]: 52 | flag = True 53 | self.context["message"] = self.context["message"].replace(nick, "") 54 | return flag 55 | 56 | def super_user(self) -> bool: 57 | return self.context["user_id"] in SUPER_USER 58 | 59 | def admin_user(self) -> bool: 60 | return self.super_user() or self.context["sender"]["role"] in ("admin", "owner") 61 | 62 | def call_api(self, action: str, params: dict) -> dict: 63 | echo_num, q = echo.get() 64 | data = json_.dumps({"action": action, "params": params, "echo": echo_num}) 65 | logger.info("发送调用 <- " + data) 66 | self.ws.send(data) 67 | try: # 阻塞至响应或者等待30s超时 68 | return q.get(timeout=30) 69 | except queue.Empty: 70 | logger.error("API调用[{echo_num}] 超时......") 71 | 72 | def send_msg(self, *message) -> int: 73 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_msg-%E5%8F%91%E9%80%81%E6%B6%88%E6%81%AF 74 | if "group_id" in self.context and self.context["group_id"]: 75 | return self.send_group_msg(*message) 76 | else: 77 | return self.send_private_msg(*message) 78 | 79 | def send_private_msg(self, *message) -> int: 80 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_private_msg-%E5%8F%91%E9%80%81%E7%A7%81%E8%81%8A%E6%B6%88%E6%81%AF 81 | params = {"user_id": self.context["user_id"], "message": message} 82 | ret = self.call_api("send_private_msg", params) 83 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 84 | 85 | def send_group_msg(self, *message) -> int: 86 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_group_msg-%E5%8F%91%E9%80%81%E7%BE%A4%E6%B6%88%E6%81%AF 87 | params = {"group_id": self.context["group_id"], "message": message} 88 | ret = self.call_api("send_group_msg", params) 89 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 90 | 91 | 92 | 93 | 94 | def text(string: str) -> dict: 95 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E7%BA%AF%E6%96%87%E6%9C%AC 96 | return {"type": "text", "data": {"text": string}} 97 | 98 | 99 | def image(file: str, cache=True) -> dict: 100 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E5%9B%BE%E7%89%87 101 | return {"type": "image", "data": {"file": file, "cache": cache}} 102 | 103 | 104 | def record(file: str, cache=True) -> dict: 105 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E8%AF%AD%E9%9F%B3 106 | return {"type": "record", "data": {"file": file, "cache": cache}} 107 | 108 | 109 | def at(qq: int) -> dict: 110 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E6%9F%90%E4%BA%BA 111 | return {"type": "at", "data": {"qq": qq}} 112 | 113 | 114 | def xml(data: str) -> dict: 115 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#xml-%E6%B6%88%E6%81%AF 116 | return {"type": "xml", "data": {"data": data}} 117 | 118 | 119 | def json(data: str) -> dict: 120 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#json-%E6%B6%88%E6%81%AF 121 | return {"type": "json", "data": {"data": data}} 122 | 123 | 124 | def music(data: str) -> dict: 125 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E9%9F%B3%E4%B9%90%E5%88%86%E4%BA%AB- 126 | return {"type": "music", "data": {"type": "qq", "id": data}} 127 | 128 | 129 | """ 130 | 在下面加入你自定义的插件,自动加载本文件所有的 Plugin 的子类 131 | 只需要写一个 Plugin 的子类,重写 match() 和 handle() 132 | match() 返回 True 则自动回调 handle() 133 | """ 134 | 135 | 136 | class TestPlugin(Plugin): 137 | def match(self): # 说 hello 则回复 138 | return self.on_full_match("hello") 139 | 140 | def handle(self): 141 | self.send_msg(at(self.context["user_id"]), text("hello world!")) 142 | 143 | 144 | 145 | class TestPlugin2(Plugin): 146 | def match(self): # 艾特机器人说菜单则回复 147 | return self.only_to_me() and self.on_full_match("菜单") 148 | 149 | def handle(self): 150 | self.send_msg(text("没有菜单")) 151 | 152 | class ADD(Plugin): 153 | 154 | def match(self) : 155 | 156 | return self.only_to_me() and self.on_full_match("好慢啊你") 157 | 158 | def handle(self): 159 | 160 | self.send_msg(at(self.context["user_id"]),text("要不你来试试?!!呜呜呜😭")) 161 | 162 | 163 | class SELF(Plugin) : 164 | 165 | def match(self) : 166 | 167 | return self.on_full_match("检查身体") 168 | 169 | def handle(self): 170 | 171 | 172 | info = os.system('ver')#对于win用户 173 | 174 | 175 | 176 | mem = psutil.virtual_memory() 177 | # 系统总计内存 178 | All_M = float(mem.total) / 1024 / 1024 / 1024 179 | # 系统已经使用内存 180 | use_ing = float(mem.used) / 1024 / 1024 / 1024 181 | 182 | # 系统空闲内存 183 | free = float(mem.free) / 1024 / 1024 / 1024 184 | 185 | all_m = '系统总计内存:%d.3GB' % All_M 186 | Use = '系统已经使用内存:%d.3GB' % use_ing 187 | Free = '系统空闲内存:%d.3GB' % free 188 | self.send_msg(text('{}\n\n{}\n\n{}\n\n{}'.format(info,all_m,Use,Free))) 189 | 190 | 191 | 192 | 193 | 194 | 195 | class TestPlugin3(Plugin): 196 | def match(self): # 戳一戳机器人则回复 197 | return self.context["post_type"] == "notice" and self.context["sub_type"] == "poke"\ 198 | and self.context["target_id"] == self.context["self_id"] 199 | 200 | def handle(self): 201 | self.send_msg(text("请不要戳我 >_<")) 202 | 203 | 204 | class TPugin(Plugin) : 205 | def match(self) : 206 | return self.on_full_match('生成文章') 207 | 208 | def handle(self): 209 | self.send_msg(text('构思中可能需要几分钟取决于我的小脑袋.......')) 210 | 211 | 212 | 213 | 214 | class GeneratePlugin(Plugin) : 215 | def match(self) : 216 | 217 | return self.on_full_match('生成文章') 218 | 219 | def handle(self): 220 | a = 'hello'#生成文章的命题(脑累了,过几天将其与群内信息接在一起实现自定义命题) 221 | 222 | GPT.interact_model( 223 | model_name='124M',#模型名称 224 | seed=None, 225 | nsamples=1, 226 | batch_size=1, 227 | length=None, 228 | temperature=1, 229 | top_k=0, 230 | top_p=1, 231 | models_dir='C:\\Users\\xbj0916\\Desktop\\新建文件夹\\models\\',# 将这里改为你自己所通过download_model.py下载的预训练模型路径 232 | Input_m= '{}'.format(a)) 233 | 234 | f = open('s.txt',encoding='utf-8').read()#读取所生成的文本文件详情请见interact_modelsample.py 235 | self.send_private_msg(text('哒哒哒~~~生成完成:{}'.format(f)))#这里是私发可以改为群发 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | """ 244 | 在上面自定义你的插 245 | """ 246 | 247 | 248 | def plugin_pool(context: dict): 249 | # 遍历所有的 Plugin 的子类,执行匹配 250 | for P in Plugin.__subclasses__(): 251 | plugin = P(context) 252 | if plugin.match(): 253 | plugin.handle() 254 | 255 | 256 | class Echo: 257 | def __init__(self): 258 | self.echo_num = 0 259 | self.echo_list = collections.deque(maxlen=20) 260 | 261 | def get(self): 262 | self.echo_num += 1 263 | q = queue.Queue(maxsize=1) 264 | self.echo_list.append((self.echo_num, q)) 265 | return self.echo_num, q 266 | 267 | def match(self, context: dict): 268 | for obj in self.echo_list: 269 | if context["echo"] == obj[0]: 270 | obj[1].put(context) 271 | 272 | 273 | def on_message(_, message): 274 | # https://github.com/botuniverse/onebot-11/blob/master/event/README.md 275 | context = json_.loads(message) 276 | if "echo" in context: 277 | logger.debug("调用返回 -> " + message) 278 | # 响应报文通过队列传递给调用 API 的函数 279 | echo.match(context) 280 | elif "meta_event_type" in context: 281 | logger.debug("心跳事件 -> " + message) 282 | else: 283 | logger.info("收到事件 -> " + message) 284 | # 消息事件,开启线程 285 | t = threading.Thread(target=plugin_pool, args=(context, )) 286 | t.start() 287 | 288 | 289 | if __name__ == "__main__": 290 | echo = Echo() 291 | WS_APP = websocket.WebSocketApp( 292 | WS_URL, 293 | on_message=on_message, 294 | on_open=lambda _: logger.debug("连接成功......"), 295 | on_close=lambda _: logger.debug("重连中......"), 296 | ) 297 | while True: # 掉线重连 298 | WS_APP.run_forever() 299 | time.sleep(5) 300 | -------------------------------------------------------------------------------- /TF1_GPT-2/src/encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import os 4 | import json 5 | import regex as re 6 | from functools import lru_cache 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a corresponding list of unicode strings. 12 | The reversible bpe codes work on unicode strings. 13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 15 | This is a signficant percentage of your normal, say, 32K bpe vocab. 16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 17 | And avoids mapping to whitespace/control characters the bpe code barfs on. 18 | """ 19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 20 | cs = bs[:] 21 | n = 0 22 | for b in range(2**8): 23 | if b not in bs: 24 | bs.append(b) 25 | cs.append(2**8+n) 26 | n += 1 27 | cs = [chr(n) for n in cs] 28 | return dict(zip(bs, cs)) 29 | 30 | def get_pairs(word): 31 | """Return set of symbol pairs in a word. 32 | 33 | Word is represented as tuple of symbols (symbols being variable-length strings). 34 | """ 35 | pairs = set() 36 | prev_char = word[0] 37 | for char in word[1:]: 38 | pairs.add((prev_char, char)) 39 | prev_char = char 40 | return pairs 41 | 42 | class Encoder: 43 | def __init__(self, encoder, bpe_merges, errors='replace'): 44 | self.encoder = encoder 45 | self.decoder = {v:k for k,v in self.encoder.items()} 46 | self.errors = errors # how to handle errors in decoding 47 | self.byte_encoder = bytes_to_unicode() 48 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 49 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 50 | self.cache = {} 51 | 52 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 53 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 54 | 55 | def bpe(self, token): 56 | if token in self.cache: 57 | return self.cache[token] 58 | word = tuple(token) 59 | pairs = get_pairs(word) 60 | 61 | if not pairs: 62 | return token 63 | 64 | while True: 65 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 66 | if bigram not in self.bpe_ranks: 67 | break 68 | first, second = bigram 69 | new_word = [] 70 | i = 0 71 | while i < len(word): 72 | try: 73 | j = word.index(first, i) 74 | new_word.extend(word[i:j]) 75 | i = j 76 | except: 77 | new_word.extend(word[i:]) 78 | break 79 | 80 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 81 | new_word.append(first+second) 82 | i += 2 83 | else: 84 | new_word.append(word[i]) 85 | i += 1 86 | new_word = tuple(new_word) 87 | word = new_word 88 | if len(word) == 1: 89 | break 90 | else: 91 | pairs = get_pairs(word) 92 | word = ' '.join(word) 93 | self.cache[token] = word 94 | return word 95 | 96 | def encode(self, text): 97 | bpe_tokens = [] 98 | for token in re.findall(self.pat, text): 99 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 100 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 101 | return bpe_tokens 102 | 103 | def decode(self, tokens): 104 | text = ''.join([self.decoder[token] for token in tokens]) 105 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 106 | return text 107 | 108 | def get_encoder(model_name, models_dir): 109 | with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: 110 | encoder = json.load(f) 111 | with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: 112 | bpe_data = f.read() 113 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 114 | return Encoder( 115 | encoder=encoder, 116 | bpe_merges=bpe_merges, 117 | ) 118 | -------------------------------------------------------------------------------- /TF1_GPT-2/src/generate_unconditional_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import model, sample, encoder 10 | 11 | def sample_model( 12 | model_name='124M', 13 | seed=None, 14 | nsamples=0, 15 | batch_size=1, 16 | length=None, 17 | temperature=1, 18 | top_k=0, 19 | top_p=1, 20 | models_dir='models', 21 | ): 22 | """ 23 | Run the sample_model 24 | :model_name=124M : String, which model to use 25 | :seed=None : Integer seed for random number generators, fix seed to 26 | reproduce results 27 | :nsamples=0 : Number of samples to return, if 0, continues to 28 | generate samples indefinately. 29 | :batch_size=1 : Number of batches (only affects speed/memory). 30 | :length=None : Number of tokens in generated text, if None (default), is 31 | determined by model hyperparameters 32 | :temperature=1 : Float value controlling randomness in boltzmann 33 | distribution. Lower temperature results in less random completions. As the 34 | temperature approaches zero, the model will become deterministic and 35 | repetitive. Higher temperature results in more random completions. 36 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 37 | considered for each step (token), resulting in deterministic completions, 38 | while 40 means 40 words are considered at each step. 0 (default) is a 39 | special setting meaning no restrictions. 40 generally is a good value. 40 | :models_dir : path to parent folder containing model subfolders 41 | (i.e. contains the folder) 42 | """ 43 | models_dir = os.path.expanduser(os.path.expandvars(models_dir)) 44 | enc = encoder.get_encoder(model_name, models_dir) 45 | hparams = model.default_hparams() 46 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: 47 | hparams.override_from_dict(json.load(f)) 48 | 49 | if length is None: 50 | length = hparams.n_ctx 51 | elif length > hparams.n_ctx: 52 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 53 | 54 | with tf.Session(graph=tf.Graph()) as sess: 55 | np.random.seed(seed) 56 | tf.set_random_seed(seed) 57 | 58 | output = sample.sample_sequence( 59 | hparams=hparams, length=length, 60 | start_token=enc.encoder['<|endoftext|>'], 61 | batch_size=batch_size, 62 | temperature=temperature, top_k=top_k, top_p=top_p 63 | )[:, 1:] 64 | 65 | saver = tf.train.Saver() 66 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) 67 | saver.restore(sess, ckpt) 68 | 69 | generated = 0 70 | while nsamples == 0 or generated < nsamples: 71 | out = sess.run(output) 72 | for i in range(batch_size): 73 | generated += batch_size 74 | text = enc.decode(out[i]) 75 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 76 | print(text) 77 | 78 | if __name__ == '__main__': 79 | fire.Fire(sample_model) 80 | 81 | -------------------------------------------------------------------------------- /TF1_GPT-2/src/interactive_conditional_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import model, sample, encoder 10 | 11 | def interact_model( 12 | model_name='',#模型名称 13 | seed=None, 14 | nsamples=1, 15 | batch_size=1, 16 | length=None, 17 | temperature=1, 18 | top_k=0, 19 | top_p=1, 20 | models_dir='',#预训练模型路径 21 | Input_m = ''#输入文本接口可在需要调用时定义一个变量并将其索引引用示例请看解释: 22 | ): 23 | """ 24 | Interactively run the model 25 | :model_name=124M : String, which model to use 26 | :seed=None : Integer seed for random number generators, fix seed to reproduce 27 | results 28 | :nsamples=1 : Number of samples to return total 29 | :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. 30 | :length=None : Number of tokens in generated text, if None (default), is 31 | determined by model hyperparameters 32 | :temperature=1 : Float value controlling randomness in boltzmann 33 | distribution. Lower temperature results in less random completions. As the 34 | temperature approaches zero, the model will become deterministic and 35 | repetitive. Higher temperature results in more random completions. 36 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 37 | considered for each step (token), resulting in deterministic completions, 38 | while 40 means 40 words are considered at each step. 0 (default) is a 39 | special setting meaning no restrictions. 40 generally is a good value. 40 | :models_dir : path to parent folder containing model subfolders 41 | (i.e. contains the folder) 42 | 43 | ### 使用模型参数 Input_m api示例: 44 | : >>> import .... 45 | : >>> a = str(input()) 46 | : >>> interact_model(input_m = a) 47 | 48 | """ 49 | models_dir = os.path.expanduser(os.path.expandvars(models_dir)) 50 | if batch_size is None: 51 | batch_size = 1 52 | assert nsamples % batch_size == 0 53 | 54 | enc = encoder.get_encoder(model_name, models_dir) 55 | hparams = model.default_hparams() 56 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: 57 | hparams.override_from_dict(json.load(f)) 58 | 59 | if length is None: 60 | length = hparams.n_ctx // 2 61 | elif length > hparams.n_ctx: 62 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 63 | 64 | with tf.Session(graph=tf.Graph()) as sess: 65 | context = tf.placeholder(tf.int32, [batch_size, None]) 66 | np.random.seed(seed) 67 | tf.set_random_seed(seed) 68 | output = sample.sample_sequence( 69 | hparams=hparams, length=length, 70 | context=context, 71 | batch_size=batch_size, 72 | temperature=temperature, top_k=top_k, top_p=top_p 73 | ) 74 | 75 | saver = tf.train.Saver() 76 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) 77 | saver.restore(sess, ckpt) 78 | 79 | if True:#之前使用while循环 80 | raw_text = Input_m#input("Model prompt >>> ") 81 | if not raw_text:# 之前使用while循环 82 | 83 | 84 | print('Prompt should not be empty!') 85 | raw_text = Input_m#input("Model prompt >>> ") 86 | context_tokens = enc.encode(raw_text) 87 | generated = 0 88 | for _ in range(nsamples // batch_size): 89 | out = sess.run(output, feed_dict={ 90 | context: [context_tokens for _ in range(batch_size)] 91 | })[:, len(context_tokens):] 92 | for i in range(batch_size): 93 | generated += 1 94 | text = enc.decode(out[i]) 95 | 96 | f = open('s.txt','w',encoding='utf-8')#将生成文件写入并保存 97 | f.write(text) 98 | f.close() 99 | 100 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 101 | print(text) 102 | print("=" * 80) 103 | 104 | if __name__ == '__main__': 105 | fire.Fire(interact_model) 106 | 107 | -------------------------------------------------------------------------------- /TF1_GPT-2/src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.training import HParams 4 | 5 | def default_hparams(): 6 | return HParams( 7 | n_vocab=0, 8 | n_ctx=1024, 9 | n_embd=768, 10 | n_head=12, 11 | n_layer=12, 12 | ) 13 | 14 | def shape_list(x): 15 | """Deal with dynamic shape in tensorflow cleanly.""" 16 | static = x.shape.as_list() 17 | dynamic = tf.shape(x) 18 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 19 | 20 | def softmax(x, axis=-1): 21 | x = x - tf.reduce_max(x, axis=axis, keepdims=True) 22 | ex = tf.exp(x) 23 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) 24 | 25 | def gelu(x): 26 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) 27 | 28 | def norm(x, scope, *, axis=-1, epsilon=1e-5): 29 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 30 | with tf.variable_scope(scope): 31 | n_state = x.shape[-1].value 32 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) 33 | b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) 34 | u = tf.reduce_mean(x, axis=axis, keepdims=True) 35 | s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) 36 | x = (x - u) * tf.rsqrt(s + epsilon) 37 | x = x*g + b 38 | return x 39 | 40 | def split_states(x, n): 41 | """Reshape the last dimension of x into [n, x.shape[-1]/n].""" 42 | *start, m = shape_list(x) 43 | return tf.reshape(x, start + [n, m//n]) 44 | 45 | def merge_states(x): 46 | """Smash the last two dimensions of x into a single dimension.""" 47 | *start, a, b = shape_list(x) 48 | return tf.reshape(x, start + [a*b]) 49 | 50 | def conv1d(x, scope, nf, *, w_init_stdev=0.02): 51 | with tf.variable_scope(scope): 52 | *start, nx = shape_list(x) 53 | w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) 54 | b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) 55 | c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) 56 | return c 57 | 58 | def attention_mask(nd, ns, *, dtype): 59 | """1's in the lower triangle, counting from the lower right corner. 60 | 61 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 62 | """ 63 | i = tf.range(nd)[:,None] 64 | j = tf.range(ns) 65 | m = i >= j - ns + nd 66 | return tf.cast(m, dtype) 67 | 68 | 69 | def attn(x, scope, n_state, *, past, hparams): 70 | assert x.shape.ndims == 3 # Should be [batch, sequence, features] 71 | assert n_state % hparams.n_head == 0 72 | if past is not None: 73 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] 74 | 75 | def split_heads(x): 76 | # From [batch, sequence, features] to [batch, heads, sequence, features] 77 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) 78 | 79 | def merge_heads(x): 80 | # Reverse of split_heads 81 | return merge_states(tf.transpose(x, [0, 2, 1, 3])) 82 | 83 | def mask_attn_weights(w): 84 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 85 | _, _, nd, ns = shape_list(w) 86 | b = attention_mask(nd, ns, dtype=w.dtype) 87 | b = tf.reshape(b, [1, 1, nd, ns]) 88 | w = w*b - tf.cast(1e10, w.dtype)*(1-b) 89 | return w 90 | 91 | def multihead_attn(q, k, v): 92 | # q, k, v have shape [batch, heads, sequence, features] 93 | w = tf.matmul(q, k, transpose_b=True) 94 | w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) 95 | 96 | w = mask_attn_weights(w) 97 | w = softmax(w) 98 | a = tf.matmul(w, v) 99 | return a 100 | 101 | with tf.variable_scope(scope): 102 | c = conv1d(x, 'c_attn', n_state*3) 103 | q, k, v = map(split_heads, tf.split(c, 3, axis=2)) 104 | present = tf.stack([k, v], axis=1) 105 | if past is not None: 106 | pk, pv = tf.unstack(past, axis=1) 107 | k = tf.concat([pk, k], axis=-2) 108 | v = tf.concat([pv, v], axis=-2) 109 | a = multihead_attn(q, k, v) 110 | a = merge_heads(a) 111 | a = conv1d(a, 'c_proj', n_state) 112 | return a, present 113 | 114 | 115 | def mlp(x, scope, n_state, *, hparams): 116 | with tf.variable_scope(scope): 117 | nx = x.shape[-1].value 118 | h = gelu(conv1d(x, 'c_fc', n_state)) 119 | h2 = conv1d(h, 'c_proj', nx) 120 | return h2 121 | 122 | 123 | def block(x, scope, *, past, hparams): 124 | with tf.variable_scope(scope): 125 | nx = x.shape[-1].value 126 | a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) 127 | x = x + a 128 | m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) 129 | x = x + m 130 | return x, present 131 | 132 | def past_shape(*, hparams, batch_size=None, sequence=None): 133 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] 134 | 135 | def expand_tile(value, size): 136 | """Add a new axis of given size.""" 137 | value = tf.convert_to_tensor(value, name='value') 138 | ndims = value.shape.ndims 139 | return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) 140 | 141 | def positions_for(tokens, past_length): 142 | batch_size = tf.shape(tokens)[0] 143 | nsteps = tf.shape(tokens)[1] 144 | return expand_tile(past_length + tf.range(nsteps), batch_size) 145 | 146 | 147 | def model(hparams, X, past=None, scope='model', reuse=False): 148 | with tf.variable_scope(scope, reuse=reuse): 149 | results = {} 150 | batch, sequence = shape_list(X) 151 | 152 | wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], 153 | initializer=tf.random_normal_initializer(stddev=0.01)) 154 | wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], 155 | initializer=tf.random_normal_initializer(stddev=0.02)) 156 | past_length = 0 if past is None else tf.shape(past)[-2] 157 | h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) 158 | 159 | # Transformer 160 | presents = [] 161 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer 162 | assert len(pasts) == hparams.n_layer 163 | for layer, past in enumerate(pasts): 164 | h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) 165 | presents.append(present) 166 | results['present'] = tf.stack(presents, axis=1) 167 | h = norm(h, 'ln_f') 168 | 169 | # Language model loss. Do tokens > git clone https://github.com/starxsky/gpt-2 11 | 2. >> python pre_process.py 12 | 3. >> python train_gpt2.py 13 | 4. windows==> go-cqhttp.exe 14 | Linunx===> ./go-cqhttp 15 | 5. >> python bot.py 16 | 17 | # Bot.py中的API配置 18 | ```python 19 | GPT.sequence_gen( 20 | model_path = "C:\\Users\\xbj0916\\Desktop\\TF2_GPT-2\\TF2_GPT\\model\\",#只有运行完pre_process.py&train_gpt2.py才能看到 21 | model_param = "C:\\Users\\xbj0916\\Desktop\\TF2_GPT-2\\TF2_GPT\\model\\model_par.json",#只有运行完pre_process.py&train_gpt2.py才能看到 22 | vocab = "C:\\Users\\xbj0916\\Desktop\\TF2_GPT-2\\TF2_GPT\\data\\bpe_model.model",#只有运行完pre_process.py&train_gpt2.py才能看到 23 | seq_len = 512, 24 | temperature = 1, 25 | top_k = 8, 26 | top_p = 0.9, 27 | nucleus_sampling = False, 28 | context = "sample context")#文章开头标题 29 | ``` 30 | - 配置完成后在QQ只需一句“生成文章”即可 31 | ## 有关TF2_GPT-2详细信息与使用请移步至[Watermelon's TF2_GPT-2](https://github.com/starxsky/tf2_gpt-2) 32 | 33 | ## 训练 34 | 1. >> python pre_process.py 35 | 2. >> python train_gpt2.py 36 | # 直接生成文章 37 | 1. >> python sequence_generator.py 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /TF2_GPT-2/bot.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import time 4 | import queue 5 | import logging 6 | import threading 7 | import collections 8 | import json as json_ 9 | import os 10 | import psutil 11 | import websocket 12 | 13 | import numpy as np 14 | import Api as GPT 15 | 16 | 17 | 18 | 19 | 20 | 21 | WS_URL = "ws://127.0.0.1:6700/ws" # WebSocket 地址 22 | NICKNAME = ["BOT", "ROBOT"] # 机器人昵称 23 | SUPER_USER = [1237545454] # 主人的 QQ 号 24 | # 日志设置 level=logging.DEBUG -> 日志级别为 DEBUG 25 | logging.basicConfig(level=logging.DEBUG, format="[void] %(asctime)s - %(levelname)s - %(message)s") 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class Plugin: 30 | def __init__(self, context: dict): 31 | self.ws = WS_APP 32 | self.context = context 33 | 34 | def match(self) -> bool: 35 | return self.on_full_match("hello") 36 | 37 | def handle(self): 38 | self.send_msg(text("hello world!")) 39 | 40 | def on_message(self) -> bool: 41 | return self.context["post_type"] == "message" 42 | 43 | def on_full_match(self, keyword="") -> bool: 44 | return self.on_message() and self.context["message"] == keyword 45 | 46 | def on_reg_match(self, pattern="") -> bool: 47 | return self.on_message() and re.search(pattern, self.context["message"]) 48 | 49 | def only_to_me(self) -> bool: 50 | flag = False 51 | for nick in NICKNAME + [f"[CQ:at,qq={self.context['self_id']}] "]: 52 | if self.on_message() and nick in self.context["message"]: 53 | flag = True 54 | self.context["message"] = self.context["message"].replace(nick, "") 55 | return flag 56 | 57 | def super_user(self) -> bool: 58 | return self.context["user_id"] in SUPER_USER 59 | 60 | def admin_user(self) -> bool: 61 | return self.super_user() or self.context["sender"]["role"] in ("admin", "owner") 62 | 63 | def call_api(self, action: str, params: dict) -> dict: 64 | echo_num, q = echo.get() 65 | data = json_.dumps({"action": action, "params": params, "echo": echo_num}) 66 | logger.info("发送调用 <- " + data) 67 | self.ws.send(data) 68 | try: # 阻塞至响应或者等待30s超时 69 | return q.get(timeout=30) 70 | except queue.Empty: 71 | logger.error("API调用[{echo_num}] 超时......") 72 | 73 | def send_msg(self, *message) -> int: 74 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_msg-%E5%8F%91%E9%80%81%E6%B6%88%E6%81%AF 75 | if "group_id" in self.context and self.context["group_id"]: 76 | return self.send_group_msg(*message) 77 | else: 78 | return self.send_private_msg(*message) 79 | 80 | def send_private_msg(self, *message) -> int: 81 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_private_msg-%E5%8F%91%E9%80%81%E7%A7%81%E8%81%8A%E6%B6%88%E6%81%AF 82 | params = {"user_id": self.context["user_id"], "message": message} 83 | ret = self.call_api("send_private_msg", params) 84 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 85 | 86 | def send_group_msg(self, *message) -> int: 87 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_group_msg-%E5%8F%91%E9%80%81%E7%BE%A4%E6%B6%88%E6%81%AF 88 | params = {"group_id": self.context["group_id"], "message": message} 89 | ret = self.call_api("send_group_msg", params) 90 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 91 | 92 | 93 | 94 | 95 | def text(string: str) -> dict: 96 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E7%BA%AF%E6%96%87%E6%9C%AC 97 | return {"type": "text", "data": {"text": string}} 98 | 99 | 100 | def image(file: str, cache=True) -> dict: 101 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E5%9B%BE%E7%89%87 102 | return {"type": "image", "data": {"file": file, "cache": cache}} 103 | 104 | 105 | def record(file: str, cache=True) -> dict: 106 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E8%AF%AD%E9%9F%B3 107 | return {"type": "record", "data": {"file": file, "cache": cache}} 108 | 109 | 110 | def at(qq: int) -> dict: 111 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E6%9F%90%E4%BA%BA 112 | return {"type": "at", "data": {"qq": qq}} 113 | 114 | 115 | def xml(data: str) -> dict: 116 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#xml-%E6%B6%88%E6%81%AF 117 | return {"type": "xml", "data": {"data": data}} 118 | 119 | 120 | def json(data: str) -> dict: 121 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#json-%E6%B6%88%E6%81%AF 122 | return {"type": "json", "data": {"data": data}} 123 | 124 | 125 | def music(data: str) -> dict: 126 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E9%9F%B3%E4%B9%90%E5%88%86%E4%BA%AB- 127 | return {"type": "music", "data": {"type": "qq", "id": data}} 128 | 129 | 130 | """ 131 | 在下面加入你自定义的插件,自动加载本文件所有的 Plugin 的子类 132 | 只需要写一个 Plugin 的子类,重写 match() 和 handle() 133 | match() 返回 True 则自动回调 handle() 134 | """ 135 | 136 | 137 | class TestPlugin(Plugin): 138 | def match(self): # 说 hello 则回复 139 | return self.on_full_match("hello") 140 | 141 | def handle(self): 142 | self.send_msg(at(self.context["user_id"]), text("hello world!")) 143 | 144 | 145 | class f(Plugin) : 146 | def match(self): 147 | return self.on_full_match("mua~") 148 | 149 | def handle(self): 150 | self.send_msg(at(self.context["user_id"]),text("恶心🤢")) 151 | 152 | 153 | 154 | 155 | class ss(Plugin) : 156 | def match(self) : 157 | return self.on_full_match("沙比") 158 | 159 | def handle(self): 160 | 161 | po = np.random.random(1) 162 | op = np.random.random(1) 163 | if op > po : 164 | self.send_msg(at(self.context["user_id"]),text('歪!!骂谁呐!')) 165 | else : 166 | self.send_msg(at(self.context["user_id"]),text('草草....草尼🐎🐎(¬︿̫̿¬☆)不理你了')) 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | class ADD(Plugin): 178 | def match(self) : 179 | return self.only_to_me() and self.on_full_match("好慢啊你") 180 | 181 | def handle(self): 182 | 183 | self.send_msg(at(self.context["user_id"]),text("要不你来试试?!!呜呜呜😭")) 184 | 185 | 186 | 187 | 188 | 189 | class SELF(Plugin) : 190 | def match(self) : 191 | return self.on_full_match("检查身体") 192 | 193 | def handle(self): 194 | 195 | info = os.system('ver') 196 | 197 | net_work = psutil.cpu_stats() 198 | 199 | 200 | mem = psutil.virtual_memory() 201 | # 系统总计内存 202 | All_M = float(mem.total) / 1024 / 1024 / 1024 203 | # 系统已经使用内存 204 | use_ing = float(mem.used) / 1024 / 1024 / 1024 205 | 206 | # 系统空闲内存 207 | free = float(mem.free) / 1024 / 1024 / 1024 208 | 209 | all_m = '系统总计内存:%d.3GB' % All_M 210 | Use = '系统已经使用内存:%d.3GB' % use_ing 211 | Free = '系统空闲内存:%d.3GB' % free 212 | C_k = 'CPU状态:{}'.format(net_work) 213 | 214 | 215 | 216 | self.send_msg(text('{}\n\n{}\n\n{}\n\n{}\n{}'.format(info,all_m,Use,Free,C_k))) 217 | 218 | 219 | 220 | class TestPlugin3(Plugin): 221 | def match(self): # 戳一戳机器人则回复 222 | return self.context["post_type"] == "notice" and self.context["sub_type"] == "poke"\ 223 | and self.context["target_id"] == self.context["self_id"] 224 | 225 | def handle(self): 226 | k = np.random.random(1) 227 | j = np.random.random(1) 228 | x = "请不要戳我 >_<" 229 | h = "歪!!戳我干嘛!!(╯▔皿▔)╯" 230 | if k < j : 231 | self.send_msg(text(x)) 232 | else : 233 | self.send_msg(text(h)) 234 | 235 | 236 | 237 | class TPugin(Plugin) : 238 | def match(self) : 239 | return self.on_full_match('生成文章') 240 | 241 | def handle(self): 242 | self.send_msg(text('构思中可能需要几分钟,取决于我的小脑袋ε=ε=ε=(~ ̄▽ ̄)~........')) 243 | 244 | 245 | #GPT-2生成文章插件 246 | class GeneratePlugin(Plugin) : 247 | def match(self) : 248 | 249 | return self.on_full_match('生成文章') 250 | 251 | def handle(self): 252 | 253 | 254 | GPT.sequence_gen( 255 | model_path = "C:\\Users\\xbj0916\\Desktop\\TF2_GPT-2\\TF2_GPT\\model\\", 256 | model_param = "C:\\Users\\xbj0916\\Desktop\\TF2_GPT-2\\TF2_GPT\\model\\model_par.json", 257 | vocab = "C:\\Users\\xbj0916\\Desktop\\TF2_GPT-2\\TF2_GPT\\data\\bpe_model.model", 258 | seq_len = 512, 259 | temperature = 1, 260 | top_k = 8, 261 | top_p = 0.9, 262 | nucleus_sampling = False, 263 | context = "sample context")#文章题目 264 | 265 | 266 | f = open('s.txt',encoding='utf-8').read() 267 | self.send_msg(text('哒哒哒~~~生成完成:{}'.format(f))) 268 | 269 | #这里是私发可以改为群发 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | """ 278 | 在上面自定义你的插 279 | """ 280 | 281 | 282 | def plugin_pool(context: dict): 283 | # 遍历所有的 Plugin 的子类,执行匹配 284 | for P in Plugin.__subclasses__(): 285 | plugin = P(context) 286 | if plugin.match(): 287 | plugin.handle() 288 | 289 | 290 | class Echo: 291 | def __init__(self): 292 | self.echo_num = 0 293 | self.echo_list = collections.deque(maxlen=20) 294 | 295 | def get(self): 296 | self.echo_num += 1 297 | q = queue.Queue(maxsize=1) 298 | self.echo_list.append((self.echo_num, q)) 299 | return self.echo_num, q 300 | 301 | def match(self, context: dict): 302 | for obj in self.echo_list: 303 | if context["echo"] == obj[0]: 304 | obj[1].put(context) 305 | 306 | 307 | def on_message(_, message): 308 | # https://github.com/botuniverse/onebot-11/blob/master/event/README.md 309 | context = json_.loads(message) 310 | if "echo" in context: 311 | logger.debug("调用返回 -> " + message) 312 | # 响应报文通过队列传递给调用 API 的函数 313 | echo.match(context) 314 | elif "meta_event_type" in context: 315 | logger.debug("心跳事件 -> " + message) 316 | else: 317 | logger.info("收到事件 -> " + message) 318 | # 消息事件,开启线程 319 | t = threading.Thread(target=plugin_pool, args=(context, )) 320 | t.start() 321 | 322 | 323 | if __name__ == "__main__": 324 | echo = Echo() 325 | WS_APP = websocket.WebSocketApp( 326 | WS_URL, 327 | on_message=on_message, 328 | on_open=lambda _: logger.debug("连接成功......"), 329 | on_close=lambda _: logger.debug("重连中......"), 330 | ) 331 | while True: # 掉线重连 332 | WS_APP.run_forever() 333 | time.sleep(5) 334 | -------------------------------------------------------------------------------- /TF2_GPT-2/捕获.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FloatTech/AI-Bot/083cc20b48d5f4538db3629d3c1d0e6e19bb1cff/TF2_GPT-2/捕获.PNG -------------------------------------------------------------------------------- /bot.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import queue 4 | import logging 5 | import threading 6 | import collections 7 | import json as json_ 8 | 9 | import websocket 10 | 11 | WS_URL = "ws://127.0.0.1:6700/ws" # WebSocket 地址 12 | NICKNAME = ["BOT", "ROBOT"] # 机器人昵称 13 | SUPER_USER = [12345678, 23456789] # 主人的 QQ 号 14 | # 日志设置 level=logging.DEBUG -> 日志级别为 DEBUG 15 | logging.basicConfig(level=logging.DEBUG, format="[void] %(asctime)s - %(levelname)s - %(message)s") 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class Plugin: 20 | def __init__(self, context: dict): 21 | self.ws = WS_APP 22 | self.context = context 23 | 24 | def match(self) -> bool: 25 | return self.on_full_match("hello") 26 | 27 | def handle(self): 28 | self.send_msg(text("hello world!")) 29 | 30 | def on_message(self) -> bool: 31 | return self.context["post_type"] == "message" 32 | 33 | def on_full_match(self, keyword="") -> bool: 34 | return self.on_message() and self.context["message"] == keyword 35 | 36 | def on_reg_match(self, pattern="") -> bool: 37 | return self.on_message() and re.search(pattern, self.context["message"]) 38 | 39 | def only_to_me(self) -> bool: 40 | flag = False 41 | for nick in NICKNAME + [f"[CQ:at,qq={self.context['self_id']}] "]: 42 | if self.on_message() and nick in self.context["message"]: 43 | flag = True 44 | self.context["message"] = self.context["message"].replace(nick, "") 45 | return flag 46 | 47 | def super_user(self) -> bool: 48 | return self.context["user_id"] in SUPER_USER 49 | 50 | def admin_user(self) -> bool: 51 | return self.super_user() or self.context["sender"]["role"] in ("admin", "owner") 52 | 53 | def call_api(self, action: str, params: dict) -> dict: 54 | echo_num, q = echo.get() 55 | data = json_.dumps({"action": action, "params": params, "echo": echo_num}) 56 | logger.info("发送调用 <- " + data) 57 | self.ws.send(data) 58 | try: # 阻塞至响应或者等待30s超时 59 | return q.get(timeout=30) 60 | except queue.Empty: 61 | logger.error("API调用[{echo_num}] 超时......") 62 | 63 | def send_msg(self, *message) -> int: 64 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_msg-%E5%8F%91%E9%80%81%E6%B6%88%E6%81%AF 65 | if "group_id" in self.context and self.context["group_id"]: 66 | return self.send_group_msg(*message) 67 | else: 68 | return self.send_private_msg(*message) 69 | 70 | def send_private_msg(self, *message) -> int: 71 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_private_msg-%E5%8F%91%E9%80%81%E7%A7%81%E8%81%8A%E6%B6%88%E6%81%AF 72 | params = {"user_id": self.context["user_id"], "message": message} 73 | ret = self.call_api("send_private_msg", params) 74 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 75 | 76 | def send_group_msg(self, *message) -> int: 77 | # https://github.com/botuniverse/onebot-11/blob/master/api/public.md#send_group_msg-%E5%8F%91%E9%80%81%E7%BE%A4%E6%B6%88%E6%81%AF 78 | params = {"group_id": self.context["group_id"], "message": message} 79 | ret = self.call_api("send_group_msg", params) 80 | return 0 if ret is None or ret["status"] == "failed" else ret["data"]["message_id"] 81 | 82 | 83 | def text(string: str) -> dict: 84 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E7%BA%AF%E6%96%87%E6%9C%AC 85 | return {"type": "text", "data": {"text": string}} 86 | 87 | 88 | def image(file: str, cache=True) -> dict: 89 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E5%9B%BE%E7%89%87 90 | return {"type": "image", "data": {"file": file, "cache": cache}} 91 | 92 | 93 | def record(file: str, cache=True) -> dict: 94 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E8%AF%AD%E9%9F%B3 95 | return {"type": "record", "data": {"file": file, "cache": cache}} 96 | 97 | 98 | def at(qq: int) -> dict: 99 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E6%9F%90%E4%BA%BA 100 | return {"type": "at", "data": {"qq": qq}} 101 | 102 | 103 | def xml(data: str) -> dict: 104 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#xml-%E6%B6%88%E6%81%AF 105 | return {"type": "xml", "data": {"data": data}} 106 | 107 | 108 | def json(data: str) -> dict: 109 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#json-%E6%B6%88%E6%81%AF 110 | return {"type": "json", "data": {"data": data}} 111 | 112 | 113 | def music(data: str) -> dict: 114 | # https://github.com/botuniverse/onebot-11/blob/master/message/segment.md#%E9%9F%B3%E4%B9%90%E5%88%86%E4%BA%AB- 115 | return {"type": "music", "data": {"type": "qq", "id": data}} 116 | 117 | 118 | """ 119 | 在下面加入你自定义的插件,自动加载本文件所有的 Plugin 的子类 120 | 只需要写一个 Plugin 的子类,重写 match() 和 handle() 121 | match() 返回 True 则自动回调 handle() 122 | """ 123 | 124 | 125 | class TestPlugin(Plugin): 126 | def match(self): # 说 hello 则回复 127 | return self.on_full_match("hello") 128 | 129 | def handle(self): 130 | self.send_msg(at(self.context["user_id"]), text("hello world!")) 131 | 132 | 133 | class TestPlugin2(Plugin): 134 | def match(self): # 艾特机器人说菜单则回复 135 | return self.only_to_me() and self.on_full_match("菜单") 136 | 137 | def handle(self): 138 | self.send_msg(text("没有菜单")) 139 | 140 | 141 | class TestPlugin3(Plugin): 142 | def match(self): # 戳一戳机器人则回复 143 | return self.context["post_type"] == "notice" and self.context["sub_type"] == "poke"\ 144 | and self.context["target_id"] == self.context["self_id"] 145 | 146 | def handle(self): 147 | self.send_msg(text("请不要戳我 >_<")) 148 | 149 | 150 | """ 151 | 在上面自定义你的插件 152 | """ 153 | 154 | 155 | def plugin_pool(context: dict): 156 | # 遍历所有的 Plugin 的子类,执行匹配 157 | for P in Plugin.__subclasses__(): 158 | plugin = P(context) 159 | if plugin.match(): 160 | plugin.handle() 161 | 162 | 163 | class Echo: 164 | def __init__(self): 165 | self.echo_num = 0 166 | self.echo_list = collections.deque(maxlen=20) 167 | 168 | def get(self): 169 | self.echo_num += 1 170 | q = queue.Queue(maxsize=1) 171 | self.echo_list.append((self.echo_num, q)) 172 | return self.echo_num, q 173 | 174 | def match(self, context: dict): 175 | for obj in self.echo_list: 176 | if context["echo"] == obj[0]: 177 | obj[1].put(context) 178 | 179 | 180 | def on_message(_, message): 181 | # https://github.com/botuniverse/onebot-11/blob/master/event/README.md 182 | context = json_.loads(message) 183 | if "echo" in context: 184 | logger.debug("调用返回 -> " + message) 185 | # 响应报文通过队列传递给调用 API 的函数 186 | echo.match(context) 187 | elif "meta_event_type" in context: 188 | logger.debug("心跳事件 -> " + message) 189 | else: 190 | logger.info("收到事件 -> " + message) 191 | # 消息事件,开启线程 192 | t = threading.Thread(target=plugin_pool, args=(context, )) 193 | t.start() 194 | 195 | 196 | if __name__ == "__main__": 197 | echo = Echo() 198 | WS_APP = websocket.WebSocketApp( 199 | WS_URL, 200 | on_message=on_message, 201 | on_open=lambda _: logger.debug("连接成功......"), 202 | on_close=lambda _: logger.debug("重连中......"), 203 | ) 204 | while True: # 掉线重连 205 | WS_APP.run_forever() 206 | time.sleep(5) 207 | --------------------------------------------------------------------------------