├── data
└── .gitkeep
├── cache
└── .gitkeep
├── chat_web
├── bot.png
├── user.png
├── bot-mini.png
├── speech.js
├── index.html
├── style.css
├── constants.js
└── index.js
├── flask_chat
├── chat_server.py
├── semantic_sort.py
└── chat.py
└── README.md
/data/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cache/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chat_web/bot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dataaug/chatbot_multiround/HEAD/chat_web/bot.png
--------------------------------------------------------------------------------
/chat_web/user.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dataaug/chatbot_multiround/HEAD/chat_web/user.png
--------------------------------------------------------------------------------
/chat_web/bot-mini.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dataaug/chatbot_multiround/HEAD/chat_web/bot-mini.png
--------------------------------------------------------------------------------
/chat_web/speech.js:
--------------------------------------------------------------------------------
1 | // Text to Speech
2 |
3 | const synth = window.speechSynthesis;
4 |
5 | const textToSpeech = (string) => {
6 | let voice = new SpeechSynthesisUtterance(string);
7 | voice.text = string;
8 | voice.lang = "zh-CN";
9 | voice.volume = 1;
10 | voice.rate = 1;
11 | voice.pitch = 1; // Can be 0, 1, or 2
12 | synth.speak(voice);
13 | }
--------------------------------------------------------------------------------
/chat_web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Chatbot
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
17 |

18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/flask_chat/chat_server.py:
--------------------------------------------------------------------------------
1 | import os
2 | from flask_cors import CORS
3 | from flask import Flask, request, jsonify
4 | from chat import Generator
5 |
6 | # 需要在不同目录import
7 | import sys
8 |
9 | app = Flask(__name__)
10 | CORS(app, resources=r'/*')
11 |
12 |
13 | engine = Generator()
14 | sents = []
15 |
16 | @app.route("/", methods=["GET", 'POST'])
17 | def hello_world():
18 | if request.method == 'POST':
19 | print('文本生成暂不支持图片和视频请求')
20 | return
21 |
22 | else: # GET方法 代表是关键词请求
23 | key_word = request.args.get("key_word")
24 | industry = request.args.get("industry")
25 | global sents
26 | sents.append(key_word)
27 | print('当前对话:', sents)
28 | res = engine.generate(sents)
29 | if '[CLS]' in res[0] or key_word.lower() == 'exit': #机器人结束对话或者出现结束代码
30 | res = ['无言以对......请开始下一个话题']
31 | with open('chat_saved.txt', 'a', encoding='utf-8') as fw:
32 | fw.write('[SEP]'.join(sents) + '\n')
33 | sents = []
34 | else:
35 | sents.append(res[0])
36 | print(res)
37 |
38 | return jsonify(list_of_data=res)
39 |
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 多轮对话机器人
2 | ### 该仓库提供训练数据,训练代码,推断代码以及可交互的网页
3 | ### 多轮中文聊天机器人,采用GPT2进行微调,清洗聊天数据110w+,采用语义相似度和文本jaccard相似度过滤回话。
4 |
5 | ## 推断代码 (快速开始)
6 | 训练好的模型 [百度网盘](https://pan.baidu.com/s/1vuZM4Amjz8qY-Tjz4yhACg) 提取码:3nup 将模型下载解压后存入cache,执行以下命令
7 | ```bash
8 | cd chat_web
9 | python3 chat.py
10 | ```
11 | 即可和聊天机器人交互
12 |
13 |
14 | ## 对话网页
15 | 改自项目:https://github.com/sylviapap/chatbot
16 |
17 |
18 | 对话网页可以提供更友好的AI,运行这部分代码可以快速部署AI聊天机器人服务
19 |
20 |
21 | 首先,开启flask服务
22 |
23 | ```bash
24 | cd flask_chat
25 | export FLASK_APP=chat_server
26 | nohup flask run --host=0.0.0.0 -p 5000 > flask.log 2>&1 &
27 | ```
28 | 点开 chat_web/index.html就可以看到可交互网页了
29 |
30 |
31 | 如果需要公网访问架设的聊天服务,请将 index.js 中相关ip改为你的服务器ip
32 |
33 | ## 训练数据
34 | 清洗自 https://github.com/yangjianxin1/GPT2-chitchat 中的100w聊天语料,以及 https://github.com/codemayq/chinese_chatbot_corpus 中的贴吧及电视剧对白语料,仅保留3段以上的对话。请前往 [百度网盘](https://pan.baidu.com/s/1hUHtzIxZS4U6GGuE4TrEtw) 提取码:7rm7 下载后存入data文件夹下
35 |
36 | ## 训练代码
37 | 改自huggingface源码 https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
38 | ```bash
39 | python3 run_clm.py \
40 | --model_name_or_path uer/gpt2-distil-chinese-cluecorpussmall \
41 | --train_file data/merge_train.txt \
42 | --do_train \
43 | --do_eval \
44 | --output_dir cache/new \
45 | --overwrite_output_dir
46 | ```
47 |
48 |
49 | ## 技术细节
50 | 采用GPT-2训练,训练数据格式如下
51 | ```
52 | [CLS]天气真不错[SEP]你也很不错[SEP]你真会夸人[SEP]过奖了[SEP]
53 | [CLS]你好啊[SEP]你好[SEP]吃了[SEP]那我们一起出去玩吧[SEP]
54 | ```
55 | 推断过程根据用户输入采样30条回应(topp采样),这个回应将通过roberta模型与用户发言进行语义排序,希望选择语义相似度最高的句子。
56 |
57 |
58 | 与此同时,计算用户发言和候选语句jaccard相似度,将语义相似度 + (1 - jaccard相似分)就得到机器回复的排序。选择最高排序结果返回。
59 |
60 | TODO: 完善依赖,测试环境
61 |
62 |
--------------------------------------------------------------------------------
/chat_web/style.css:
--------------------------------------------------------------------------------
1 | * {
2 | box-sizing: border-box;
3 | }
4 |
5 | html {
6 | height: 100%;
7 |
8 | }
9 |
10 | body {
11 | font-family: 'Roboto', 'Oxygen',
12 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial, Helvetica,
13 | sans-serif;
14 | -webkit-font-smoothing: antialiased;
15 | -moz-osx-font-smoothing: grayscale;
16 | background-color: rgba(223, 242, 247, .5);
17 | height: 100%;
18 | margin: 0;
19 | }
20 |
21 | span {
22 | padding-right: 15px;
23 | padding-left: 15px;
24 | }
25 |
26 | .container {
27 | display: flex;
28 | justify-content: center;
29 | align-items: center;
30 | width: 100%;
31 | height: 100%;
32 | }
33 |
34 | .chat {
35 | height: 300px;
36 | width: 50vw;
37 | display: flex;
38 | flex-direction: column;
39 | justify-content: center;
40 | align-items: center;
41 | }
42 |
43 | ::-webkit-input-placeholder {
44 | color: .711
45 | }
46 |
47 | input {
48 | border: 0;
49 | padding: 15px;
50 | margin-left: auto;
51 | border-radius: 10px;
52 | }
53 |
54 | .messages {
55 | display: flex;
56 | flex-direction: column;
57 | overflow: scroll;
58 | height: 90%;
59 | width: 100%;
60 | background-color: white;
61 | padding: 15px;
62 | margin: 15px;
63 | border-radius: 10px;
64 | }
65 |
66 | #bot {
67 | margin-left: auto;
68 | }
69 |
70 | .bot {
71 | font-family: Consolas, 'Courier New', Menlo, source-code-pro, Monaco,
72 | monospace;
73 | }
74 |
75 | .avatar {
76 | height: 25px;
77 | }
78 |
79 | .response {
80 | display: flex;
81 | align-items: center;
82 | margin: 1%;
83 | }
84 |
85 |
86 | /* Mobile */
87 |
88 | @media only screen and (max-width: 980px) {
89 | .container {
90 | flex-direction: column;
91 | justify-content: flex-start;
92 | }
93 | .chat {
94 | width: 75vw;
95 | margin: 10vw;
96 | }
97 | }
--------------------------------------------------------------------------------
/flask_chat/semantic_sort.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
2 | import torch
3 | import requests
4 | import copy
5 | import re
6 |
7 |
8 | class Semantic_sort():
9 | def __init__(self) -> None:
10 | # 存储所有句子的句向量 queries
11 | self.tokenizer = AutoTokenizer.from_pretrained(
12 | 'hfl/chinese-roberta-wwm-ext')
13 | self.model = AutoModelForSequenceClassification.from_pretrained(
14 | 'hfl/chinese-roberta-wwm-ext') #.cuda()
15 | self.model.eval()
16 |
17 | def get_tokens(self, lines): # 输入句子列表 输出CLS token 列表
18 | inputs = self.tokenizer(lines, return_tensors="pt", truncation=True,
19 | padding='max_length', max_length=40) #.to('cuda')
20 | outputs = self.model(**inputs, output_hidden_states=True)
21 | last_hidden = outputs[1][-1] # hidden states的最后一层输出
22 | cls_token = last_hidden[:, 0, :]
23 | return cls_token
24 |
25 | def mask_sents(self, key_word='', sents=[]):
26 | assert key_word, sents
27 | for i, sent in enumerate(sents):
28 | sent = re.sub('|'.join(key_word.split(' ')), '[MASK]', sent)
29 | sents[i] = sent
30 |
31 | return sents
32 |
33 | def sem_sort(self, keyword='', sentences=[], mask=False):
34 | assert keyword, sentences
35 |
36 | recalls = sentences
37 |
38 | tmp = copy.deepcopy(recalls) # 暂存 以便不展现mask后结果
39 | if mask:
40 | recalls = self.mask_sents(keyword, recalls)
41 | print('keyword', keyword)
42 | print('recalls', recalls)
43 | # 同时传入keyword和句子 避免两次前传播
44 | with torch.no_grad():
45 | queries = self.get_tokens([keyword] + recalls)
46 | key = queries[0]
47 | queries = queries[1:]
48 |
49 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
50 | cos_sim = cos(queries, key.unsqueeze(0))
51 |
52 | cos_sim = [float(x) for x in cos_sim]
53 |
54 | sent_score = list(zip(tmp, cos_sim))
55 | # sent_score.sort(key=lambda x: x[0], reverse=True)
56 | return sent_score
57 |
58 |
59 | if __name__ == '__main__':
60 | sem_sorter = Semantic_sort()
61 | res = sem_sorter.sem_sort('电视', ['老电视', '新电视', '电视柜'])
62 | print(res)
63 |
--------------------------------------------------------------------------------
/chat_web/constants.js:
--------------------------------------------------------------------------------
1 | // Options the user could type in
2 | const prompts = [
3 | ["hi", "hey", "hello", "good morning", "good afternoon"],
4 | ["how are you", "how is life", "how are things"],
5 | ["what are you doing", "what is going on", "what is up"],
6 | ["how old are you"],
7 | ["who are you", "are you human", "are you bot", "are you human or bot"],
8 | ["who created you", "who made you"],
9 | [
10 | "your name please",
11 | "your name",
12 | "may i know your name",
13 | "what is your name",
14 | "what call yourself"
15 | ],
16 | ["i love you"],
17 | ["happy", "good", "fun", "wonderful", "fantastic", "cool"],
18 | ["bad", "bored", "tired"],
19 | ["help me", "tell me story", "tell me joke"],
20 | ["ah", "yes", "ok", "okay", "nice"],
21 | ["bye", "good bye", "goodbye", "see you later"],
22 | ["what should i eat today"],
23 | ["bro"],
24 | ["what", "why", "how", "where", "when"],
25 | ["no","not sure","maybe","no thanks"],
26 | [""],
27 | ["haha","ha","lol","hehe","funny","joke"]
28 | ]
29 |
30 | // Possible responses, in corresponding order
31 |
32 | const replies = [
33 | ["Hello!", "Hi!", "Hey!", "Hi there!","Howdy"],
34 | [
35 | "Fine... how are you?",
36 | "Pretty well, how are you?",
37 | "Fantastic, how are you?"
38 | ],
39 | [
40 | "Nothing much",
41 | "About to go to sleep",
42 | "Can you guess?",
43 | "I don't know actually"
44 | ],
45 | ["I am infinite"],
46 | ["I am just a bot", "I am a bot. What are you?"],
47 | ["The one true God, JavaScript"],
48 | ["I am nameless", "I don't have a name"],
49 | ["I love you too", "Me too"],
50 | ["Have you ever felt bad?", "Glad to hear it"],
51 | ["Why?", "Why? You shouldn't!", "Try watching TV"],
52 | ["What about?", "Once upon a time..."],
53 | ["Tell me a story", "Tell me a joke", "Tell me about yourself"],
54 | ["Bye", "Goodbye", "See you later"],
55 | ["Sushi", "Pizza"],
56 | ["Bro!"],
57 | ["Great question"],
58 | ["That's ok","I understand","What do you want to talk about?"],
59 | ["Please say something :("],
60 | ["Haha!","Good one!"]
61 | ]
62 |
63 | // Random for any other user input
64 |
65 | const alternative = [
66 | "Same",
67 | "Go on...",
68 | "Bro...",
69 | "Try again",
70 | "I'm listening...",
71 | "I don't understand :/"
72 | ]
73 |
74 | // Whatever else you want :)
75 |
76 | const coronavirus = ["Please stay home", "Wear a mask", "Fortunately, I don't have COVID", "These are uncertain times"]
--------------------------------------------------------------------------------
/flask_chat/chat.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from datetime import datetime
3 | from elasticsearch import Elasticsearch
4 | import os
5 | import jieba
6 | import re
7 | import sys
8 | import json
9 | import sys
10 | from transformers import pipeline, set_seed
11 | from transformers import AutoModelForCausalLM, AutoTokenizer
12 | import torch.nn as nn
13 | import torch
14 | import requests
15 | from semantic_sort import Semantic_sort
16 |
17 |
18 | class Generator():
19 | def __init__(self) -> None:
20 | # set_seed(42)
21 | model_cache = "../cache/checkpoint-download"
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_cache)
23 | self.model = AutoModelForCausalLM.from_pretrained(model_cache)# .cuda() 暂不用显卡
24 | self.semantic = Semantic_sort().sem_sort
25 |
26 | def cal_jaccard(self, text1, text2):
27 | text1 = set(text1)
28 | text2 = set(text2)
29 | same = text1.intersection(text2)
30 | return float(len(same))/(len(text1)+len(text2)-len(same))
31 |
32 | # 生成
33 | def generate(self, key_word, FIN = '', SIN = ''): # TODO 使用行业信息
34 | def cut_SEP(input_ids):
35 | return input_ids[:,:-1]
36 |
37 | # set seed to reproduce results. Feel free to change the seed though to get different results
38 | with torch.no_grad():
39 | input_ids = self.tokenizer.encode('[SEP]'.join(key_word), return_tensors='pt')# .cuda()
40 | print('输入:', self.tokenizer.decode(input_ids[0]), input_ids[0])
41 | # input_ids = cut_SEP(input_ids)
42 | print('输入:', self.tokenizer.decode(input_ids[0]))
43 | input_record = self.tokenizer.decode(input_ids[0])
44 |
45 | # torch.manual_seed(2)
46 | num_return_sequences = 30
47 | # activate sampling and deactivate top_k by setting top_k sampling to 0
48 |
49 | sample_output = self.model.generate(
50 | input_ids,
51 | do_sample=True,
52 | max_length=len(input_record) + 30,
53 | top_k=0,
54 | top_p=0.92,
55 | num_return_sequences=num_return_sequences,
56 | # eos_token_id = 101
57 | # temperature=0.7
58 | )
59 |
60 | seqs = []
61 | len_prefix = len(input_record)
62 | print('input_record',input_record)
63 | print(len_prefix)
64 | print("Output:\n" + 100 * '-')
65 | for i in range(num_return_sequences):
66 | res = self.tokenizer.decode(sample_output[i], skip_special_tokens=False)
67 | print(res)
68 | res = res[len_prefix:]
69 | print(res)
70 | print(re.findall('(.*?)\[SEP\]',res))
71 | try:
72 | seqs.append(re.findall('(.*?)\[SEP\]',res)[0])
73 | except:
74 | print('输出出现异常:')
75 | print(res)
76 | pass
77 | # print(res)
78 | seqs = list(set(seqs))
79 | seqs = self.semantic(key_word[-1], seqs)
80 | # seqs = [x for x in seqs if self.cal_jaccard(key_word[-1], x[0]) < 0.8]
81 | seqs = [[x[0], x[1] + (1 - self.cal_jaccard(key_word[-1], x[0]))] for x in seqs]
82 | seqs.sort(key = lambda x: x[1], reverse = True)
83 | print(seqs)
84 | seqs = [re.sub(' ','',x[0]) for x in seqs]
85 | seqs = [seqs[0]]
86 | print(seqs)
87 | return list(set(seqs))
88 |
89 |
90 |
91 |
92 |
93 | if __name__ == '__main__':
94 | engine = Generator()
95 | sents = []
96 | while True:
97 | key_word = input("发言:")
98 | sents.append(key_word)
99 | res = engine.generate(sents)
100 | if '[CLS]' in res[0]:
101 | print('无言以对......')
102 | sents = []
103 | else:
104 | sents.append(res[0])
105 | print(res)
106 |
--------------------------------------------------------------------------------
/chat_web/index.js:
--------------------------------------------------------------------------------
1 | document.addEventListener("DOMContentLoaded", () => {
2 | const inputField = document.getElementById("input");
3 | inputField.addEventListener("keydown", (e) => {
4 | if (e.code === "Enter") {
5 | let input = inputField.value;
6 | inputField.value = "";
7 | // output(input);
8 | console.log('数据传入:' + input)
9 | res = search('nothing' ,input);
10 | console.log('机器回应:' + res)
11 | addChat(input, res)
12 | }
13 | });
14 | });
15 |
16 | function search(industry, key_word) {
17 | var data = {
18 | key_word: key_word,
19 | industry: industry
20 | };
21 |
22 | var arr = [];
23 | $.ajax({
24 | type: 'GET',
25 | url: 'http://127.0.0.1:5000/', // 本地flask 对话服务
26 | data: data,
27 | dataType: 'json',
28 | async: false,//同步
29 | success: function (data) {
30 | arr = data.list_of_data
31 | },
32 | error: function (xhr, type) {
33 | }
34 | });
35 | return arr[0];
36 | }
37 |
38 |
39 | function output(input) {
40 | let product;
41 |
42 | // Regex remove non word/space chars
43 | // Trim trailing whitespce
44 | // Remove digits - not sure if this is best
45 | // But solves problem of entering something like 'hi1'
46 |
47 | let text = input.toLowerCase().replace(/[^\w\s]/gi, "").replace(/[\d]/gi, "").trim();
48 | text = text
49 | .replace(/ a /g, " ") // 'tell me a story' -> 'tell me story'
50 | .replace(/i feel /g, "")
51 | .replace(/whats/g, "what is")
52 | .replace(/please /g, "")
53 | .replace(/ please/g, "")
54 | .replace(/r u/g, "are you");
55 |
56 | if (compare(prompts, replies, text)) {
57 | // Search for exact match in `prompts`
58 | product = compare(prompts, replies, text);
59 | } else if (text.match(/thank/gi)) {
60 | product = "You're welcome!"
61 | } else if (text.match(/(corona|covid|virus)/gi)) {
62 | // If no match, check if message contains `coronavirus`
63 | product = coronavirus[Math.floor(Math.random() * coronavirus.length)];
64 | } else {
65 | // If all else fails: random alternative
66 | product = alternative[Math.floor(Math.random() * alternative.length)];
67 | }
68 |
69 | // Update DOM
70 | addChat(input, product);
71 | }
72 |
73 | function compare(promptsArray, repliesArray, string) {
74 | let reply;
75 | let replyFound = false;
76 | for (let x = 0; x < promptsArray.length; x++) {
77 | for (let y = 0; y < promptsArray[x].length; y++) {
78 | if (promptsArray[x][y] === string) {
79 | let replies = repliesArray[x];
80 | reply = replies[Math.floor(Math.random() * replies.length)];
81 | replyFound = true;
82 | // Stop inner loop when input value matches prompts
83 | break;
84 | }
85 | }
86 | if (replyFound) {
87 | // Stop outer loop when reply is found instead of interating through the entire array
88 | break;
89 | }
90 | }
91 | return reply;
92 | }
93 |
94 | function addChat(input, product) {
95 | const messagesContainer = document.getElementById("messages");
96 |
97 | let userDiv = document.createElement("div");
98 | userDiv.id = "user";
99 | userDiv.className = "user response";
100 | userDiv.innerHTML = `
${input}`;
101 | messagesContainer.appendChild(userDiv);
102 |
103 | let botDiv = document.createElement("div");
104 | let botImg = document.createElement("img");
105 | let botText = document.createElement("span");
106 | botDiv.id = "bot";
107 | botImg.src = "bot-mini.png";
108 | botImg.className = "avatar";
109 | botDiv.className = "bot response";
110 | botText.innerText = "Typing...";
111 | botDiv.appendChild(botText);
112 | botDiv.appendChild(botImg);
113 | messagesContainer.appendChild(botDiv);
114 | // Keep messages at most recent
115 | messagesContainer.scrollTop = messagesContainer.scrollHeight - messagesContainer.clientHeight;
116 |
117 | // Fake delay to seem "real"
118 | setTimeout(() => {
119 | botText.innerText = `${product}`;
120 | textToSpeech(product)
121 | }, 2000
122 | )
123 |
124 | }
--------------------------------------------------------------------------------