├── data └── .gitkeep ├── cache └── .gitkeep ├── chat_web ├── bot.png ├── user.png ├── bot-mini.png ├── speech.js ├── index.html ├── style.css ├── constants.js └── index.js ├── flask_chat ├── chat_server.py ├── semantic_sort.py └── chat.py └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cache/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chat_web/bot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dataaug/chatbot_multiround/HEAD/chat_web/bot.png -------------------------------------------------------------------------------- /chat_web/user.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dataaug/chatbot_multiround/HEAD/chat_web/user.png -------------------------------------------------------------------------------- /chat_web/bot-mini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dataaug/chatbot_multiround/HEAD/chat_web/bot-mini.png -------------------------------------------------------------------------------- /chat_web/speech.js: -------------------------------------------------------------------------------- 1 | // Text to Speech 2 | 3 | const synth = window.speechSynthesis; 4 | 5 | const textToSpeech = (string) => { 6 | let voice = new SpeechSynthesisUtterance(string); 7 | voice.text = string; 8 | voice.lang = "zh-CN"; 9 | voice.volume = 1; 10 | voice.rate = 1; 11 | voice.pitch = 1; // Can be 0, 1, or 2 12 | synth.speak(voice); 13 | } -------------------------------------------------------------------------------- /chat_web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Chatbot 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 |
14 |
15 | 16 |
17 | Robot cartoon 18 |
19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /flask_chat/chat_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flask_cors import CORS 3 | from flask import Flask, request, jsonify 4 | from chat import Generator 5 | 6 | # 需要在不同目录import 7 | import sys 8 | 9 | app = Flask(__name__) 10 | CORS(app, resources=r'/*') 11 | 12 | 13 | engine = Generator() 14 | sents = [] 15 | 16 | @app.route("/", methods=["GET", 'POST']) 17 | def hello_world(): 18 | if request.method == 'POST': 19 | print('文本生成暂不支持图片和视频请求') 20 | return 21 | 22 | else: # GET方法 代表是关键词请求 23 | key_word = request.args.get("key_word") 24 | industry = request.args.get("industry") 25 | global sents 26 | sents.append(key_word) 27 | print('当前对话:', sents) 28 | res = engine.generate(sents) 29 | if '[CLS]' in res[0] or key_word.lower() == 'exit': #机器人结束对话或者出现结束代码 30 | res = ['无言以对......请开始下一个话题'] 31 | with open('chat_saved.txt', 'a', encoding='utf-8') as fw: 32 | fw.write('[SEP]'.join(sents) + '\n') 33 | sents = [] 34 | else: 35 | sents.append(res[0]) 36 | print(res) 37 | 38 | return jsonify(list_of_data=res) 39 | 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 多轮对话机器人 2 | ### 该仓库提供训练数据,训练代码,推断代码以及可交互的网页 3 | ### 多轮中文聊天机器人,采用GPT2进行微调,清洗聊天数据110w+,采用语义相似度和文本jaccard相似度过滤回话。 4 | 5 | ## 推断代码 (快速开始) 6 | 训练好的模型 [百度网盘](https://pan.baidu.com/s/1vuZM4Amjz8qY-Tjz4yhACg) 提取码:3nup 将模型下载解压后存入cache,执行以下命令 7 | ```bash 8 | cd chat_web 9 | python3 chat.py 10 | ``` 11 | 即可和聊天机器人交互 12 | 13 | 14 | ## 对话网页 15 | 改自项目:https://github.com/sylviapap/chatbot 16 | 17 | 18 | 对话网页可以提供更友好的AI,运行这部分代码可以快速部署AI聊天机器人服务 19 | 20 | 21 | 首先,开启flask服务 22 | 23 | ```bash 24 | cd flask_chat 25 | export FLASK_APP=chat_server 26 | nohup flask run --host=0.0.0.0 -p 5000 > flask.log 2>&1 & 27 | ``` 28 | 点开 chat_web/index.html就可以看到可交互网页了 29 | 30 | 31 | 如果需要公网访问架设的聊天服务,请将 index.js 中相关ip改为你的服务器ip 32 | 33 | ## 训练数据 34 | 清洗自 https://github.com/yangjianxin1/GPT2-chitchat 中的100w聊天语料,以及 https://github.com/codemayq/chinese_chatbot_corpus 中的贴吧及电视剧对白语料,仅保留3段以上的对话。请前往 [百度网盘](https://pan.baidu.com/s/1hUHtzIxZS4U6GGuE4TrEtw) 提取码:7rm7 下载后存入data文件夹下 35 | 36 | ## 训练代码 37 | 改自huggingface源码 https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py 38 | ```bash 39 | python3 run_clm.py \ 40 | --model_name_or_path uer/gpt2-distil-chinese-cluecorpussmall \ 41 | --train_file data/merge_train.txt \ 42 | --do_train \ 43 | --do_eval \ 44 | --output_dir cache/new \ 45 | --overwrite_output_dir 46 | ``` 47 | 48 | 49 | ## 技术细节 50 | 采用GPT-2训练,训练数据格式如下 51 | ``` 52 | [CLS]天气真不错[SEP]你也很不错[SEP]你真会夸人[SEP]过奖了[SEP] 53 | [CLS]你好啊[SEP]你好[SEP]吃了[SEP]那我们一起出去玩吧[SEP] 54 | ``` 55 | 推断过程根据用户输入采样30条回应(topp采样),这个回应将通过roberta模型与用户发言进行语义排序,希望选择语义相似度最高的句子。 56 | 57 | 58 | 与此同时,计算用户发言和候选语句jaccard相似度,将语义相似度 + (1 - jaccard相似分)就得到机器回复的排序。选择最高排序结果返回。 59 | 60 | TODO: 完善依赖,测试环境 61 | 62 | -------------------------------------------------------------------------------- /chat_web/style.css: -------------------------------------------------------------------------------- 1 | * { 2 | box-sizing: border-box; 3 | } 4 | 5 | html { 6 | height: 100%; 7 | 8 | } 9 | 10 | body { 11 | font-family: 'Roboto', 'Oxygen', 12 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial, Helvetica, 13 | sans-serif; 14 | -webkit-font-smoothing: antialiased; 15 | -moz-osx-font-smoothing: grayscale; 16 | background-color: rgba(223, 242, 247, .5); 17 | height: 100%; 18 | margin: 0; 19 | } 20 | 21 | span { 22 | padding-right: 15px; 23 | padding-left: 15px; 24 | } 25 | 26 | .container { 27 | display: flex; 28 | justify-content: center; 29 | align-items: center; 30 | width: 100%; 31 | height: 100%; 32 | } 33 | 34 | .chat { 35 | height: 300px; 36 | width: 50vw; 37 | display: flex; 38 | flex-direction: column; 39 | justify-content: center; 40 | align-items: center; 41 | } 42 | 43 | ::-webkit-input-placeholder { 44 | color: .711 45 | } 46 | 47 | input { 48 | border: 0; 49 | padding: 15px; 50 | margin-left: auto; 51 | border-radius: 10px; 52 | } 53 | 54 | .messages { 55 | display: flex; 56 | flex-direction: column; 57 | overflow: scroll; 58 | height: 90%; 59 | width: 100%; 60 | background-color: white; 61 | padding: 15px; 62 | margin: 15px; 63 | border-radius: 10px; 64 | } 65 | 66 | #bot { 67 | margin-left: auto; 68 | } 69 | 70 | .bot { 71 | font-family: Consolas, 'Courier New', Menlo, source-code-pro, Monaco, 72 | monospace; 73 | } 74 | 75 | .avatar { 76 | height: 25px; 77 | } 78 | 79 | .response { 80 | display: flex; 81 | align-items: center; 82 | margin: 1%; 83 | } 84 | 85 | 86 | /* Mobile */ 87 | 88 | @media only screen and (max-width: 980px) { 89 | .container { 90 | flex-direction: column; 91 | justify-content: flex-start; 92 | } 93 | .chat { 94 | width: 75vw; 95 | margin: 10vw; 96 | } 97 | } -------------------------------------------------------------------------------- /flask_chat/semantic_sort.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments 2 | import torch 3 | import requests 4 | import copy 5 | import re 6 | 7 | 8 | class Semantic_sort(): 9 | def __init__(self) -> None: 10 | # 存储所有句子的句向量 queries 11 | self.tokenizer = AutoTokenizer.from_pretrained( 12 | 'hfl/chinese-roberta-wwm-ext') 13 | self.model = AutoModelForSequenceClassification.from_pretrained( 14 | 'hfl/chinese-roberta-wwm-ext') #.cuda() 15 | self.model.eval() 16 | 17 | def get_tokens(self, lines): # 输入句子列表 输出CLS token 列表 18 | inputs = self.tokenizer(lines, return_tensors="pt", truncation=True, 19 | padding='max_length', max_length=40) #.to('cuda') 20 | outputs = self.model(**inputs, output_hidden_states=True) 21 | last_hidden = outputs[1][-1] # hidden states的最后一层输出 22 | cls_token = last_hidden[:, 0, :] 23 | return cls_token 24 | 25 | def mask_sents(self, key_word='', sents=[]): 26 | assert key_word, sents 27 | for i, sent in enumerate(sents): 28 | sent = re.sub('|'.join(key_word.split(' ')), '[MASK]', sent) 29 | sents[i] = sent 30 | 31 | return sents 32 | 33 | def sem_sort(self, keyword='', sentences=[], mask=False): 34 | assert keyword, sentences 35 | 36 | recalls = sentences 37 | 38 | tmp = copy.deepcopy(recalls) # 暂存 以便不展现mask后结果 39 | if mask: 40 | recalls = self.mask_sents(keyword, recalls) 41 | print('keyword', keyword) 42 | print('recalls', recalls) 43 | # 同时传入keyword和句子 避免两次前传播 44 | with torch.no_grad(): 45 | queries = self.get_tokens([keyword] + recalls) 46 | key = queries[0] 47 | queries = queries[1:] 48 | 49 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) 50 | cos_sim = cos(queries, key.unsqueeze(0)) 51 | 52 | cos_sim = [float(x) for x in cos_sim] 53 | 54 | sent_score = list(zip(tmp, cos_sim)) 55 | # sent_score.sort(key=lambda x: x[0], reverse=True) 56 | return sent_score 57 | 58 | 59 | if __name__ == '__main__': 60 | sem_sorter = Semantic_sort() 61 | res = sem_sorter.sem_sort('电视', ['老电视', '新电视', '电视柜']) 62 | print(res) 63 | -------------------------------------------------------------------------------- /chat_web/constants.js: -------------------------------------------------------------------------------- 1 | // Options the user could type in 2 | const prompts = [ 3 | ["hi", "hey", "hello", "good morning", "good afternoon"], 4 | ["how are you", "how is life", "how are things"], 5 | ["what are you doing", "what is going on", "what is up"], 6 | ["how old are you"], 7 | ["who are you", "are you human", "are you bot", "are you human or bot"], 8 | ["who created you", "who made you"], 9 | [ 10 | "your name please", 11 | "your name", 12 | "may i know your name", 13 | "what is your name", 14 | "what call yourself" 15 | ], 16 | ["i love you"], 17 | ["happy", "good", "fun", "wonderful", "fantastic", "cool"], 18 | ["bad", "bored", "tired"], 19 | ["help me", "tell me story", "tell me joke"], 20 | ["ah", "yes", "ok", "okay", "nice"], 21 | ["bye", "good bye", "goodbye", "see you later"], 22 | ["what should i eat today"], 23 | ["bro"], 24 | ["what", "why", "how", "where", "when"], 25 | ["no","not sure","maybe","no thanks"], 26 | [""], 27 | ["haha","ha","lol","hehe","funny","joke"] 28 | ] 29 | 30 | // Possible responses, in corresponding order 31 | 32 | const replies = [ 33 | ["Hello!", "Hi!", "Hey!", "Hi there!","Howdy"], 34 | [ 35 | "Fine... how are you?", 36 | "Pretty well, how are you?", 37 | "Fantastic, how are you?" 38 | ], 39 | [ 40 | "Nothing much", 41 | "About to go to sleep", 42 | "Can you guess?", 43 | "I don't know actually" 44 | ], 45 | ["I am infinite"], 46 | ["I am just a bot", "I am a bot. What are you?"], 47 | ["The one true God, JavaScript"], 48 | ["I am nameless", "I don't have a name"], 49 | ["I love you too", "Me too"], 50 | ["Have you ever felt bad?", "Glad to hear it"], 51 | ["Why?", "Why? You shouldn't!", "Try watching TV"], 52 | ["What about?", "Once upon a time..."], 53 | ["Tell me a story", "Tell me a joke", "Tell me about yourself"], 54 | ["Bye", "Goodbye", "See you later"], 55 | ["Sushi", "Pizza"], 56 | ["Bro!"], 57 | ["Great question"], 58 | ["That's ok","I understand","What do you want to talk about?"], 59 | ["Please say something :("], 60 | ["Haha!","Good one!"] 61 | ] 62 | 63 | // Random for any other user input 64 | 65 | const alternative = [ 66 | "Same", 67 | "Go on...", 68 | "Bro...", 69 | "Try again", 70 | "I'm listening...", 71 | "I don't understand :/" 72 | ] 73 | 74 | // Whatever else you want :) 75 | 76 | const coronavirus = ["Please stay home", "Wear a mask", "Fortunately, I don't have COVID", "These are uncertain times"] -------------------------------------------------------------------------------- /flask_chat/chat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from datetime import datetime 3 | from elasticsearch import Elasticsearch 4 | import os 5 | import jieba 6 | import re 7 | import sys 8 | import json 9 | import sys 10 | from transformers import pipeline, set_seed 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | import torch.nn as nn 13 | import torch 14 | import requests 15 | from semantic_sort import Semantic_sort 16 | 17 | 18 | class Generator(): 19 | def __init__(self) -> None: 20 | # set_seed(42) 21 | model_cache = "../cache/checkpoint-download" 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_cache) 23 | self.model = AutoModelForCausalLM.from_pretrained(model_cache)# .cuda() 暂不用显卡 24 | self.semantic = Semantic_sort().sem_sort 25 | 26 | def cal_jaccard(self, text1, text2): 27 | text1 = set(text1) 28 | text2 = set(text2) 29 | same = text1.intersection(text2) 30 | return float(len(same))/(len(text1)+len(text2)-len(same)) 31 | 32 | # 生成 33 | def generate(self, key_word, FIN = '', SIN = ''): # TODO 使用行业信息 34 | def cut_SEP(input_ids): 35 | return input_ids[:,:-1] 36 | 37 | # set seed to reproduce results. Feel free to change the seed though to get different results 38 | with torch.no_grad(): 39 | input_ids = self.tokenizer.encode('[SEP]'.join(key_word), return_tensors='pt')# .cuda() 40 | print('输入:', self.tokenizer.decode(input_ids[0]), input_ids[0]) 41 | # input_ids = cut_SEP(input_ids) 42 | print('输入:', self.tokenizer.decode(input_ids[0])) 43 | input_record = self.tokenizer.decode(input_ids[0]) 44 | 45 | # torch.manual_seed(2) 46 | num_return_sequences = 30 47 | # activate sampling and deactivate top_k by setting top_k sampling to 0 48 | 49 | sample_output = self.model.generate( 50 | input_ids, 51 | do_sample=True, 52 | max_length=len(input_record) + 30, 53 | top_k=0, 54 | top_p=0.92, 55 | num_return_sequences=num_return_sequences, 56 | # eos_token_id = 101 57 | # temperature=0.7 58 | ) 59 | 60 | seqs = [] 61 | len_prefix = len(input_record) 62 | print('input_record',input_record) 63 | print(len_prefix) 64 | print("Output:\n" + 100 * '-') 65 | for i in range(num_return_sequences): 66 | res = self.tokenizer.decode(sample_output[i], skip_special_tokens=False) 67 | print(res) 68 | res = res[len_prefix:] 69 | print(res) 70 | print(re.findall('(.*?)\[SEP\]',res)) 71 | try: 72 | seqs.append(re.findall('(.*?)\[SEP\]',res)[0]) 73 | except: 74 | print('输出出现异常:') 75 | print(res) 76 | pass 77 | # print(res) 78 | seqs = list(set(seqs)) 79 | seqs = self.semantic(key_word[-1], seqs) 80 | # seqs = [x for x in seqs if self.cal_jaccard(key_word[-1], x[0]) < 0.8] 81 | seqs = [[x[0], x[1] + (1 - self.cal_jaccard(key_word[-1], x[0]))] for x in seqs] 82 | seqs.sort(key = lambda x: x[1], reverse = True) 83 | print(seqs) 84 | seqs = [re.sub(' ','',x[0]) for x in seqs] 85 | seqs = [seqs[0]] 86 | print(seqs) 87 | return list(set(seqs)) 88 | 89 | 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | engine = Generator() 95 | sents = [] 96 | while True: 97 | key_word = input("发言:") 98 | sents.append(key_word) 99 | res = engine.generate(sents) 100 | if '[CLS]' in res[0]: 101 | print('无言以对......') 102 | sents = [] 103 | else: 104 | sents.append(res[0]) 105 | print(res) 106 | -------------------------------------------------------------------------------- /chat_web/index.js: -------------------------------------------------------------------------------- 1 | document.addEventListener("DOMContentLoaded", () => { 2 | const inputField = document.getElementById("input"); 3 | inputField.addEventListener("keydown", (e) => { 4 | if (e.code === "Enter") { 5 | let input = inputField.value; 6 | inputField.value = ""; 7 | // output(input); 8 | console.log('数据传入:' + input) 9 | res = search('nothing' ,input); 10 | console.log('机器回应:' + res) 11 | addChat(input, res) 12 | } 13 | }); 14 | }); 15 | 16 | function search(industry, key_word) { 17 | var data = { 18 | key_word: key_word, 19 | industry: industry 20 | }; 21 | 22 | var arr = []; 23 | $.ajax({ 24 | type: 'GET', 25 | url: 'http://127.0.0.1:5000/', // 本地flask 对话服务 26 | data: data, 27 | dataType: 'json', 28 | async: false,//同步 29 | success: function (data) { 30 | arr = data.list_of_data 31 | }, 32 | error: function (xhr, type) { 33 | } 34 | }); 35 | return arr[0]; 36 | } 37 | 38 | 39 | function output(input) { 40 | let product; 41 | 42 | // Regex remove non word/space chars 43 | // Trim trailing whitespce 44 | // Remove digits - not sure if this is best 45 | // But solves problem of entering something like 'hi1' 46 | 47 | let text = input.toLowerCase().replace(/[^\w\s]/gi, "").replace(/[\d]/gi, "").trim(); 48 | text = text 49 | .replace(/ a /g, " ") // 'tell me a story' -> 'tell me story' 50 | .replace(/i feel /g, "") 51 | .replace(/whats/g, "what is") 52 | .replace(/please /g, "") 53 | .replace(/ please/g, "") 54 | .replace(/r u/g, "are you"); 55 | 56 | if (compare(prompts, replies, text)) { 57 | // Search for exact match in `prompts` 58 | product = compare(prompts, replies, text); 59 | } else if (text.match(/thank/gi)) { 60 | product = "You're welcome!" 61 | } else if (text.match(/(corona|covid|virus)/gi)) { 62 | // If no match, check if message contains `coronavirus` 63 | product = coronavirus[Math.floor(Math.random() * coronavirus.length)]; 64 | } else { 65 | // If all else fails: random alternative 66 | product = alternative[Math.floor(Math.random() * alternative.length)]; 67 | } 68 | 69 | // Update DOM 70 | addChat(input, product); 71 | } 72 | 73 | function compare(promptsArray, repliesArray, string) { 74 | let reply; 75 | let replyFound = false; 76 | for (let x = 0; x < promptsArray.length; x++) { 77 | for (let y = 0; y < promptsArray[x].length; y++) { 78 | if (promptsArray[x][y] === string) { 79 | let replies = repliesArray[x]; 80 | reply = replies[Math.floor(Math.random() * replies.length)]; 81 | replyFound = true; 82 | // Stop inner loop when input value matches prompts 83 | break; 84 | } 85 | } 86 | if (replyFound) { 87 | // Stop outer loop when reply is found instead of interating through the entire array 88 | break; 89 | } 90 | } 91 | return reply; 92 | } 93 | 94 | function addChat(input, product) { 95 | const messagesContainer = document.getElementById("messages"); 96 | 97 | let userDiv = document.createElement("div"); 98 | userDiv.id = "user"; 99 | userDiv.className = "user response"; 100 | userDiv.innerHTML = `${input}`; 101 | messagesContainer.appendChild(userDiv); 102 | 103 | let botDiv = document.createElement("div"); 104 | let botImg = document.createElement("img"); 105 | let botText = document.createElement("span"); 106 | botDiv.id = "bot"; 107 | botImg.src = "bot-mini.png"; 108 | botImg.className = "avatar"; 109 | botDiv.className = "bot response"; 110 | botText.innerText = "Typing..."; 111 | botDiv.appendChild(botText); 112 | botDiv.appendChild(botImg); 113 | messagesContainer.appendChild(botDiv); 114 | // Keep messages at most recent 115 | messagesContainer.scrollTop = messagesContainer.scrollHeight - messagesContainer.clientHeight; 116 | 117 | // Fake delay to seem "real" 118 | setTimeout(() => { 119 | botText.innerText = `${product}`; 120 | textToSpeech(product) 121 | }, 2000 122 | ) 123 | 124 | } --------------------------------------------------------------------------------