├── db ├── config.json ├── models │ ├── Qa.py │ ├── __init__.py │ ├── Dialogue.py │ └── UserProfile.py ├── database │ └── tips.txt └── preprocess.py ├── docker └── .gitkeep ├── examples └── .gitkeep ├── mobile └── .gitkeep ├── logs ├── tips.txt └── logger.py ├── scripts └── tips.txt ├── dialogue ├── load_datasets.py ├── tools │ ├── logger.py │ └── read_data.py ├── tensorflow │ ├── nlu │ │ └── intent │ │ │ └── intent_manager.py │ ├── task │ │ ├── common │ │ │ ├── common.py │ │ │ ├── kb.py │ │ │ └── pre_treat.py │ │ ├── README.md │ │ ├── config │ │ │ ├── model_config.json │ │ │ └── get_config.py │ │ ├── model │ │ │ ├── tracker.py │ │ │ ├── model.py │ │ │ └── chatter.py │ │ ├── data │ │ │ ├── semi_dict.json │ │ │ └── ontology.json │ │ └── task_chatter.py │ ├── apis.py │ ├── positional_encoding.py │ ├── optimizers.py │ ├── preprocess.py │ ├── loader.py │ ├── utils.py │ ├── load_dataset.py │ ├── scheduled_sampling │ │ └── transformer.py │ ├── layers.py │ ├── seq2seq │ │ ├── model.py │ │ └── modules.py │ ├── smn │ │ └── model.py │ ├── gpt2 │ │ └── gpt2.py │ ├── beamsearch.py │ └── modules.py ├── pipeline.py ├── pytorch │ ├── apis.py │ ├── layers.py │ ├── transformer │ │ └── model.py │ ├── utils.py │ ├── load_dataset.py │ ├── seq2seq │ │ └── model.py │ ├── beamsearch.py │ └── modules.py ├── constants.py ├── config.py ├── metrics.py ├── actuator.py ├── debug.py └── tools.py ├── check.sh ├── app ├── assets │ ├── chat.png │ └── main.png ├── static │ └── favicon.ico ├── view │ └── __init__.py └── templates │ └── error │ └── 404.html ├── run.sh ├── .dockerignore ├── DockerFile ├── configs ├── constant.py ├── __init__.py ├── configs.py └── configs.json ├── requirements.txt ├── gunicorn.conf.py ├── check.py ├── server.py ├── .gitignore ├── README.CN.md ├── README.md └── docs ├── Attention_Is_All_You_Need.md └── Massive_Exploration_of_Neural_Machine_Translation_Architectures.md /db/config.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /db/models/Qa.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docker/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mobile/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /db/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/tips.txt: -------------------------------------------------------------------------------- 1 | # 系统更新日志 -------------------------------------------------------------------------------- /scripts/tips.txt: -------------------------------------------------------------------------------- 1 | # 项目脚本 -------------------------------------------------------------------------------- /dialogue/load_datasets.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dialogue/tools/logger.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | python check.py -------------------------------------------------------------------------------- /db/database/tips.txt: -------------------------------------------------------------------------------- 1 | # 这里对数据进行数据库转化,如kv等格式 -------------------------------------------------------------------------------- /dialogue/tensorflow/nlu/intent/intent_manager.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/assets/chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengBoCong/nlp-dialogue/HEAD/app/assets/chat.png -------------------------------------------------------------------------------- /app/assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengBoCong/nlp-dialogue/HEAD/app/assets/main.png -------------------------------------------------------------------------------- /app/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengBoCong/nlp-dialogue/HEAD/app/static/favicon.ico -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python actuator.py --version tf --model transformer --act pre_treat -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | docker* 2 | docs 3 | .git* 4 | **/*.pyc 5 | **/__pycache__ 6 | !docker/configs 7 | data/ 8 | examples/ -------------------------------------------------------------------------------- /app/view/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint 2 | 3 | views = Blueprint("views", __name__, url_prefix="/") 4 | -------------------------------------------------------------------------------- /DockerFile: -------------------------------------------------------------------------------- 1 | FROM python:3.7-slim-stretch 2 | ADD requirements.txt / 3 | RUN pip install -r /requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ 4 | ADD . /app 5 | WORKDIR /app 6 | EXPOSE 8081 7 | CMD [ "python" , "./server/server.py"] -------------------------------------------------------------------------------- /configs/constant.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ Global static variables 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | DIALOGUE_APIS_MODULE = {"tf": "dialogue.tensorflow.apis", "torch": "dialogue.pytorch.apis"} 9 | -------------------------------------------------------------------------------- /dialogue/pipeline.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ Assembly Inference Pipeline 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/common/common.py: -------------------------------------------------------------------------------- 1 | from optparse import OptionParser 2 | 3 | 4 | class CmdParser(OptionParser): 5 | def error(self, msg): 6 | print('Error!提示信息如下:') 7 | self.print_help() 8 | self.exit(0) 9 | 10 | def exit(self, status=0, msg=None): 11 | exit(status) 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.2 2 | tensorflow==2.4 3 | Flask==1.1.4 4 | Flask-Script==2.0.6 5 | Flask-Mail==0.9.1 6 | Flask-Caching==1.10.1 7 | Flask-SQLAlchemy==2.5.1 8 | Flask-Login==0.5.0 9 | Flask-Migrate==3.1.0 10 | mysqlclient==2.0.3 11 | gunicorn==20.1.0 12 | gevent==21.1.2 13 | Flask-SocketIO==5.1.1 14 | jieba==0.42.1 15 | LAC==2.1.2 16 | pkuseg==0.0.25 17 | -------------------------------------------------------------------------------- /dialogue/pytorch/apis.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ PyTorch Server Api 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | from flask import Blueprint 13 | 14 | apis = Blueprint("torch_apis", __name__, url_prefix="/apis/torch") 15 | 16 | 17 | @apis.route('test', methods=['GET', 'POST']) 18 | def test(): 19 | return "torch_test" 20 | -------------------------------------------------------------------------------- /dialogue/tensorflow/apis.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ TensorFlow Server Api 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | from flask import Blueprint 13 | 14 | apis = Blueprint("tf_apis", __name__, url_prefix="/apis/tf") 15 | 16 | 17 | @apis.route('test', methods=['GET', 'POST']) 18 | def test(): 19 | return "tf_apis" 20 | 21 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/README.md: -------------------------------------------------------------------------------- 1 | # 目录 2 | + [运行说明](#运行说明) 3 | + [模型效果](#模型效果) 4 | 5 | 6 | # 运行说明 7 | + 运行入口: 8 | + task_chatter.py为seq2seq的执行入口文件:指令需要附带运行参数 9 | + 执行的指令格式: 10 | + task:python task_chatter.py -t/--type [执行模式] 11 | + 执行类别:pre_treat(默认)/train/chat 12 | + 执行指令示例: 13 | + python task_chatter.py 14 | + python task_chatter.py -t pre_treat 15 | + pre_treat模式为文本预处理模式,如果在没有分词结果集的情况下,需要先运行pre_treat模式 16 | + train模式为训练模式 17 | + chat模式为对话模式。chat模式下运行时,输入exit即退出对话。 18 | 19 | + 正常执行顺序为pre_treat->train->chat 20 | 21 | # 模型效果 22 | 待完善 -------------------------------------------------------------------------------- /gunicorn.conf.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ Project Server Entrance 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | workers = 5 9 | threads = 2 10 | daemon = 'false' 11 | worker_class = 'gevent' # 采用gevent库,支持异步处理请求,提高吞吐量 12 | bind = '0.0.0.0:8000' 13 | worker_connections = 2000 14 | pidfile = 'var/run/gunicorn.pid' 15 | accesslog = 'logs/run/gunicorn_acess.log' 16 | errorlog = 'logs/run/gunicorn_error.log' 17 | loglevel = 'warning' 18 | access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' 19 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/config/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data": "checkpoint", 3 | "units": 150, 4 | "vocab_size": 20000, 5 | "embedding_dim": 256, 6 | "use_pretrained_embedding": true, 7 | "kb": "data/kb.json", 8 | "sent_groups": "data/groups.json", 9 | "database": "data/database.json", 10 | "ontology": "data/ontology.json", 11 | "semi_dict": "data/semi_dict.json", 12 | "dialogues_train": "data/woz_train_en.json", 13 | "tokenized_data": "data/dialogues_tokenized.txt", 14 | "dict_fn": "data/task_dict.json", 15 | "kb_indicator_len": 3, 16 | "beam_size": 3, 17 | "max_length": 300, 18 | "max_train_data_size": 30, 19 | "epochs": 2 20 | } -------------------------------------------------------------------------------- /app/templates/error/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 | dsfsdf 15 |
16 | 17 | 18 | 21 | 24 | -------------------------------------------------------------------------------- /dialogue/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 全局通用常量 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | -------------------------------------------------------------------------------- /dialogue/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 抽象配置类及个模型默认配置类 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | common = {} 23 | 24 | transformer = {} 25 | 26 | class Config(object): 27 | def __str__(self): 28 | print() 29 | 30 | __repr__ = __str__ 31 | 32 | def __getitem__(self, item): 33 | return a + b 34 | -------------------------------------------------------------------------------- /check.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if __name__ == '__main__': 4 | work_path = os.getcwd() 5 | 6 | # checkpoints 下的目录层级 7 | if not os.path.exists(work_path + "\\dialogue\\checkpoints\\tensorflow"): 8 | os.mkdir(work_path + "\\dialogue\\checkpoints\\tensorflow") 9 | if not os.path.exists(work_path + "\\dialogue\\checkpoints\\pytorch"): 10 | os.mkdir(work_path + "\\dialogue\\checkpoints\\pytorch") 11 | 12 | # data下的目录层级 13 | if not os.path.exists(work_path + "\\dialogue\\data\\history"): 14 | os.mkdir(work_path + "\\dialogue\\data\\history") 15 | if not os.path.exists(work_path + "\\dialogue\\data\\preprocess"): 16 | os.mkdir(work_path + "\\dialogue\\data\\preprocess") 17 | if not os.path.exists(work_path + "\\dialogue\\data\\pytorch"): 18 | os.mkdir(work_path + "\\dialogue\\data\\pytorch") 19 | 20 | # model 下的目录层级 21 | if not os.path.exists(work_path + "\\dialogue\\models\\tensorflow"): 22 | os.mkdir(work_path + "\\dialogue\\models\\tensorflow") 23 | if not os.path.exists(work_path + "\\dialogue\\models\\pytorch"): 24 | os.mkdir(work_path + "\\dialogue\\models\\pytorch") 25 | -------------------------------------------------------------------------------- /dialogue/tensorflow/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from typing import Tuple 4 | 5 | 6 | def _get_angles(pos: tf.Tensor, i: tf.Tensor, d_model: tf.Tensor) -> Tuple: 7 | """pos/10000^(2i/d_model) 8 | 9 | :param pos: 字符总的数量按顺序递增 10 | :param i: 词嵌入大小按顺序递增 11 | :param d_model: 词嵌入大小 12 | :return: shape=(pos.shape[0], d_model) 13 | """ 14 | angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) 15 | return pos * angle_rates 16 | 17 | 18 | def positional_encoding(position: int, d_model: int, d_type: tf.dtypes.DType = tf.float32) -> Tuple: 19 | """PE(pos,2i) = sin(pos/10000^(2i/d_model)) | PE(pos,2i+1) = cos(pos/10000^(2i/d_model)) 20 | 21 | :param position: 字符总数 22 | :param d_model: 词嵌入大小 23 | :param d_type: 运算精度 24 | """ 25 | angle_rads = _get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) 26 | 27 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 28 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 29 | pos_encoding = angle_rads[np.newaxis, ...] 30 | 31 | return tf.cast(pos_encoding, dtype=d_type) 32 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/config/get_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | seq2seq_config = os.path.dirname(__file__) + r'\model_config.json' 5 | path = os.path.dirname(__file__)[:-6] 6 | 7 | 8 | def get_config_json(config_file='main.json'): 9 | with open(config_file, 'r') as file: 10 | return json.load(file) 11 | 12 | 13 | def config(config_file=seq2seq_config): 14 | return get_config_json(config_file=config_file) 15 | 16 | 17 | conf = {} 18 | 19 | conf = config() 20 | 21 | # task模型相关配置 22 | epochs = conf['epochs'] 23 | vocab_size = conf['vocab_size'] 24 | beam_size = conf['beam_size'] 25 | embedding_dim = conf['embedding_dim'] 26 | max_length = conf['max_length'] 27 | units = conf['units'] 28 | task_train_data = path + conf['train_data'] 29 | sent_groups = path + conf['sent_groups'] # 含插槽的句子组合 30 | database = path + conf['database'] 31 | ontology = path + conf['ontology'] 32 | semi_dict = path + conf['semi_dict'] 33 | dialogues_train = path + conf['dialogues_train'] 34 | dict_fn = path + conf['dict_fn'] 35 | dialogues_tokenized = path + conf['tokenized_data'] 36 | kb_indicator_len = conf['kb_indicator_len'] 37 | max_train_data_size = conf['max_train_data_size'] 38 | -------------------------------------------------------------------------------- /db/models/Dialogue.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ Session 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from configs import db 9 | from datetime import datetime 10 | 11 | 12 | class Dialogue(db.Model): 13 | """ 对话文本 14 | """ 15 | __tablename__ = "DIALOGUE_DIALOGUE" 16 | 17 | ID = db.Column(db.String(50), primary_key=True, nullable=False, comment="ID") 18 | CREATE_DATETIME = db.Column(db.DateTime, default=datetime.now(), nullable=False, comment="创建时间") 19 | EMAIL = db.Column(db.String(60), index=True, nullable=False, default="", unique=True, comment="邮箱账号") 20 | IDENTITY = db.Column(db.Enum("Agent", "User"), nullable=False, comment="发送者身份") 21 | UTTERANCE = db.Column(db.String(255), nullable=False, default="", comment="文本内容") 22 | 23 | def __repr__(self): 24 | return '\n' % self.EMAIL 25 | 26 | def to_json(self): 27 | """ Dialogue字符串格式化 28 | """ 29 | return { 30 | 'ID': self.ID, 31 | 'CREATE_DATETIME': self.CREATE_DATETIME.strftime('%Y-%m-%d %H:%M:%S'), 32 | 'EMAIL': self.EMAIL, 33 | 'IDENTITY': self.IDENTITY, 34 | 'UTTERANCE': self.UTTERANCE, 35 | } 36 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/common/kb.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import reduce 3 | from collections import defaultdict 4 | 5 | 6 | def load_kb(kb_fn, primary): 7 | with open(kb_fn) as f: 8 | data = json.load(f) 9 | 10 | kb = KnowledgeBase(data[0].keys(), primary) 11 | 12 | for obj in data: 13 | kb.add(obj) 14 | 15 | return kb 16 | 17 | 18 | class KnowledgeBase: 19 | """ 20 | 提供基于知识检索的API 21 | """ 22 | 23 | def __init__(self, columns, primary): 24 | self.columns = columns 25 | self.primary = primary 26 | self.index = {k: defaultdict(list) for k in self.columns} 27 | self.objs = {} 28 | 29 | def add(self, obj): 30 | """ 31 | 添加一个知识对象到KB中 32 | """ 33 | for key, value in obj.items(): 34 | self.index[key][value].append(obj[self.primary]) 35 | 36 | self.objs[obj[self.primary]] = obj 37 | 38 | def get(self, primary): 39 | """ 40 | 通过key查询知识对象 41 | """ 42 | return self.objs[primary] 43 | 44 | def search(self, key, value): 45 | return set(self.index[key][value]); 46 | 47 | def search_multi(self, kvs): 48 | """ 49 | 通过key批量查询知识对象 50 | :params kvs: key和value的列表,使用lambda表达式进行累积操作 51 | """ 52 | ret = reduce(lambda y, x: y & set(self.index[x[0]][x[1]]) 53 | if y is not None else set(self.index[x[0]][x[1]]), kvs, None) 54 | return ret if ret is not None else set() 55 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/model/tracker.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def inform_slot_tracker(units, n_choices, name="inform_slot_tracker"): 5 | """ 6 | informable插槽跟踪器,informable插槽是用户告知系统的信息,用 7 | 来约束对话的一些条件,系统为了完成任务必须满足这些条件 8 | 用来获得时间t的状态的槽值分布,比如price=cheap 9 | 输入为状态跟踪器的输入'state_t',输出为槽值分布'P(v_s_t| state_t)' 10 | """ 11 | inputs = tf.keras.Input(shape=(units,), name="inform_slot_tracker_inputs") 12 | outputs = tf.keras.layers.Dense(units=n_choices)(inputs) 13 | return tf.keras.Model(inputs=inputs, outputs=outputs, name=name) 14 | 15 | 16 | def request_slot_tracker(units, name="request_slot_tracker"): 17 | """ 18 | requestable插槽跟踪器,requestable插槽是用户询问系统的信息 19 | 用来获得时间t的状态的非分类插槽槽值分布, 20 | 比如: 21 | address=1 (地址被询问) 22 | phone=0 (用户不关心电话号码) 23 | 输入为状态跟踪器的输入'state_t',输出为槽值二元分布'P(v_s_t| state_t)' 24 | """ 25 | inputs = tf.keras.Input(shape=(units,), name="request_slot_tracker_inputs") 26 | outputs = tf.keras.layers.Dense(units=2)(inputs) 27 | return tf.keras.Model(inputs=inputs, outputs=outputs, name=name) 28 | 29 | 30 | def state_tracker(units, vocab_size, embedding_dim, name="state_tracker"): 31 | inputs = tf.keras.Input(shape=(None,), name="state_tracker_inputs") 32 | embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs) 33 | output, state = tf.keras.layers.GRU(units=units, 34 | return_sequences=True, 35 | return_state=True, 36 | dropout=0.9)(inputs=embedding) 37 | return tf.keras.Model(inputs=[inputs], outputs=[output, state], name=name) 38 | -------------------------------------------------------------------------------- /dialogue/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """模型评估指标 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | from typing import List 23 | 24 | 25 | def recall_at_position_k_in_n(labels: list, k: list = [1], n: int = 10, tar: float = 1.0) -> List: 26 | """ Rn@k 召回率指标计算 27 | 28 | :param labels: 数据列表 29 | :param k: top k 30 | :param n: 样本范围 31 | :param tar: 目标值 32 | :return: 所得指标值 33 | """ 34 | score = labels[0] 35 | label = labels[1] 36 | 37 | length = len(k) 38 | sum_k = [0.0] * length 39 | total = 0 40 | for i in range(0, len(label), n): 41 | total += 1 42 | remain = [label[index] for index in np.argsort(score[i:i + n])] 43 | for j in range(length): 44 | sum_k[j] += 1.0 * remain[-k[j]:].count(tar) / remain.count(tar) 45 | 46 | for i in range(length): 47 | sum_k[i] /= total 48 | 49 | return sum_k 50 | -------------------------------------------------------------------------------- /dialogue/actuator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 总执行器入口 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import sys 22 | from argparse import ArgumentParser 23 | 24 | pipelines = { 25 | # "preprocess": pa 26 | } 27 | 28 | def preprocess(): 29 | pass 30 | 31 | 32 | def train(): 33 | pass 34 | 35 | 36 | def valid(): 37 | pass 38 | 39 | 40 | def run(): 41 | pass 42 | 43 | 44 | def main() -> None: 45 | parser = ArgumentParser(description="total actuator", usage="the first parameter must be --pipeline PIPELINE") 46 | parser.add_argument("--pipeline", default="chain", type=str, required=True, 47 | help="execution mode,preprocess/train/valid/run") 48 | 49 | options = parser.parse_args().__dict__ 50 | 51 | if not options.get("pipeline"): 52 | raise AttributeError("actuator.py: error: PIPELINE: [preprocess/train/valid/run]") 53 | 54 | 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ Server Entrance 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import importlib 13 | import os 14 | from configs import create_app 15 | from configs import db 16 | from flask import g 17 | from flask import render_template 18 | from flask_migrate import Migrate 19 | from flask_script import Manager 20 | from flask_script import Shell 21 | from configs import DIALOGUE_APIS_MODULE 22 | 23 | application = create_app(config_name=os.environ.get("ENV") or "default") 24 | module = importlib.import_module(DIALOGUE_APIS_MODULE.get(os.environ.get("DIALOGUE_MODULE") or "tf")) 25 | application.register_blueprint(module.apis) 26 | 27 | migrate = Migrate(application, db) 28 | server = Manager(application) 29 | 30 | with application.app_context(): 31 | g.contextPath = "" 32 | 33 | 34 | @application.errorhandler(404) 35 | def route_not_found(e): 36 | """ Api/Route not found 37 | """ 38 | return render_template("error/404.html"), 404 39 | 40 | 41 | @application.teardown_appcontext 42 | def shutdown_session(exception=None): 43 | """ The last operation performed when the Sever shutdown 44 | """ 45 | db.session.remove() 46 | # TODO: send mail while application shutdown 47 | 48 | 49 | @server.command 50 | def check(): 51 | """ Check instructions before starting 52 | """ 53 | if not os.path.exists("logs/run"): 54 | os.mkdir("logs/run") 55 | # TODO: check system integrity 56 | 57 | 58 | def make_shell_context(): 59 | """ Make it possible to control the running Server program through shell commands 60 | """ 61 | return dict(app=application, db=db) 62 | 63 | 64 | if __name__ == "__main__": 65 | server.add_command("shell", Shell(make_context=make_shell_context())) 66 | server.run() 67 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ Global Configuration 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import uuid 14 | from configs.configs import config 15 | from configs.constant import * 16 | from flask import Flask 17 | from flask_caching import Cache 18 | from flask_login import LoginManager 19 | from flask_mail import Mail 20 | from flask_socketio import SocketIO 21 | from flask_sqlalchemy import SQLAlchemy 22 | 23 | mail = Mail() 24 | db = SQLAlchemy() 25 | socket_io = SocketIO() 26 | login_manager = LoginManager() 27 | login_manager.session_protection = "strong" 28 | login_manager.login_view = "views.login" 29 | login_manager.login_message = "Token is invalid, please regain permissions" 30 | 31 | 32 | @login_manager.user_loader 33 | def load_user(user_id): 34 | """ Activate session 35 | """ 36 | return {"ID": "null"} 37 | 38 | 39 | basedir = os.path.abspath(os.path.dirname(__file__)) 40 | 41 | 42 | def create_app(config_name): 43 | """ Server app related configuration 44 | """ 45 | app = Flask(__name__, template_folder="../app/templates", static_folder="../app/static") 46 | app.config.from_object(config[config_name]) 47 | config[config_name].init_app(app=app) 48 | app.secret_key = uuid.uuid1().__str__() 49 | app.jinja_env.variable_start_string = "[[" 50 | app.jinja_env.variable_end_string = "]]" 51 | cache = Cache(config={"CACHE_TYPE": "simple"}) 52 | 53 | db.init_app(app=app) 54 | mail.init_app(app=app) 55 | cache.init_app(app=app) 56 | login_manager.init_app(app=app) 57 | socket_io.init_app(app=app) 58 | 59 | from app.view import views 60 | from dialogue.pytorch.apis import apis 61 | app.register_blueprint(views) 62 | app.register_blueprint(apis) 63 | 64 | return app 65 | -------------------------------------------------------------------------------- /db/models/UserProfile.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ UserProfile 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from configs import db 9 | from datetime import datetime 10 | from flask_login import UserMixin 11 | 12 | 13 | class UserProfile(db.Model, UserMixin): 14 | """ 用户画像 15 | """ 16 | __tablename__ = "DIALOGUE_USERPROFILE" 17 | 18 | ID = db.Column(db.String(50), primary_key=True, nullable=False, comment="ID") 19 | CREATE_DATETIME = db.Column(db.DateTime, default=datetime.now(), nullable=False, comment="创建时间") 20 | LAST_DATETIME = db.Column(db.DateTime, index=True, nullable=False, comment="最后登录时间") 21 | UPDATE_DATETIME = db.Column(db.DateTime, nullable=False, comment="更新时间") 22 | EMAIL = db.Column(db.String(60), index=True, nullable=False, default="", unique=True, comment="邮箱账号") 23 | NAME = db.Column(db.String(50), nullable=False, default="", comment="名称") 24 | AVATAR_URL = db.Column(db.String(255), nullable=False, default="", comment="图片地址") 25 | SEX = db.Column(db.String(1), nullable=False, default="0", comment="性别") 26 | AGE = db.Column(db.Integer, nullable=False, default=0, comment="年龄") 27 | Contact = db.Column(db.String(30), nullable=False, default="", comment="联系方式") 28 | 29 | def __repr__(self): 30 | return '\n' % self.NAME 31 | 32 | def to_json(self): 33 | """ UserProfile字符串格式化 34 | """ 35 | return { 36 | 'ID': self.ID, 37 | 'CREATE_DATETIME': self.CREATE_DATETIME.strftime('%Y-%m-%d %H:%M:%S'), 38 | 'LAST_DATETIME': self.LAST_DATETIME.strftime('%Y-%m-%d %H:%M:%S'), 39 | 'UPDATE_DATETIME': self.UPDATE_DATETIME.strftime('%Y-%m-%d %H:%M:%S'), 40 | 'EMAIL': self.EMAIL, 41 | 'NAME': self.NAME, 42 | 'AVATAR_URL': self.AVATAR_URL, 43 | 'SEX': self.SEX, 44 | 'AGE': self.AGE, 45 | 'Contact': self.Contact, 46 | } 47 | -------------------------------------------------------------------------------- /dialogue/tensorflow/optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """模型优化相关实现模块 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): 25 | def __init__(self, d_model, warmup_steps=4000): 26 | super(CustomSchedule, self).__init__() 27 | self.d_model = d_model 28 | self.d_model = tf.cast(self.d_model, tf.float32) 29 | self.warmup_steps = warmup_steps 30 | 31 | def __call__(self, step): 32 | arg1 = tf.math.rsqrt(step) 33 | arg2 = step * (self.warmup_steps ** -1.5) 34 | return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) 35 | 36 | # def get_config(self): 37 | # print("") 38 | 39 | 40 | def loss_func_mask(real, pred, weights=None): 41 | """ 屏蔽填充的SparseCategoricalCrossentropy损失 42 | 43 | 真实标签real中有0填充部分,这部分不记入预测损失 44 | 45 | :param weights: 样本权重 46 | :param real: 真实标签张量 47 | :param pred: logits张量 48 | :return: 损失平均值 49 | """ 50 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") 51 | mask = tf.math.logical_not(tf.math.equal(real, 0)) # 填充位为0,掩蔽 52 | 53 | loss_ = loss_object(real, pred, sample_weight=weights) 54 | mask = tf.cast(mask, dtype=loss_.dtype) 55 | loss_ *= mask 56 | 57 | return tf.reduce_mean(loss_) 58 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/model/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import model.tracker as tracker 3 | import warnings 4 | 5 | warnings.filterwarnings("ignore") 6 | 7 | 8 | def get_slots_tracker(onto, units): 9 | """ 10 | 根据inform和request的槽位的个数,生成对应的tracker 11 | :param onto: 处理过的本体数据集 12 | :param state_tracker_hidden_size: 处理过的本体数据集 13 | """ 14 | slot_trackers = {} 15 | slot_len_sum = 0 16 | 17 | for slot in onto: 18 | if len(onto[slot]) > 2: 19 | slot_trackers[slot] = tracker.inform_slot_tracker( 20 | units=units, n_choices=len(onto[slot]), name="inform_slot_tracker_{}".format(slot)) 21 | slot_len_sum += len(onto[slot]) 22 | else: 23 | slot_trackers[slot] = tracker.request_slot_tracker( 24 | units=units, name="request_slot_tracker_{}".format(slot)) 25 | slot_len_sum += 2 26 | 27 | return slot_trackers, slot_len_sum 28 | 29 | 30 | def task_encoder(units, vocab_size, embedding_dim, name="task_encoder"): 31 | """ 32 | task的encoder,使用双向LSTM对用户语句进行编码,输出序列和合并后的隐藏层 33 | """ 34 | inputs = tf.keras.Input(shape=(None,), name='task_encoder_inputs') 35 | embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs) 36 | output, forward_state, backward_state = tf.keras.layers.Bidirectional( 37 | tf.keras.layers.GRU(units=units, return_sequences=True, 38 | return_state=True, dropout=0.9), merge_mode='concat')(embedding) 39 | 40 | state = tf.concat([forward_state, backward_state], -1) 41 | return tf.keras.Model(inputs=inputs, outputs=[output, state], name=name) 42 | 43 | 44 | def task(units, onto, vocab_size, embedding_dim, max_sentence_len, name="task_model"): 45 | """ 46 | Task-Orient模型,使用函数式API实现,将encoder和decoder封装 47 | :param vocab_size:token大小 48 | """ 49 | usr_utts = tf.keras.Input(shape=max_sentence_len, name="task_model_inputs") 50 | kb_indicator = tf.keras.Input(shape=1) 51 | _, encoder_hidden = task_encoder(units=units, vocab_size=vocab_size, embedding_dim=embedding_dim)( 52 | usr_utts) 53 | inputs = tf.concat([encoder_hidden, kb_indicator], -1) 54 | _, state = tracker.state_tracker(units=units, vocab_size=vocab_size, embedding_dim=embedding_dim)( 55 | inputs) 56 | slot_trackers, slot_len_sum = get_slots_tracker(onto=onto, units=units) 57 | state_pred = {slot: slot_trackers[slot](state) for slot in onto} 58 | 59 | return tf.keras.Model(inputs=[usr_utts, kb_indicator], outputs=state_pred, name=name) 60 | -------------------------------------------------------------------------------- /configs/configs.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ Server Configuration 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | from datetime import timedelta 14 | from flask import Flask 15 | 16 | 17 | class Config: 18 | """ server configuration 19 | """ 20 | # session 21 | PERMANENT_SESSION_LIFETIME = timedelta(hours=3) 22 | # mail 23 | MAIL_SERVER = os.environ.get("MAIL_SERVER") 24 | MAIL_PROT = 25 25 | MAIL_USE_TLS = True 26 | MAIL_USE_SSL = False 27 | MAIL_USERNAME = os.environ.get("MAIL_USERNAME") 28 | MAIL_PASSWORD = os.environ.get("MAIL_PASSWORD") 29 | # db 30 | SQLALCHEMY_COMMIT_ON_TEARDOWN = True 31 | SQLALCHEMY_TRACK_MODIFICATIONS = True 32 | SQLALCHEMY_ECHO = True 33 | SQLALCHEMY_POOL_SIZE = 20 34 | SQLALCHEMY_MAX_OVERFLOW = 10 35 | SQLALCHEMY_POOL_RECYCLE = 1200 36 | 37 | @classmethod 38 | def init_app(cls, app: Flask): 39 | """ common configuration 40 | """ 41 | app.config["DEBUG"] = cls.DEBUG 42 | app.config["PERMANENT_SESSION_LIFETIME"] = cls.PERMANENT_SESSION_LIFETIME 43 | app.config["MAIL_SERVER"] = cls.MAIL_SERVER 44 | app.config["MAIL_PROT"] = cls.MAIL_PROT 45 | app.config["MAIL_USE_TLS"] = cls.MAIL_USE_TLS 46 | app.config["MAIL_USE_SSL"] = cls.MAIL_USE_SSL 47 | app.config["MAIL_USERNAME"] = cls.MAIL_USERNAME 48 | app.config["MAIL_PASSWORD"] = cls.MAIL_PASSWORD 49 | app.config["SQLALCHEMY_DATABASE_URI"] = cls.SQLALCHEMY_DATABASE_URI 50 | app.config["SQLALCHEMY_COMMIT_ON_TEARDOWN"] = cls.SQLALCHEMY_COMMIT_ON_TEARDOWN 51 | app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = cls.SQLALCHEMY_TRACK_MODIFICATIONS 52 | app.config["SQLALCHEMY_ECHO"] = cls.SQLALCHEMY_ECHO 53 | app.config["SQLALCHEMY_POOL_SIZE"] = cls.SQLALCHEMY_POOL_SIZE 54 | app.config["SQLALCHEMY_MAX_OVERFLOW"] = cls.SQLALCHEMY_MAX_OVERFLOW 55 | app.config["SQLALCHEMY_POOL_RECYCLE"] = cls.SQLALCHEMY_POOL_RECYCLE 56 | 57 | 58 | class DevelopmentConfig(Config): 59 | """ development configuration 60 | """ 61 | DEBUG = True 62 | SQLALCHEMY_DATABASE_URI = "mysql://root:Andie130857@localhost:3306/verb?charset=utf8&autocommit=true" 63 | 64 | 65 | class ProductionConfig(Config): 66 | """ production configuration 67 | """ 68 | SQLALCHEMY_DATABASE_URI = 'mysql://root:Andie130857@localhost:3306/verb?charset=utf8&autocommit=true' 69 | 70 | 71 | config = { 72 | 'development': DevelopmentConfig, 73 | 'production': ProductionConfig, 74 | 'default': DevelopmentConfig 75 | } 76 | -------------------------------------------------------------------------------- /dialogue/tensorflow/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """预处理操作,包含在正式使用模型前(训练、评估、推断等操作前)进行相关预处理 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import pysolr 23 | import tensorflow as tf 24 | from typing import NoReturn 25 | 26 | 27 | def create_search_data(data_path: str, solr_server: str, max_database_size: int, 28 | vocab_size: int, dict_path: str, unk_sign: str = "") -> NoReturn: 29 | """ 生成轮次tf-idf为索引的候选回复 30 | 31 | :param data_path: 文本数据路径 32 | :param solr_server: solr服务的地址 33 | :param max_database_size: 从文本中读取最大数据量 34 | :param vocab_size: 词汇量大小 35 | :param dict_path: 字典保存路径 36 | :param unk_sign: 未登录词 37 | :return: 无返回值 38 | """ 39 | if not os.path.exists(data_path): 40 | print("没有找到对应的文本数据,请确认文本数据存在") 41 | exit(0) 42 | 43 | responses = [] 44 | all_text_list = [] 45 | solr = pysolr.Solr(url=solr_server, always_commit=True) 46 | solr.ping() 47 | 48 | print("检测到对应文本,正在处理文本数据") 49 | with open(data_path, "r", encoding="utf-8") as file: 50 | count = 0 51 | odd_flag = True 52 | for line in file: 53 | odd_flag = not odd_flag 54 | if odd_flag: 55 | continue 56 | 57 | line = line.strip("\n").replace("/", "") 58 | apart = line.split("\t")[1:] 59 | all_text_list.extend(apart) 60 | for i in range(len(apart)): 61 | responses.append({"utterance": apart[i]}) 62 | 63 | count += 1 64 | print("\r已处理了 {} 轮次对话".format(count), flush=True, end="") 65 | if max_database_size == count: 66 | break 67 | 68 | solr.delete(q="*:*") 69 | solr.add(docs=responses) 70 | 71 | tokenizer = tf.keras.preprocessing.text.Tokenizer(filters="", num_words=vocab_size, oov_token=unk_sign) 72 | tokenizer.fit_on_texts(all_text_list) 73 | with open(dict_path, "w", encoding="utf-8") as dict_file: 74 | dict_file.write(tokenizer.to_json()) 75 | 76 | print("\n文本处理完毕,已更新候选回复集,并且以保存字典") 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | #data/ 12 | .idea/ 13 | db/data/history/ 14 | db/data/preprocess/ 15 | db/data/pytorch/ 16 | temp/ 17 | 18 | dialogue/checkpoints/ 19 | dialogue/models/ 20 | dialogue/debug.py 21 | dialogue/tensorflow/seq2seq/actuator.py 22 | dialogue/tensorflow/transformer/actuator.py 23 | dialogue/tensorflow/smn/actuator.py 24 | debug.py 25 | *.DS_Store 26 | development.txt 27 | logs/run/ 28 | 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | -------------------------------------------------------------------------------- /logs/logger.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ log and pipeline data collector 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | from datetime import datetime 14 | 15 | 16 | class Collector(object): 17 | """ used for collect pipeline data, special training logs, inference 18 | logs, evaluate logs, etc. And provide visual data and log views. 19 | """ 20 | 21 | def __init__(self, log_dir: str): 22 | self.log_dir = log_dir 23 | self.collector_dir = os.path.join(log_dir, datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) 24 | 25 | if os.path.exists(self.collector_dir): 26 | os.makedirs(self.collector_dir) 27 | 28 | def write_runtime_log(self, file_name: str, line: str): 29 | """ write runtime log file 30 | :param file_name: log write location file name 31 | :param line: log description 32 | """ 33 | with open(os.path.join(self.collector_dir, "runtime.logs"), 'a', encoding="utf-8") as file: 34 | file.write("INFO {} {} {}".format(datetime.now(), file_name, line)) 35 | 36 | def write_training_log(self, metrics: dict, if_batch_end: bool = False): 37 | """ write training log file 38 | """ 39 | pass 40 | 41 | def write_evaluate_log(self, metrics: dict): 42 | """ write evaluate log file 43 | """ 44 | pass 45 | 46 | # if os.path.exists(os.path.join(log_dir, )) 47 | 48 | 49 | # def log_operator(level: str, log_file: str = None, 50 | # log_format: str = "[%(levelname)s] - [%(asctime)s] - [file: %(filename)s] - " 51 | # "[function: %(funcName)s] - [%(message)s]") -> logging.Logger: 52 | # """ 日志操作方法,日志级别有"CRITICAL","FATAL","ERROR","WARN","WARNING","INFO","DEBUG","NOTSET" 53 | # CRITICAL = 50, FATAL = CRITICAL, ERROR = 40, WARNING = 30, WARN = WARNING, INFO = 20, DEBUG = 10, NOTSET = 0 54 | # 55 | # :param log_file: 日志路径 56 | # :param level: 日志级别 57 | # :param log_format: 日志信息格式 58 | # :return: 日志记录器 59 | # """ 60 | # if log_file is None: 61 | # log_file = os.path.abspath(__file__)[ 62 | # :os.path.abspath(__file__).rfind("\\dialogue\\")] + "\\dialogue\\data\\preprocess\\runtime.logs" 63 | # 64 | # logger = logging.getLogger() 65 | # logger.setLevel(level) 66 | # file_handler = logging.FileHandler(log_file, encoding="utf-8") 67 | # file_handler.setLevel(level=level) 68 | # formatter = logging.Formatter(log_format) 69 | # file_handler.setFormatter(formatter) 70 | # logger.addHandler(file_handler) 71 | # 72 | # return logger 73 | 74 | if __name__ == "__main__": 75 | print(os.path.join("D:\\te", "test", "te")) 76 | print(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) 77 | -------------------------------------------------------------------------------- /db/preprocess.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | """ 数据预料处理 3 | """ 4 | # Author: DengBoCong 5 | # 6 | # License: Apache-2.0 License 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import abc 13 | import os 14 | from dialogue.tokenizer import pad_sequences 15 | from dialogue.tokenizer import Segment 16 | from dialogue.tokenizer import Tokenizer 17 | 18 | 19 | class DataProcessor(abc.ABC): 20 | """ 数据格式基类 21 | """ 22 | 23 | def __init__(self, tokenizer: Tokenizer, segment: Segment = None): 24 | """ 25 | :param tokenizer: token处理器,这个必传 26 | :param segment: 分词器,后面不会用到分词的话就不用传segment 27 | :return: None 28 | """ 29 | self.tokenizer = tokenizer 30 | self.segment = segment 31 | 32 | @abc.abstractmethod 33 | def to_npy(self, *args, **kwargs): 34 | """ 该方法用于将处理好的语料数据转换成npy文件格式 35 | 36 | Note: 37 | a): 38 | """ 39 | raise NotImplementedError("Must be implemented in subclasses.") 40 | 41 | @abc.abstractmethod 42 | def to_file(self, *args, **kwargs): 43 | """ 该方法用于将处理好的语料数据转换成文件格式 44 | 45 | Note: 46 | a): 47 | """ 48 | 49 | 50 | class TextPair(DataProcessor): 51 | """ text pair或text pair + label形式数据类型 52 | """ 53 | 54 | def __init__(self, tokenizer: Tokenizer, segment: Segment = None): 55 | """ 56 | :param tokenizer: token处理器,这个必传 57 | :param segment: 分词器,后面不会用到分词的话就不用传segment 58 | :return: None 59 | """ 60 | super(TextPair, self).__init__(tokenizer, segment) 61 | 62 | def to_npy(self, batch_size: int, output_dir: str, file_path: str, 63 | split: str, if_seg: bool = False, d_type: str = "int32") -> None: 64 | """ 保存为npy文件 65 | :param batch_size: 每个npy文件保存的样本数 66 | :param output_dir: 文件输出目录 67 | :param file_path: 未分词或已分词文本列表文件路径,一行一个文本 68 | :param split: 文本分隔符,list模式不传则每个element视为list,file模式必传 69 | :param if_seg: 是否进行分词,注意使用需要初始化传入segment 70 | :param d_type: label数据类型 71 | :return: None 72 | """ 73 | if if_seg and not self.segment: 74 | raise TypeError("Segment must be instantiated in the init method") 75 | 76 | with open(os.path.join(output_dir, "outputs.txt"), "w", encoding="utf-8" 77 | ) as output_file, open(file_path, "r", encoding="utf-8") as input_file: 78 | for line in input_file: 79 | line = line.strip().strip("\n") 80 | if line == "": 81 | continue 82 | 83 | elements = line.split(split) 84 | if len(elements) < 2 or len(elements) > 3: 85 | raise RuntimeError("TextPair - to_npy: The data does not meet the format requirements") 86 | 87 | 88 | if __name__ == "__main__": 89 | print([1, 2, "fsd", 1.2]) 90 | -------------------------------------------------------------------------------- /dialogue/pytorch/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """公用层组件 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | from typing import Tuple 26 | from typing import NoReturn 27 | 28 | 29 | class BahdanauAttention(nn.Module): 30 | """ bahdanau attention实现 31 | 32 | :param enc_units: encoder单元大小 33 | :param dec_units: decoder单元大小 34 | """ 35 | 36 | def __init__(self, enc_units: int, dec_units: int) -> NoReturn: 37 | super(BahdanauAttention, self).__init__() 38 | self.W1 = nn.Linear(in_features=2 * enc_units, out_features=dec_units) 39 | self.W2 = nn.Linear(in_features=2 * enc_units, out_features=dec_units) 40 | self.V = nn.Linear(in_features=dec_units, out_features=1) 41 | 42 | def forward(self, query: torch.Tensor, values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 43 | """ 44 | :param query: 隐层状态 45 | :param values: encoder输出状态 46 | """ 47 | values = values.permute(1, 0, 2) 48 | hidden_with_time_axis = torch.unsqueeze(input=query, dim=1) 49 | score = self.V(torch.tanh(self.W1(values) + self.W2(hidden_with_time_axis))) 50 | 51 | attention_weights = F.softmax(input=score, dim=1) 52 | context_vector = attention_weights * values 53 | context_vector = torch.sum(input=context_vector, dim=1) 54 | 55 | return context_vector, attention_weights 56 | 57 | 58 | class PositionalEncoding(nn.Module): 59 | 60 | def __init__(self, d_model, dropout=0.1, max_len=5000): 61 | """ PE(pos,2i) = sin(pos/10000^(2i/d_model)) | PE(pos,2i+1) = cos(pos/10000^(2i/d_model)) 62 | 63 | :param d_model: 词嵌入大小 64 | :param dropout: 采样率 65 | :param max_len: 最大位置长度 66 | """ 67 | super(PositionalEncoding, self).__init__() 68 | self.dropout = nn.Dropout(p=dropout) 69 | 70 | pe = torch.zeros(max_len, d_model) 71 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 72 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 73 | pe[:, 0::2] = torch.sin(position * div_term) 74 | pe[:, 1::2] = torch.cos(position * div_term) 75 | pe = pe.unsqueeze(0).transpose(0, 1) 76 | self.register_buffer('pe', pe) 77 | 78 | def forward(self, x): 79 | """ 80 | :param x: 输入 81 | """ 82 | 83 | x = x + self.pe[:x.size(0), :] 84 | return self.dropout(x) 85 | -------------------------------------------------------------------------------- /configs/configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "seq2seq": { 3 | "cell_type": "lstm", 4 | "if_bidirectional": true, 5 | "units": 1024, 6 | "vocab_size": 1000, 7 | "embedding_dim": 256, 8 | "encoder_layers": 2, 9 | "decoder_layers": 2, 10 | "max_train_data_size": 0, 11 | "max_valid_data_size": 0, 12 | "max_sentence": 40, 13 | "dict_path": "data\\\\preprocess\\\\seq2seq_dict.json", 14 | "checkpoint_dir": "checkpoints\\\\tensorflow\\\\seq2seq", 15 | "resource_data_path": "data\\\\LCCC.json", 16 | "tokenized_data_path": "data\\\\preprocess\\\\lccc_tokenized.txt", 17 | "preprocess_data_path": "data\\\\preprocess\\\\single_tokenized.txt", 18 | "valid_data_path": "data\\\\preprocess\\\\single_tokenized.txt", 19 | "history_image_dir": "data\\\\history\\\\seq2seq\\\\", 20 | "valid_freq": 5, 21 | "checkpoint_save_freq": 2, 22 | "checkpoint_save_size": 3, 23 | "batch_size": 32, 24 | "buffer_size": 20000, 25 | "beam_size": 3, 26 | "valid_data_split": 0.2, 27 | "epochs": 5, 28 | "start_sign": "", 29 | "end_sign": "", 30 | "unk_sign": "", 31 | "encoder_save_path": "models\\tensorflow\\seq2seq\\encoder", 32 | "decoder_save_path": "models\\tensorflow\\seq2seq\\decoder" 33 | }, 34 | "smn": { 35 | "max_sentence": 50, 36 | "max_utterance": 10, 37 | "units": 200, 38 | "vocab_size": 2000, 39 | "embedding_dim": 200, 40 | "max_train_data_size": 0, 41 | "max_valid_data_size": 0, 42 | "checkpoint_save_freq": 2, 43 | "checkpoint_save_size": 1, 44 | "valid_data_split": 0.0, 45 | "max_database_size": 0, 46 | "learning_rate": 0.001, 47 | "act": "pre_treat", 48 | "dict_path": "data\\preprocess\\smn_dict.json", 49 | "checkpoint_dir": "checkpoints\\tensorflow\\smn", 50 | "train_data_path": "data\\ubuntu_train.txt", 51 | "valid_data_path": "data\\ubuntu_valid.txt", 52 | "solr_server": "http://49.235.33.100:8983/solr/smn/", 53 | "candidate_database": "data\\preprocess\\candidate.json", 54 | "model_save_path": "models\\tensorflow\\smn", 55 | "batch_size": 32, 56 | "buffer_size": 20000, 57 | "epochs": 5, 58 | "start_sign": "", 59 | "end_sign": "", 60 | "unk_sign": "" 61 | }, 62 | "transformer": { 63 | "num_layers": 2, 64 | "num_heads": 8, 65 | "units": 512, 66 | "dropout": 0.1, 67 | "vocab_size": 1500, 68 | "embedding_dim": 256, 69 | "learning_rate_beta_1": 0.9, 70 | "learning_rate_beta_2": 0.98, 71 | "max_train_data_size": 200, 72 | "max_valid_data_size": 100, 73 | "max_sentence": 40, 74 | "valid_data_file": "", 75 | "valid_freq": 5, 76 | "checkpoint_save_freq": 2, 77 | "checkpoint_save_size": 3, 78 | "batch_size": 32, 79 | "buffer_size": 20000, 80 | "beam_size": 3, 81 | "valid_data_split": 0.2, 82 | "epochs": 5, 83 | "start_sign": "", 84 | "end_sign": "", 85 | "unk_sign": "", 86 | "dict_path": "data\\preprocess\\transformer_dict.json", 87 | "checkpoint_dir": "checkpoints\\tensorflow\\transformer", 88 | "raw_data_path": "data\\LCCC.json", 89 | "tokenized_data_path": "data\\preprocess\\lccc_tokenized.txt", 90 | "preprocess_data_path": "data\\preprocess\\single_tokenized.txt", 91 | "valid_data_path": "data\\preprocess\\single_tokenized.txt", 92 | "history_image_dir": "data\\history\\transformer\\", 93 | "encoder_save_path": "models\\tensorflow\\transformer\\encoder", 94 | "decoder_save_path": "models\\tensorflow\\transformer\\decoder" 95 | } 96 | } -------------------------------------------------------------------------------- /dialogue/tensorflow/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """应用于server的加载模型推断组件 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import json 23 | import tensorflow as tf 24 | from dialogue.tensorflow.seq2seq.modules import Seq2SeqModule 25 | from dialogue.tensorflow.smn.modules import SMNModule 26 | from dialogue.tensorflow.transformer.modules import TransformerModule 27 | from typing import Tuple 28 | 29 | 30 | def load_transformer(config_path: str) -> TransformerModule: 31 | """加载Transformer的Modules 32 | 33 | :param config_path: 34 | :return: TransformerModule 35 | """ 36 | options, work_path = check_and_read_path(config_path=config_path) 37 | if options is None: 38 | return None 39 | 40 | encoder = tf.keras.models.load_model(filepath=(work_path + options["encoder_save_path"]).replace("\\", "/")) 41 | decoder = tf.keras.models.load_model(filepath=(work_path + options["decoder_save_path"]).replace("\\", "/")) 42 | 43 | modules = TransformerModule(max_sentence=options["max_sentence"], 44 | dict_path=work_path + options["dict_path"], encoder=encoder, decoder=decoder) 45 | 46 | return modules 47 | 48 | 49 | def load_seq2seq(config_path: str) -> Seq2SeqModule: 50 | """加载Seq2Seq的Modules 51 | 52 | :param config_path: 53 | :return: Seq2SeqModule 54 | """ 55 | options, work_path = check_and_read_path(config_path=config_path) 56 | if options is None: 57 | return None 58 | 59 | encoder = tf.keras.models.load_model(filepath=(work_path + options["encoder_save_path"]).replace("\\", "/")) 60 | decoder = tf.keras.models.load_model(filepath=(work_path + options["decoder_save_path"]).replace("\\", "/")) 61 | 62 | modules = Seq2SeqModule(max_sentence=options["max_sentence"], 63 | dict_path=work_path + options["dict_path"], encoder=encoder, decoder=decoder) 64 | 65 | return modules 66 | 67 | 68 | def load_smn(config_path: str) -> SMNModule: 69 | """加载Seq2Seq的Modules 70 | 71 | :param config_path: 72 | :return: Seq2SeqModule 73 | """ 74 | options, work_path = check_and_read_path(config_path=config_path) 75 | if options is None: 76 | return None 77 | 78 | model = tf.keras.models.load_model(filepath=(work_path + options["model_save_path"]).replace("\\", "/")) 79 | modules = SMNModule(max_sentence=options["max_sentence"], dict_path=work_path + options["dict_path"], model=model) 80 | 81 | return modules 82 | 83 | 84 | def check_and_read_path(config_path: str) -> Tuple: 85 | """ 检查配置文件路径及读取配置文件内容及当前工作目录 86 | 87 | :param config_path: 88 | :return: options, work_path 89 | """ 90 | if config_path == "": 91 | print("加载失败") 92 | return None, None 93 | 94 | with open(config_path, "r", encoding="utf-8") as config_file: 95 | options = json.load(config_file) 96 | 97 | file_path = os.path.abspath(__file__) 98 | work_path = file_path[:file_path.find("tensorflow")] 99 | 100 | return options, work_path 101 | 102 | 103 | if __name__ == '__main__': 104 | load_transformer(r"D:\DengBoCong\Project\nlp-dialogue\dialogue\config\transformer.json") 105 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/data/semi_dict.json: -------------------------------------------------------------------------------- 1 | {"area": ["area of town","section of the city","side of town","part of town","of town","area","location"], 2 | "food": ["food type","type of cuisine","type of food","kind of food","types of food","food","cuisine"], 3 | "pricerange": ["price range"], 4 | "phone" : ["phone number","telephone number","phone","number"], 5 | "address" : ["address"], 6 | "postcode" : ["postal code","post code","zip code","postcode","area code"], 7 | "name" : ["name"], 8 | "none": ["NONE"], 9 | "exist" : ["NONE"], 10 | "afghan": ["afghan"], 11 | "african": ["african"], 12 | "afternoon tea": ["afternoon tea"], 13 | "any": ["no specific","no preference","dont really care","do not care","dont care","does not matter","any"], 14 | "asian oriental": ["asian oriental"], 15 | "australasian": ["australasian"], 16 | "australian": ["australian"], 17 | "austrian": ["austrian"], 18 | "barbeque": ["barbeque","bbq"], 19 | "basque": ["basque"], 20 | "belgian": ["belgian"], 21 | "bistro": ["bistro"], 22 | "brazilian": ["brazilian"], 23 | "british": ["british"], 24 | "canapes": ["canapes"], 25 | "cantonese": ["cantonese"], 26 | "caribbean": ["caribbean"], 27 | "catalan": ["catalan"], 28 | "centre": ["centre", "center", "central","downtown"], 29 | "cheap": ["cheap","inexpensive"], 30 | "chinese": ["chinese"], 31 | "christmas": ["christmas"], 32 | "corsica": ["corsica"], 33 | "creative": ["creative"], 34 | "crossover": ["crossover"], 35 | "cuban": ["cuban"], 36 | "danish": ["danish"], 37 | "east": ["east", "eastern"], 38 | "eastern european": ["eastern european"], 39 | "english": ["english"], 40 | "eritrean": ["eritrean"], 41 | "european": ["european"], 42 | "expensive": ["expensive","upscale","high priced"], 43 | "french": ["french"], 44 | "fusion": ["fusion"], 45 | "gastropub": ["gastropub"], 46 | "german": ["german"], 47 | "greek": ["greek"], 48 | "halal": ["halal"], 49 | "hungarian": ["hungarian"], 50 | "indian": ["indian"], 51 | "indonesian": ["indonesian"], 52 | "international": ["international"], 53 | "irish": ["irish"], 54 | "italian": ["italian"], 55 | "jamaican": ["jamaican"], 56 | "japanese": ["japanese"], 57 | "korean": ["korean"], 58 | "kosher": ["kosher"], 59 | "latin american": ["latin american"], 60 | "lebanese": ["lebanese"], 61 | "light bites": ["light bites"], 62 | "malaysian": ["malaysian"], 63 | "mediterranean": ["mediterranean"], 64 | "mexican": ["mexican"], 65 | "middle eastern": ["middle eastern"], 66 | "moderate": ["not too expensive","moderate","mid range","mid price","reasonably priced","medium"], 67 | "modern american": ["modern american"], 68 | "modern eclectic": ["modern eclectic"], 69 | "modern european": ["modern european"], 70 | "modern global": ["modern global"], 71 | "molecular gastronomy": ["molecular gastronomy"], 72 | "moroccan": ["moroccan"], 73 | "new zealand": ["new zealand"], 74 | "north": ["north", "northern"], 75 | "north african": ["north african"], 76 | "north american": ["north american"], 77 | "north indian": ["north indian"], 78 | "northern european": ["northern european"], 79 | "panasian": ["panasian"], 80 | "persian": ["persian"], 81 | "polish": ["polish"], 82 | "polynesian": ["polynesian"], 83 | "portuguese": ["portuguese"], 84 | "romanian": ["romanian"], 85 | "russian": ["russian"], 86 | "scandinavian": ["scandinavian"], 87 | "scottish": ["scottish"], 88 | "seafood": ["sea food","seafood"], 89 | "singaporean": ["singapore","singaporean"], 90 | "south": ["south", "southern"], 91 | "south african": ["south african"], 92 | "south indian": ["south indian"], 93 | "spanish": ["spanish"], 94 | "sri lankan": ["sri lankan"], 95 | "steakhouse": ["steakhouse"], 96 | "swedish": ["swedish"], 97 | "swiss": ["swiss"], 98 | "thai": ["thai"], 99 | "the americas": ["the americas"], 100 | "traditional": ["traditional"], 101 | "turkish": ["turkish"], 102 | "tuscan": ["tuscan"], 103 | "unusual": ["unusual"], 104 | "vegetarian": ["vegetarian"], 105 | "venetian": ["venetian"], 106 | "vietnamese": ["vietnamese"], 107 | "welsh": ["welsh"], 108 | "west": ["west", "western"], 109 | "world": ["world"] 110 | } 111 | -------------------------------------------------------------------------------- /dialogue/tensorflow/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """支持各类任务的工具,检查点加载、分词器加载、句子预处理、mask等等 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from sklearn.feature_extraction.text import TfidfVectorizer 23 | from typing import Tuple 24 | from typing import List 25 | 26 | 27 | def combine_mask(seq: tf.Tensor) -> Tuple: 28 | """对input中的不能见单位进行mask 29 | 30 | :param seq: 输入序列 31 | :param d_type: 运算精度 32 | :return: mask 33 | """ 34 | look_ahead_mask = _create_look_ahead_mask(seq) 35 | padding_mask = create_padding_mask(seq) 36 | return tf.maximum(look_ahead_mask, padding_mask) 37 | 38 | 39 | def create_padding_mask(seq: tf.Tensor) -> Tuple: 40 | """ 用于创建输入序列的扩充部分的mask 41 | 42 | :param seq: 输入序列 43 | :return: mask 44 | """ 45 | seq = tf.cast(x=tf.math.equal(seq, 0), dtype=tf.float32) 46 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len) 47 | 48 | 49 | def _create_look_ahead_mask(seq: tf.Tensor) -> Tuple: 50 | """ 用于创建当前点以后位置部分的mask 51 | 52 | :param seq: 输入序列 53 | :return: mask 54 | """ 55 | seq_len = tf.shape(seq)[1] 56 | look_ahead_mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) 57 | return look_ahead_mask 58 | 59 | 60 | def load_checkpoint(checkpoint_dir: str, execute_type: str, checkpoint_save_size: int, model: tf.keras.Model = None, 61 | encoder: tf.keras.Model = None, decoder: tf.keras.Model = None) -> tf.train.CheckpointManager: 62 | """加载检查点,同时支持Encoder-Decoder结构加载,两种类型的模型二者只能传其一 63 | 64 | :param checkpoint_dir: 检查点保存目录 65 | :param execute_type: 执行类型 66 | :param checkpoint_save_size: 检查点最大保存数量 67 | :param model: 传入的模型 68 | :param encoder: 传入的Encoder模型 69 | :param decoder: 传入的Decoder模型 70 | """ 71 | if model is not None: 72 | checkpoint = tf.train.Checkpoint(model=model) 73 | elif encoder is not None and decoder is not None: 74 | checkpoint = tf.train.Checkpoint(encoder=encoder, decoder=decoder) 75 | else: 76 | print("加载检查点所传入模型有误,请检查后重试!") 77 | 78 | checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=checkpoint_dir, 79 | max_to_keep=checkpoint_save_size) 80 | 81 | if checkpoint_manager.latest_checkpoint: 82 | checkpoint.restore(checkpoint_manager.latest_checkpoint).expect_partial() 83 | else: 84 | if execute_type != "train" and execute_type != "pre_treat": 85 | print("没有检查点,请先执行train模式") 86 | exit(0) 87 | 88 | return checkpoint_manager 89 | 90 | 91 | def get_tf_idf_top_k(history: list, k: int = 5) -> List: 92 | """ 使用tf_idf算法计算权重最高的k个词,并返回 93 | 94 | :param history: 上下文语句 95 | :param k: 返回词数量 96 | :return: top_5_key 97 | """ 98 | tf_idf = {} 99 | 100 | vectorizer = TfidfVectorizer(analyzer="word") 101 | weights = vectorizer.fit_transform(history).toarray()[-1] 102 | key_words = vectorizer.get_feature_names() 103 | 104 | for i in range(len(weights)): 105 | tf_idf[key_words[i]] = weights[i] 106 | 107 | top_k_key = [] 108 | tf_idf_sorted = sorted(tf_idf.items(), key=lambda x: x[1], reverse=True)[:k] 109 | for element in tf_idf_sorted: 110 | top_k_key.append(element[0]) 111 | 112 | return top_k_key 113 | -------------------------------------------------------------------------------- /dialogue/tensorflow/load_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Dataset加载模块,内含各模型针对性的以及公用性的数据加载方法 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from dialogue.tools.read_data import read_data 23 | from dialogue.tools import load_tokenizer 24 | from typing import Tuple 25 | 26 | 27 | def load_data(dict_path: str, buffer_size: int, batch_size: int, train_data_type: str, valid_data_type: str, 28 | max_sentence: int, valid_data_split: float = 0.0, train_data_path: str = "", valid_data_path: str = "", 29 | max_train_data_size: int = 0, max_valid_data_size: int = 0, **kwargs) -> Tuple: 30 | """ 数据加载方法 31 | 32 | :param dict_path: 字典路径 33 | :param buffer_size: Dataset加载缓存大小 34 | :param batch_size: Dataset加载批大小 35 | :param train_data_type: 读取训练数据类型,单轮/多轮... 36 | :param valid_data_type: 读取验证数据类型,单轮/多轮... 37 | :param max_sentence: 单个句子最大长度 38 | :param valid_data_split: 用于从训练数据中划分验证数据 39 | :param train_data_path: 文本数据路径 40 | :param valid_data_path: 验证数据文本路径 41 | :param max_train_data_size: 最大训练数据量 42 | :param max_valid_data_size: 最大验证数据量 43 | :return: 训练Dataset、验证Dataset、训练数据总共的步数、验证数据总共的步数和检查点前缀 44 | """ 45 | tokenizer = load_tokenizer(dict_path=dict_path) 46 | 47 | train_flag = True # 是否开启训练标记 48 | train_steps_per_epoch = 0 49 | train_first, train_second, train_third = None, None, None 50 | 51 | valid_flag = True # 是否开启验证标记 52 | valid_steps_per_epoch = 0 53 | valid_first, valid_second, valid_third = None, None, None 54 | 55 | if train_data_path != "": 56 | train_first, train_second, train_third = read_data( 57 | data_path=train_data_path, max_data_size=max_train_data_size, 58 | max_sentence=max_sentence, data_type=train_data_type, tokenizer=tokenizer, **kwargs 59 | ) 60 | else: 61 | train_flag = False 62 | 63 | if valid_data_path != "": 64 | print("读取验证对话对...") 65 | valid_first, valid_second, valid_third = read_data( 66 | data_path=valid_data_path, max_data_size=max_valid_data_size, 67 | max_sentence=max_sentence, data_type=valid_data_type, tokenizer=tokenizer, **kwargs 68 | ) 69 | elif valid_data_split != 0.0: 70 | train_size = int(len(train_first) * (1.0 - valid_data_split)) 71 | valid_first = train_first[train_size:] 72 | valid_second = train_second[train_size:] 73 | valid_third = train_third[train_size:] 74 | train_first = train_first[:train_size] 75 | train_second = train_second[:train_size] 76 | train_third = train_third[:train_size] 77 | else: 78 | valid_flag = False 79 | 80 | if train_flag: 81 | train_dataset = tf.data.Dataset.from_tensor_slices((train_first, train_second, train_third)).cache().shuffle( 82 | buffer_size, reshuffle_each_iteration=True).prefetch(tf.data.experimental.AUTOTUNE) 83 | train_dataset = train_dataset.batch(batch_size, drop_remainder=True) 84 | train_steps_per_epoch = len(train_first) // batch_size 85 | else: 86 | train_dataset = None 87 | 88 | if valid_flag: 89 | valid_dataset = tf.data.Dataset.from_tensor_slices((valid_first, valid_second, valid_third)) \ 90 | .prefetch(tf.data.experimental.AUTOTUNE) 91 | valid_dataset = valid_dataset.batch(batch_size, drop_remainder=True) 92 | valid_steps_per_epoch = len(valid_first) // batch_size 93 | else: 94 | valid_dataset = None 95 | 96 | return train_dataset, valid_dataset, train_steps_per_epoch, valid_steps_per_epoch 97 | -------------------------------------------------------------------------------- /dialogue/pytorch/transformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """transformer的Pytorch实现核心core 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import torch 22 | from torch.nn import TransformerDecoder 23 | from torch.nn import TransformerDecoderLayer 24 | from torch.nn import TransformerEncoder 25 | from torch.nn import TransformerEncoderLayer 26 | from typing import Any, Optional, NoReturn 27 | 28 | 29 | class Transformer(torch.nn.Module): 30 | """ Transformer Model """ 31 | 32 | def __init__(self, d_model: int = 512, num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, 33 | units: int = 2048, dropout: float = 0.1, activation: str = "relu") -> NoReturn: 34 | """ 35 | :param d_model: 深度,词嵌入维度 36 | :param num_heads: 注意力头数 37 | :param num_encoder_layers: encoder层数 38 | :param num_decoder_layers: decoder层数 39 | :param units: 单元数 40 | :param dropout: 采样率 41 | :param activation: 激活方法 42 | """ 43 | super(Transformer, self).__init__() 44 | 45 | encoder_layer = TransformerEncoderLayer(d_model, num_heads, units, dropout, activation) 46 | encoder_norm = torch.nn.LayerNorm(d_model) 47 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 48 | 49 | decoder_layer = TransformerDecoderLayer(d_model, num_heads, units, dropout, activation) 50 | decoder_norm = torch.nn.LayerNorm(d_model) 51 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) 52 | 53 | self._reset_parameters() 54 | 55 | self.d_model = d_model 56 | self.num_heads = num_heads 57 | 58 | def forward(self, enc_inputs: torch.Tensor, dec_inputs: torch.Tensor, enc_mask: Optional[torch.Tensor] = None, 59 | dec_mask: Optional[torch.Tensor] = None, enc_outputs_mask: Optional[torch.Tensor] = None, 60 | enc_key_padding_mask: Optional[torch.Tensor] = None, dec_key_padding_mask: Optional[torch.Tensor] = None, 61 | enc_outputs_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 62 | """ 63 | :param enc_inputs: encoder 输入 64 | :param dec_inputs: decoder 输入 65 | :param enc_mask: encoder 输入序列的mask 66 | :param dec_mask: decoder 输入序列的mask 67 | :param enc_outputs_mask: encoder 输出序列的mask 68 | :param enc_key_padding_mask: the ByteTensor mask for src keys per batch (optional). 69 | :param dec_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). 70 | :param enc_outputs_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). 71 | """ 72 | 73 | if enc_inputs.size(1) != dec_inputs.size(1): 74 | raise RuntimeError("the batch number of src and tgt must be equal") 75 | 76 | if enc_inputs.size(2) != self.d_model or dec_inputs.size(2) != self.d_model: 77 | raise RuntimeError("the feature number of src and tgt must be equal to d_model") 78 | 79 | memory = self.encoder(enc_inputs, mask=enc_mask, src_key_padding_mask=enc_key_padding_mask) 80 | output = self.decoder(dec_inputs, memory, tgt_mask=dec_mask, memory_mask=enc_outputs_mask, 81 | tgt_key_padding_mask=dec_key_padding_mask, 82 | memory_key_padding_mask=enc_outputs_key_padding_mask) 83 | return output 84 | 85 | def _reset_parameters(self): 86 | r"""Initiate parameters in the transformer model.""" 87 | 88 | for p in self.parameters(): 89 | if p.dim() > 1: 90 | torch.nn.init.xavier_uniform_(p) 91 | -------------------------------------------------------------------------------- /dialogue/pytorch/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """相关工具集 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import json 23 | import time 24 | import torch 25 | from typing import NoReturn 26 | from typing import Tuple 27 | 28 | 29 | def generate_square_subsequent_mask(self, sz): 30 | """ 序列mask 31 | 32 | :param sz: mask大小 33 | :return: mask 34 | """ 35 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 36 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 37 | return mask 38 | 39 | 40 | def save_checkpoint(checkpoint_dir: str, optimizer: torch.optim.Optimizer = None, model: torch.nn.Module = None, 41 | encoder: torch.nn.Module = None, decoder: torch.nn.Module = None) -> NoReturn: 42 | """ 保存模型检查点 43 | 44 | :param checkpoint_dir: 检查点保存路径 45 | :param optimizer: 优化器 46 | :param model: 模型 47 | :param encoder: encoder模型 48 | :param decoder: decoder模型 49 | :return: 无返回值 50 | """ 51 | checkpoint_path = checkpoint_dir + "checkpoint" 52 | version = 1 53 | if os.path.exists(checkpoint_path): 54 | with open(checkpoint_path, "r", encoding="utf-8") as file: 55 | info = json.load(file) 56 | version = info["version"] + 1 57 | 58 | model_dict = {} 59 | if model is not None: 60 | model_dict["model_state_dict"] = model.state_dict() 61 | if encoder is not None: 62 | model_dict["encoder_state_dict"] = encoder.state_dict() 63 | if decoder is not None: 64 | model_dict["decoder_state_dict"] = decoder.state_dict() 65 | model_dict["optimizer_state_dict"] = optimizer.state_dict() 66 | 67 | model_checkpoint_path = "checkpoint-{}.pth".format(version) 68 | torch.save(model_dict, checkpoint_dir + model_checkpoint_path) 69 | with open(checkpoint_path, "w", encoding="utf-8") as file: 70 | file.write(json.dumps({ 71 | "version": version, 72 | "model_checkpoint_path": model_checkpoint_path, 73 | "last_preserved_timestamp": time.time() 74 | })) 75 | 76 | 77 | def load_checkpoint(checkpoint_dir: str, execute_type: str, optimizer: torch.optim.Optimizer = None, 78 | model: torch.nn.Module = None, encoder: torch.nn.Module = None, 79 | decoder: torch.nn.Module = None) -> Tuple: 80 | """加载检查点恢复模型,同时支持Encoder-Decoder结构加载 81 | 82 | :param checkpoint_dir: 检查点保存路径 83 | :param execute_type: 执行类型 84 | :param optimizer: 优化器 85 | :param model: 模型 86 | :param encoder: encoder模型 87 | :param decoder: decoder模型 88 | :return: 恢复的各模型检查点细节 89 | """ 90 | checkpoint_path = checkpoint_dir + "checkpoint" 91 | 92 | if not os.path.exists(checkpoint_path) and execute_type != "train" and execute_type != "pre_treat": 93 | print("没有检查点,请先执行train模式") 94 | exit(0) 95 | elif not os.path.exists(checkpoint_path): 96 | return model, encoder, decoder, optimizer 97 | 98 | with open(checkpoint_path, "r", encoding="utf-8") as file: 99 | checkpoint_info = json.load(file) 100 | 101 | model_checkpoint_path = checkpoint_dir + checkpoint_info["model_checkpoint_path"] 102 | 103 | checkpoint = torch.load(model_checkpoint_path) 104 | if model is not None: 105 | model.load_state_dict(checkpoint["model_state_dict"]) 106 | if encoder is not None: 107 | encoder.load_state_dict(checkpoint["encoder_state_dict"]) 108 | if decoder is not None: 109 | decoder.load_state_dict(checkpoint["decoder_state_dict"]) 110 | 111 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 112 | 113 | return model, encoder, decoder, optimizer 114 | -------------------------------------------------------------------------------- /dialogue/debug.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from dialogue.tensorflow.utils import load_tokenizer 3 | from dialogue.tensorflow.utils import preprocess_request 4 | from dialogue.tensorflow.beamsearch import BeamSearch 5 | 6 | 7 | # MODEL_DIR = tempfile.gettempdir() 8 | # version = 1 9 | # export_path = os.path.join(MODEL_DIR, str(version)) 10 | # print("export_path = {}\n".format(export_path)) 11 | # 12 | # tf.keras.models.save_model() 13 | 14 | @tf.function(autograph=True, experimental_relax_shapes=True) 15 | def _inference_one_step(decoder, dec_input: tf.Tensor, enc_output: tf.Tensor, padding_mask: tf.Tensor): 16 | """ 单个推断步 17 | 18 | :param dec_input: decoder输入 19 | :param enc_output: encoder输出 20 | :param padding_mask: encoder的padding mask 21 | :return: 单个token结果 22 | """ 23 | predictions = decoder(inputs=[dec_input, enc_output, padding_mask]) 24 | predictions = tf.nn.softmax(predictions, axis=-1) 25 | predictions = predictions[:, -1:, :] 26 | predictions = tf.squeeze(predictions, axis=1) 27 | 28 | return predictions 29 | 30 | 31 | def inference(encoder, decoder, request: str, beam_size: int, dict_path, max_sentence, 32 | start_sign: str = "", end_sign: str = "") -> str: 33 | """ 对话推断模块 34 | 35 | :param request: 输入句子 36 | :param beam_size: beam大小 37 | :param start_sign: 句子开始标记 38 | :param end_sign: 句子结束标记 39 | :return: 返回历史指标数据 40 | """ 41 | tokenizer = load_tokenizer(dict_path) 42 | 43 | enc_input = preprocess_request(sentence=request, tokenizer=tokenizer, 44 | max_length=max_sentence, start_sign=start_sign, end_sign=end_sign) 45 | enc_output, padding_mask = encoder(inputs=enc_input) 46 | dec_input = tf.expand_dims([tokenizer.word_index.get(start_sign)], 0) 47 | 48 | beam_search_container = BeamSearch(beam_size=beam_size, max_length=max_sentence, worst_score=0) 49 | beam_search_container.reset(enc_output=enc_output, dec_input=dec_input, remain=padding_mask) 50 | enc_output, dec_input, padding_mask = beam_search_container.get_search_inputs() 51 | 52 | for t in range(max_sentence): 53 | predictions = _inference_one_step(decoder=decoder, dec_input=dec_input, 54 | enc_output=enc_output, padding_mask=padding_mask) 55 | 56 | beam_search_container.expand(predictions=predictions, end_sign=tokenizer.word_index.get(end_sign)) 57 | # 注意了,如果BeamSearch容器里的beam_size为0了,说明已经找到了相应数量的结果,直接跳出循环 58 | if beam_search_container.beam_size == 0: 59 | break 60 | enc_output, dec_input, padding_mask = beam_search_container.get_search_inputs() 61 | 62 | beam_search_result = beam_search_container.get_result(top_k=3) 63 | result = "" 64 | # 从容器中抽取序列,生成最终结果 65 | for i in range(len(beam_search_result)): 66 | temp = beam_search_result[i].numpy() 67 | text = tokenizer.sequences_to_texts(temp) 68 | text[0] = text[0].replace(start_sign, "").replace(end_sign, "").replace(" ", "") 69 | result = "<" + text[0] + ">" + result 70 | return result 71 | 72 | if __name__ == '__main__': 73 | # encoder = encoder(vocab_size=1500, num_layers=2, units=512, embedding_dim=256, num_heads=8, dropout=0.1) 74 | # decoder = decoder(vocab_size=1500, num_layers=2, units=512, embedding_dim=256, num_heads=8, dropout=0.1) 75 | # 76 | # checkpoint_manager = load_checkpoint( 77 | # checkpoint_dir=r"D:\DengBoCong\Project\nlp-dialogue\dialogue\checkpoints\tensorflow\transformer", 78 | # execute_type="chat", encoder=encoder, decoder=decoder, checkpoint_save_size=None 79 | # ) 80 | 81 | # request = "你去那儿竟然不喊我生气了,快点给我道歉" 82 | # response = inference(encoder=encoder, decoder=decoder, request=request, beam_size=3, 83 | # dict_path=r"D:\DengBoCong\Project\nlp-dialogue\dialogue\data\preprocess\transformer_dict.json", 84 | # max_sentence=40) 85 | # print("Agent: ", response) 86 | # encoder.save(r"D:\DengBoCong\Project\nlp-dialogue\dialogue\data\encoder") 87 | # decoder.save(r"D:\DengBoCong\Project\nlp-dialogue\dialogue\data\decoder") 88 | encoder_save = tf.keras.models.load_model(r"D:\DengBoCong\Project\nlp-dialogue\dialogue\data\encoder") 89 | decoder_save = tf.keras.models.load_model(r"D:\DengBoCong\Project\nlp-dialogue\dialogue\data\decoder") 90 | request = "你去那儿竟然不喊我生气了,快点给我道歉" 91 | response = inference(encoder=encoder_save, decoder=decoder_save, request=request, beam_size=3, 92 | dict_path=r"D:\DengBoCong\Project\nlp-dialogue\dialogue\data\preprocess\transformer_dict.json", 93 | max_sentence=40) 94 | print("Agent: ", response) 95 | -------------------------------------------------------------------------------- /dialogue/pytorch/load_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Dataset加载模块,内含各模型针对性的以及公用性的数据加载方法 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from dialogue.tools.read_data import read_data 22 | from dialogue.tools import load_tokenizer 23 | from typing import Tuple 24 | from torch.utils.data import DataLoader 25 | from torch.utils.data import Dataset 26 | 27 | 28 | class PairDataset(Dataset): 29 | """ 专门用于问答对形式的数据集构建的dataset,用于配合DataLoader使用 """ 30 | 31 | def __init__(self, first_tensor, second_tensor, third_tensor): 32 | """ Dataset预留三个数据位置 """ 33 | self.first_tensor = first_tensor 34 | self.second_tensor = second_tensor 35 | self.third_tensor = third_tensor 36 | 37 | def __getitem__(self, item): 38 | return self.first_tensor[item], self.second_tensor[item], self.third_tensor[item] 39 | 40 | def __len__(self): 41 | return len(self.first_tensor) 42 | 43 | 44 | def load_data(dict_path: str, batch_size: int, train_data_type: str, valid_data_type: str, 45 | max_sentence: int, valid_data_split: float = 0.0, train_data_path: str = "", valid_data_path: str = "", 46 | max_train_data_size: int = 0, max_valid_data_size: int = 0, num_workers: int = 2, **kwargs) -> Tuple: 47 | """ 数据加载方法 48 | 49 | :param dict_path: 字典路径 50 | :param batch_size: Dataset加载批大小 51 | :param train_data_type: 读取训练数据类型,单轮/多轮... 52 | :param valid_data_type: 读取验证数据类型,单轮/多轮... 53 | :param max_sentence: 单个句子最大长度 54 | :param valid_data_split: 用于从训练数据中划分验证数据 55 | :param train_data_path: 文本数据路径 56 | :param valid_data_path: 验证数据文本路径 57 | :param max_train_data_size: 最大训练数据量 58 | :param max_valid_data_size: 最大验证数据量 59 | :param num_workers: 数据加载器的工作线程 60 | :return: 训练Dataset、验证Dataset、训练数据总共的步数、验证数据总共的步数和检查点前缀 61 | """ 62 | tokenizer = load_tokenizer(dict_path=dict_path) 63 | 64 | train_flag = True # 是否开启训练标记 65 | train_steps_per_epoch = 0 66 | train_first, train_second, train_third = None, None, None 67 | 68 | valid_flag = True # 是否开启验证标记 69 | valid_steps_per_epoch = 0 70 | valid_first, valid_second, valid_third = None, None, None 71 | 72 | if train_data_path != "": 73 | train_first, train_second, train_third = read_data( 74 | data_path=train_data_path, max_data_size=max_train_data_size, 75 | max_sentence=max_sentence, data_type=train_data_type, tokenizer=tokenizer, **kwargs 76 | ) 77 | else: 78 | train_flag = False 79 | 80 | if valid_data_path != "": 81 | print("读取验证对话对...") 82 | valid_first, valid_second, valid_third = read_data( 83 | data_path=valid_data_path, max_data_size=max_valid_data_size, 84 | max_sentence=max_sentence, data_type=valid_data_type, tokenizer=tokenizer, **kwargs 85 | ) 86 | elif valid_data_split != 0.0: 87 | train_size = int(len(train_first) * (1.0 - valid_data_split)) 88 | valid_first = train_first[train_size:] 89 | valid_second = train_second[train_size:] 90 | valid_third = train_third[train_size:] 91 | train_first = train_first[:train_size] 92 | train_second = train_second[:train_size] 93 | train_third = train_third[:train_size] 94 | else: 95 | valid_flag = False 96 | 97 | if train_flag: 98 | train_dataset = PairDataset(train_first, train_second, train_third) 99 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 100 | shuffle=True, drop_last=True, num_workers=num_workers) 101 | train_steps_per_epoch = len(train_first) // batch_size 102 | else: 103 | train_loader = None 104 | 105 | if valid_flag: 106 | valid_dataset = PairDataset(valid_first, valid_second, valid_third) 107 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, 108 | shuffle=False, drop_last=True, num_workers=num_workers) 109 | valid_steps_per_epoch = len(valid_first) // batch_size 110 | else: 111 | valid_loader = None 112 | 113 | return train_loader, valid_loader, train_steps_per_epoch, valid_steps_per_epoch 114 | -------------------------------------------------------------------------------- /dialogue/tensorflow/scheduled_sampling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """应用scheduled_sampling的transformer模型核心core 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from dialogue.tensorflow.nlu.model import decoder 23 | from dialogue.tensorflow.nlu.model import encoder 24 | from dialogue.tensorflow.utils import combine_mask 25 | from dialogue.tensorflow.utils import create_padding_mask 26 | 27 | 28 | def gumbel_softmax(inputs: tf.Tensor, alpha: float): 29 | """ 30 | 按照论文中的公式,实现GumbelSoftmax,具体见论文公式 31 | :param inputs: 输入 32 | :param alpha: 温度 33 | :return: 混合Gumbel噪音后,做softmax以及argmax之后的输出 34 | """ 35 | uniform = tf.random.uniform(shape=tf.shape(inputs), maxval=1, minval=0) 36 | # 以给定输入的形状采样Gumbel噪声 37 | gumbel_noise = -tf.math.log(-tf.math.log(uniform)) 38 | # 将Gumbel噪声添加到输入中,输入第三维就是分数 39 | gumbel_outputs = inputs + gumbel_noise 40 | gumbel_outputs = tf.cast(gumbel_outputs, dtype=tf.float32) 41 | # 在给定温度下,进行softmax并返回 42 | gumbel_outputs = tf.nn.softmax(alpha * gumbel_outputs) 43 | gumbel_outputs = tf.argmax(gumbel_outputs, axis=-1) 44 | return tf.cast(gumbel_outputs, dtype=tf.float32) 45 | 46 | 47 | def embedding_mix(gumbel_inputs: tf.Tensor, inputs: tf.Tensor): 48 | """ 49 | 将输入和gumbel噪音混合嵌入,线性衰减 50 | :param gumbel_inputs: 噪音输入 51 | :param inputs: 输入 52 | :return: 混合嵌入 53 | """ 54 | probability = tf.random.uniform(shape=tf.shape(inputs), maxval=1, minval=0, dtype=tf.float32) 55 | return tf.where(probability < 0.3, x=gumbel_inputs, y=inputs) 56 | 57 | 58 | def transformer_scheduled_sample(vocab_size, num_layers, units, d_model, num_heads, 59 | dropout, alpha=1.0, name="transformer_scheduled_sample") -> tf.keras.Model: 60 | """ 61 | Transformer应用Scheduled Sample 62 | :param vocab_size: token大小 63 | :param num_layers: 编码解码层的数量 64 | :param units: 单元大小 65 | :param d_model: 词嵌入维度 66 | :param num_heads:多头注意力的头部层数量 67 | :param dropout: dropout的权重 68 | :param alpha: 温度 69 | :param name: 名称 70 | :return: Scheduled Sample的Transformer 71 | """ 72 | inputs = tf.keras.Input(shape=(None,), name="inputs") 73 | dec_inputs = tf.keras.Input(shape=(None,), name="dec_inputs") 74 | 75 | # 使用了Lambda将方法包装成层,为的是满足函数式API的需要 76 | enc_padding_mask = tf.keras.layers.Lambda( 77 | create_padding_mask, output_shape=(1, 1, None), 78 | name="enc_padding_mask" 79 | )(inputs) 80 | 81 | look_ahead_mask = tf.keras.layers.Lambda( 82 | combine_mask, output_shape=(1, None, None), 83 | name="look_ahead_mask" 84 | )(dec_inputs) 85 | 86 | dec_padding_mask = tf.keras.layers.Lambda( 87 | create_padding_mask, output_shape=(1, 1, None), 88 | name="dec_padding_mask" 89 | )(inputs) 90 | 91 | enc_outputs = encoder( 92 | vocab_size=vocab_size, num_layers=num_layers, units=units, 93 | d_model=d_model, num_heads=num_heads, dropout=dropout 94 | )(inputs=[inputs, enc_padding_mask]) 95 | 96 | transformer_decoder = decoder( 97 | vocab_size=vocab_size, num_layers=num_layers, units=units, 98 | d_model=d_model, num_heads=num_heads, dropout=dropout 99 | ) 100 | 101 | dec_first_outputs = transformer_decoder(inputs=[dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask]) 102 | 103 | # dec_outputs的几种方式 104 | # 1. dec_outputs = tf.argmax(dec_outputs, axis=-1) # 使用这个方式的话,就是直接返回最大的概率用来作为decoder的inputs 105 | # 2. tf.layers.Sparsemax(axis=-1)(dec_outputs) # 使用Sparsemax的方法,具体公式参考论文 106 | # 3. tf.math.top_k() # 混合top-k嵌入,使用得分最高的5个词汇词嵌入的加权平均值。 107 | # 4. 使用GumbelSoftmax的方法,具体公式参考论文,下面就用GumbelSoftmax方法 108 | # 这里使用论文的第四种方法:GumbelSoftmax 109 | gumbel_outputs = gumbel_softmax(dec_first_outputs, alpha=alpha) 110 | dec_first_outputs = embedding_mix(gumbel_outputs, dec_inputs) 111 | 112 | dec_second_outputs = transformer_decoder(inputs=[dec_first_outputs, enc_outputs, look_ahead_mask, dec_padding_mask]) 113 | outputs = tf.keras.layers.Dense(units=vocab_size, name="outputs")(dec_second_outputs) 114 | return tf.keras.Model(inputs=[inputs, dec_inputs], outputs=outputs, name=name) 115 | -------------------------------------------------------------------------------- /dialogue/pytorch/seq2seq/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """seq2seq的Pytorch实现核心core 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import torch 22 | import torch.nn as nn 23 | from typing import Tuple 24 | from dialogue.pytorch.layers import BahdanauAttention 25 | 26 | 27 | class Encoder(nn.Module): 28 | """ seq2seq的encoder """ 29 | 30 | def __init__(self, vocab_size: int, embedding_dim: int, enc_units: int, num_layers: int, 31 | dropout: float, cell_type: str = "lstm", if_bidirectional: bool = True) -> None: 32 | """ 33 | :param vocab_size: 词汇量大小 34 | :param embedding_dim: 词嵌入维度 35 | :param enc_units: encoder单元大小 36 | :param num_layers: encoder中内部RNN层数 37 | :param dropout: 采样率 38 | :param if_bidirectional: 是否双向 39 | :param cell_type: cell类型,lstm/gru, 默认lstm 40 | :return: Seq2Seq的Encoder 41 | """ 42 | super(Encoder, self).__init__() 43 | self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim) 44 | 45 | if cell_type == "lstm": 46 | self.rnn = nn.LSTM(input_size=embedding_dim, hidden_size=enc_units, 47 | num_layers=num_layers, bidirectional=if_bidirectional) 48 | elif cell_type == "gru": 49 | self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=enc_units, 50 | num_layers=num_layers, bidirectional=if_bidirectional) 51 | 52 | self.dropout = nn.Dropout(p=dropout) 53 | 54 | def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 55 | """ 56 | :param inputs: encoder的输入 57 | """ 58 | inputs = self.embedding(inputs) 59 | dropout = self.dropout(inputs) 60 | outputs, (state, _) = self.rnn(dropout) 61 | # 这里使用了双向GRU,所以这里将两个方向的特征层合并起来,维度将会是units * 2 62 | state = torch.cat((state[-2, :, :], state[-1, :, :]), dim=1) 63 | return outputs, state 64 | 65 | 66 | class Decoder(nn.Module): 67 | """ seq2seq的decoder 68 | 69 | :param vocab_size: 词汇量大小 70 | :param embedding_dim: 词嵌入维度 71 | :param enc_units: encoder单元大小 72 | :param dec_units: decoder单元大小 73 | :param num_layers: encoder中内部RNN层数 74 | :param dropout: 采样率 75 | :param cell_type: cell类型,lstm/gru, 默认lstm 76 | :param if_bidirectional: 是否双向 77 | :return: Seq2Seq的Encoder 78 | """ 79 | 80 | def __init__(self, vocab_size: int, embedding_dim: int, enc_units: int, dec_units: int, num_layers: int, 81 | dropout: float, cell_type: str = "lstm", if_bidirectional: bool = True) -> None: 82 | super(Decoder, self).__init__() 83 | self.vocab_size = vocab_size 84 | self.attention = BahdanauAttention(enc_units=enc_units, dec_units=dec_units) 85 | self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim) 86 | if cell_type == "lstm": 87 | self.rnn = nn.LSTM(input_size=enc_units * 2 + embedding_dim, hidden_size=dec_units, 88 | num_layers=num_layers, bidirectional=if_bidirectional) 89 | elif cell_type == "gru": 90 | self.rnn = nn.GRU(input_size=enc_units * 2 + embedding_dim, hidden_size=dec_units, 91 | num_layers=num_layers, bidirectional=if_bidirectional) 92 | self.fc = nn.Linear(in_features=2 * enc_units + 2 * dec_units + embedding_dim, out_features=vocab_size) 93 | self.dropout = nn.Dropout(dropout) 94 | 95 | def forward(self, inputs: torch.Tensor, hidden: torch.Tensor, 96 | enc_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 97 | """ 98 | :param inputs: decoder的输入 99 | :param hidden: encoder的hidden 100 | :param enc_output: encoder的输出 101 | """ 102 | embedding = self.embedding(inputs) 103 | embedding = self.dropout(embedding) 104 | context_vector, attention_weights = self.attention(hidden, enc_output) 105 | 106 | rnn_input = torch.cat((embedding, torch.unsqueeze(context_vector, dim=0)), dim=-1) 107 | rnn_output, (dec_state, _) = self.rnn(rnn_input) 108 | output = self.fc(torch.cat((embedding, context_vector.unsqueeze(dim=0), rnn_output), dim=-1)) 109 | 110 | return output, dec_state.squeeze(0) 111 | -------------------------------------------------------------------------------- /README.CN.md: -------------------------------------------------------------------------------- 1 | # NLP-Dialogue | Still Work 2 | 3 | [![Blog](https://img.shields.io/badge/blog-@DengBoCong-blue.svg?style=social)](https://www.zhihu.com/people/dengbocong) 4 | [![Paper Support](https://img.shields.io/badge/paper-repo-blue.svg?style=social)](https://github.com/DengBoCong/nlp-paper) 5 | ![Stars Thanks](https://img.shields.io/badge/Stars-thanks-brightgreen.svg?style=social&logo=trustpilot) 6 | ![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=social&logo=appveyor) 7 | 8 | # 项目正在优化架构,可执行代码已经标记tag 9 | 10 | 一个能够部署执行的全流程对话系统 11 | + TensorFlow模型 12 | + Transformer 13 | + Seq2Seq 14 | + SMN检索式模型 15 | + Scheduled Sampling的Transformer 16 | + GPT2 17 | + Task Dialogue 18 | + Pytorch模型 19 | + Transformer 20 | + Seq2Seq 21 | 22 | # 项目说明 23 | 24 | 本项目奔着构建一个能够在线部署对话系统,同时包含开放域和面向任务型两种对话系统,针对相关模型进行复现,论文阅读笔记放置另一个项目:[nlp-paper](https://github.com/DengBoCong/nlp-paper),项目中使用TensorFlow和Pytorch进行实现。 25 | 26 | # 语料 27 | 仓库中的[data](https://github.com/DengBoCong/nlp-dialogue/tree/main/dialogue/data)目录下放着各语料的玩具数据,可用于验证系统执行性,完整语料以及Paper可以在[这里](https://github.com/DengBoCong/nlp-paper)查看 28 | 29 | + LCCC 30 | + CrossWOZ 31 | + 小黄鸡 32 | + 豆瓣 33 | + Ubuntu 34 | + 微博 35 | + 青云 36 | + 贴吧 37 | 38 | # 执行说明 39 | 40 | + Linux执行run.sh,项目工程目录检查执行check.sh(或check.py) 41 | + 根目录下的actuator.py为总执行入口,通过调用如下指令格式执行(执行前注意安装requirements.txt): 42 | ``` 43 | python actuator.py --version [Options] --model [Options] ... 44 | ``` 45 | + 通过根目录下的actuator.py进行执行时,`--version`、`--model`和`--act`为必传参数,其中`--version`为代码版本`tf/torch`,`--model`为执行对应的模型`transformer/smn...`,而act为执行模式(缺省状态下为`pre_treat`模式),更详细指令参数参见各模型下的`actuator.py`或config目录下的对应json配置文件。 46 | + `--act`执行模式说明如下: 47 | + pre_treat模式为文本预处理模式,如果在没有分词结果集以及字典的情况下,需要先运行pre_treat模式 48 | + train模式为训练模式 49 | + evaluate模式为指标评估模式 50 | + chat模式为对话模式,chat模式下运行时,输入ESC即退出对话。 51 | + 正常执行顺序为pre_treat->train->evaluate->chat 52 | + 各模型下单独有一个actuator.py,可以绕开外层耦合进行执行开发,不过执行时注意调整工程目录路径 53 | 54 | # 目录结构说明 55 | + dialogue下为相关模型的核心代码放置位置,方便日后进行封装打包等 56 | + checkpoints为检查点保存位置 57 | + config为配置文件保存目录 58 | + data为原始数据储存位置,同时,在模型执行过程中产生的中间数据文件也保存在此目录下 59 | + models为模型保存目录 60 | + tensorflow及pytorch放置模型构建以及各模组执行的核心代码 61 | + preprocess_corpus.py为语料处理脚本,对各语料进行单轮和多轮对话的处理,并规范统一接口调用 62 | + read_data.py用于load_dataset.py的数据加载格式调用 63 | + metrics.py为各项指标脚本 64 | + tools.py为工具脚本,保存有分词器、日志操作、检查点保存/加载脚本等 65 | + docs下放置文档说明,包括模型论文阅读笔记 66 | + docker(mobile)用于服务端(移动终端)部署脚本 67 | + server为UI服务界面,使用flask进行构建使用,执行对应的server.py即可 68 | + tools为预留工具目录 69 | + actuator.py(run.sh)为总执行器入口 70 | + check.py(check.sh)为工程目录检查脚本 71 | 72 | 73 | # SMN模型运行说明 74 | SMN检索式对话系统使用前需要准备solr环境,solr部署系统环境推荐Linux,工具推荐使用容器部署(推荐Docker),并准备: 75 | + Solr(8.6.3) 76 | + pysolr(3.9.0) 77 | 78 | 以下提供简要说明,更详细可参见文章:[搞定检索式对话系统的候选response检索--使用pysolr调用Solr](https://zhuanlan.zhihu.com/p/300165220) 79 | ## Solr环境 80 | 需要保证solr在线上运行稳定,以及方便后续维护,请使用DockerFile进行部署,DockerFile获取地址:[docker-solr](https://github.com/docker-solr/docker-solr) 81 | 82 | 仅测试模型使用,可使用如下最简构建指令: 83 | ``` 84 | docker pull solr:8.6.3 85 | # 然后启动solr 86 | docker run -itd --name solr -p 8983:8983 solr:8.6.3 87 | # 然后创建core核心选择器,这里取名smn(可选) 88 | docker exec -it --user=solr solr bin/solr create_core -c smn 89 | ``` 90 | 91 | 关于solr中分词工具有IK Analyzer、Smartcn、拼音分词器等等,需要下载对应jar,然后在Solr核心配置文件managed-schema中添加配置。 92 | 93 | **特别说明**:如果使用TF-IDF,还需要在managed-schema中开启相似度配置。 94 | ## Python中使用说明 95 | 线上部署好Solr之后,在Python中使用pysolr进行连接使用: 96 | ``` 97 | pip install pysolr 98 | ``` 99 | 100 | 添加索引数据(一般需要先安全检查)方式如下。将回复数据添加索引,responses是一个json,形式如:[{},{},{},...],里面每个对象构建按照你回复的需求即可: 101 | ``` 102 | solr = pysolr.Solr(url=solr_server, always_commit=True, timeout=10) 103 | # 安全检查 104 | solr.ping() 105 | solr.add(docs=responses) 106 | ``` 107 | 108 | 查询方式如下,以TF-IDF查询所有语句query语句方式如下: 109 | ``` 110 | {!func}sum(product(idf(utterance,key1),tf(utterance,key1),product(idf(utterance,key2),tf(utterance,key2),...) 111 | ``` 112 | 113 | 使用前需要先将数据添加至Solr,在本SMN模型中使用,先执行pre_treat模式即可。 114 | 115 | # Demo概览 116 | 117 | 118 | 119 | # 参考代码和文献 120 | 1. [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/250946855):Transformer的开山之作,值得精读 | Ashish et al,2017 121 | 2. [Sequential Matching Network: A New Architecture for Multi-turn Response Selection in Retrieval-Based Chatbots](https://arxiv.org/pdf/1612.01627v2.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/270554147):SMN检索式对话模型,多层多粒度提取信息 | Devlin et al,2018 122 | 3. [Massive Exploration of Neural Machine Translation Architectures](https://arxiv.org/pdf/1703.03906.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/328801239):展示了以NMT架构超参数为例的首次大规模分析,实验为构建和扩展NMT体系结构带来了新颖的见解和实用建议。 | Denny et al,2017 123 | 4. [Scheduled Sampling for Transformers](https://arxiv.org/pdf/1906.07651.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/267146739):在Transformer应用Scheduled Sampling | Mihaylova et al,2019 124 | 125 | # License 126 | Licensed under the Apache License, Version 2.0. Copyright 2021 DengBoCong. [Copy of the license](). -------------------------------------------------------------------------------- /dialogue/tools/read_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """用于加载预处理好的数据 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from typing import Tuple 23 | from dialogue.tools import pad_sequences 24 | from dialogue.tools import Tokenizer 25 | 26 | 27 | def read_data(data_path: str, max_data_size: int, max_sentence: int, 28 | data_type: str, tokenizer: Tokenizer, **kwargs) -> Tuple: 29 | """ 中转读取数据 30 | 31 | :param data_path: 分词文本路径 32 | :param max_data_size: 读取的数据量大小 33 | :param max_sentence: 最大序列长度 34 | :param data_type: 读取数据类型,单轮/多轮 35 | :param tokenizer: 传入现有的分词器,默认重新生成 36 | :return: 输入序列张量、目标序列张量和分词器 37 | """ 38 | operation = { 39 | "read_single_data": lambda: _read_single_data( 40 | data_path=data_path, max_data_size=max_data_size, max_sentence=max_sentence, tokenizer=tokenizer), 41 | "read_multi_turn_data": lambda: _read_multi_turn_data( 42 | data_path=data_path, max_data_size=max_data_size, max_utterance=kwargs.get("max_utterance"), 43 | max_sentence=max_sentence, tokenizer=tokenizer) 44 | } 45 | 46 | return operation.get(data_type)() 47 | 48 | 49 | def _read_single_data(data_path: str, max_data_size: int, 50 | max_sentence: int, tokenizer: Tokenizer) -> Tuple: 51 | """ 读取单轮问答数据,将input和target进行分词后,与样本权重一同返回 52 | 53 | :param data_path: 分词文本路径 54 | :param max_data_size: 读取的数据量大小 55 | :param max_sentence: 最大序列长度 56 | :param tokenizer: 传入现有的分词器,默认重新生成 57 | :return: 输入序列张量、目标序列张量和分词器 58 | """ 59 | if not os.path.exists(data_path): 60 | print("不存在已经分词好的文件,请检查数据集或执行pre_treat模式") 61 | exit(0) 62 | 63 | with open(data_path, "r", encoding="utf-8") as file: 64 | sample_weights = [] 65 | qa_pairs = [] 66 | count = 0 # 用于处理数据计数 67 | 68 | for line in file: 69 | # 文本数据中的问答对权重通过在问答对尾部添加“<|>”配置 70 | temp = line.strip().strip("\n").replace("/", "").split("<|>") 71 | qa_pairs.append([sentence for sentence in temp[0].split("\t")]) 72 | # 如果没有配置对应问答对权重,则默认为1. 73 | if len(temp) == 1: 74 | sample_weights.append(float(1)) 75 | else: 76 | sample_weights.append(float(temp[1])) 77 | 78 | count += 1 79 | if max_data_size == count: 80 | break 81 | 82 | (input_lang, target_lang) = zip(*qa_pairs) 83 | 84 | input_tensor = tokenizer.texts_to_sequences(input_lang) 85 | target_tensor = tokenizer.texts_to_sequences(target_lang) 86 | 87 | input_tensor = pad_sequences(input_tensor, maxlen=max_sentence, padding="post") 88 | target_tensor = pad_sequences(target_tensor, maxlen=max_sentence, padding="post") 89 | 90 | return input_tensor, target_tensor, sample_weights 91 | 92 | 93 | def _read_multi_turn_data(data_path: str, max_data_size: int, max_utterance: int, 94 | max_sentence: int, tokenizer: Tokenizer) -> Tuple: 95 | """ 读取多轮对话数据,将utterance和response进行分词后,同label等数据一并返回 96 | 97 | :param data_path: 分词文本路径 98 | :param max_data_size: 读取的数据量大小 99 | :param max_utterance: 每轮对话最大对话数 100 | :param max_sentence: 单个句子最大长度 101 | :param tokenizer: 传入现有的分词器,默认重新生成 102 | :return: 输入序列张量、目标序列张量和分词器 103 | """ 104 | if not os.path.exists(data_path): 105 | print("不存在已经分词好的文件,请检查数据集或执行pre_treat模式") 106 | exit(0) 107 | 108 | history = [] # 用于保存每轮对话历史语句 109 | response = [] # 用于保存每轮对话的回答 110 | label = [] # 用于保存每轮对话的标签 111 | count = 0 # 用于处理数据计数 112 | 113 | with open(data_path, "r", encoding="utf-8") as file: 114 | for line in file: 115 | apart = line.strip().strip("\n").replace("/", "").split("\t") 116 | label.append(int(apart[0])) 117 | response.append(apart[-1]) 118 | del apart[0] 119 | del apart[-1] 120 | history.append(apart) 121 | 122 | count += 1 123 | if max_data_size == count: 124 | break 125 | 126 | response = tokenizer.texts_to_sequences(response) 127 | response = pad_sequences(response, maxlen=max_sentence, padding="post") 128 | 129 | utterances = [] 130 | for utterance in history: 131 | # 注意了,这边要取每轮对话的最后max_utterances数量的语句 132 | utterance_padding = tokenizer.texts_to_sequences(utterance)[-max_utterance:] 133 | utterance_len = len(utterance_padding) 134 | # 如果当前轮次中的历史语句不足max_utterances数量,需要在尾部进行填充 135 | if utterance_len != max_utterance: 136 | utterance_padding = [[0]] * (max_utterance - utterance_len) + utterance_padding 137 | utterances.append(pad_sequences(utterance_padding, maxlen=max_sentence, padding="post").tolist()) 138 | 139 | return utterances, response, label 140 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/task_chatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import tensorflow as tf 5 | import model.model as task 6 | from common.kb import load_kb 7 | 8 | sys.path.append(sys.path[0][:-10]) 9 | from model.chatter import Chatter 10 | import common.data_utils as _data 11 | from common.common import CmdParser 12 | import config.get_config as _config 13 | from common.pre_treat import preprocess_raw_task_data 14 | 15 | 16 | class TaskChatter(Chatter): 17 | """ 18 | Task模型的聊天器 19 | """ 20 | 21 | def __init__(self, checkpoint_dir, beam_size): 22 | super().__init__(checkpoint_dir, beam_size) 23 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 24 | self.checkpoint_prefix = os.path.join(self.checkpoint_dir, "ckpt") 25 | self.optimizer = tf.keras.optimizers.RMSprop() 26 | self.train_loss = tf.keras.metrics.Mean(name='train_loss') 27 | 28 | def _init_loss_accuracy(self): 29 | print('待完善') 30 | 31 | def _train_step(self, inp, tar, step_loss): 32 | print('待完善') 33 | 34 | def _create_predictions(self, inputs, dec_input, t): 35 | print('待完善') 36 | 37 | def train(self, dict_fn, data_fn, start_sign, end_sign, max_train_data_size): 38 | _, _, lang_tokenizer = _data.load_dataset(dict_fn=dict_fn, data_fn=data_fn, start_sign=start_sign, 39 | end_sign=end_sign, max_train_data_size=max_train_data_size) 40 | data_load = _data.load_data(_config.dialogues_train, _config.max_length, _config.database, _config.ontology, 41 | lang_tokenizer.word_index, _config.max_train_data_size, _config.kb_indicator_len) 42 | 43 | model = task.task(_config.units, data_load.onto, 44 | _config.vocab_size, _config.embedding_dim, _config.max_length) 45 | 46 | checkpoint = tf.train.Checkpoint(model=model, optimizer=self.optimizer) 47 | ckpt = tf.io.gfile.listdir(self.checkpoint_dir) 48 | if ckpt: 49 | checkpoint.restore(tf.train.latest_checkpoint(self.checkpoint_dir)).expect_partial() 50 | 51 | sample_sum = len(data_load) 52 | for epoch in range(_config.epochs): 53 | print('Epoch {}/{}'.format(epoch + 1, _config.epochs)) 54 | start_time = time.time() 55 | 56 | batch_sum = 0 57 | 58 | while (True): 59 | _, _, _, usr_utts, _, state_gt, kb_indicator, _ = data_load.next() 60 | if data_load.cur == 0: 61 | break 62 | kb_indicator = tf.convert_to_tensor(kb_indicator) 63 | with tf.GradientTape() as tape: 64 | state_preds = model(inputs=[usr_utts, kb_indicator]) 65 | loss = 0 66 | for key in state_preds: 67 | loss += tf.keras.losses.SparseCategoricalCrossentropy( 68 | from_logits=True)(state_gt[key], state_preds[key]) 69 | gradients = tape.gradient(loss, model.trainable_variables) 70 | self.optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 71 | self.train_loss(loss) 72 | kb = load_kb(_config.database, "name") 73 | 74 | batch_sum = batch_sum + len(usr_utts) 75 | print('\r', '{}/{} [==================================]'.format(batch_sum, sample_sum), end='', 76 | flush=True) 77 | step_time = (time.time() - start_time) 78 | sys.stdout.write(' - {:.4f}s/step - loss: {:.4f}\n' 79 | .format(step_time, self.train_loss.result())) 80 | sys.stdout.flush() 81 | checkpoint.save(file_prefix=self.checkpoint_prefix) 82 | print('训练结束') 83 | 84 | 85 | def main(): 86 | parser = CmdParser(version='%task chatbot V1.0') 87 | parser.add_option("-t", "--type", action="store", type="string", 88 | dest="type", default="pre_treat", 89 | help="execute type, pre_treat/train/chat") 90 | (options, args) = parser.parse_args() 91 | 92 | chatter = TaskChatter(checkpoint_dir=_config.task_train_data, beam_size=_config.beam_size) 93 | 94 | if options.type == 'train': 95 | chatter.train(dict_fn=_config.dict_fn, 96 | data_fn=_config.dialogues_tokenized, 97 | start_sign='', 98 | end_sign='', 99 | max_train_data_size=0) 100 | elif options.type == 'chat': 101 | print('Agent: 你好!结束聊天请输入ESC。') 102 | while True: 103 | req = input('User: ') 104 | if req == 'ESC': 105 | print('Agent: 再见!') 106 | exit(0) 107 | # response = chatter.respond(req) 108 | response = '待完善' 109 | print('Agent: ', response) 110 | elif options.type == 'pre_treat': 111 | preprocess_raw_task_data(raw_data=_config.dialogues_train, 112 | tokenized_data=_config.dialogues_tokenized, 113 | semi_dict=_config.semi_dict, 114 | database=_config.database, 115 | ontology=_config.ontology) 116 | else: 117 | parser.error(msg='') 118 | 119 | 120 | if __name__ == "__main__": 121 | """ 122 | TaskModel入口:指令需要附带运行参数 123 | cmd:python task_chatter.py -t/--type [执行模式] 124 | 执行类别:pre_treat/train/chat 125 | 126 | chat模式下运行时,输入exit即退出对话 127 | """ 128 | main() 129 | -------------------------------------------------------------------------------- /dialogue/tensorflow/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TensorFlow版本下的公共layers 16 | """ 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def bahdanau_attention(units: int, query_dim: int, value_dim: int, d_type: tf.dtypes.DType = tf.float32, 22 | name: str = "bahdanau_attention") -> tf.keras.Model: 23 | """Bahdanau Attention实现 24 | 25 | :param units: 26 | :param query_dim: query最后一个维度 27 | :param value_dim: value最后一个维度 28 | :param d_type: 运算精度 29 | :param name: 名称 30 | """ 31 | query = tf.keras.Input(shape=(query_dim,), dtype=d_type, name="{}_query".format(name)) 32 | value = tf.keras.Input(shape=(None, value_dim), dtype=d_type, name="{}_value".format(name)) 33 | hidden_with_time_axis = tf.expand_dims(query, 1) 34 | 35 | state = tf.keras.layers.Dense(units=units, dtype=d_type, name="{}_state_dense".format(name))(value) 36 | hidden = tf.keras.layers.Dense(units=units, dtype=d_type, name="{}_hidden_dense".format(name))( 37 | hidden_with_time_axis) 38 | effect = tf.nn.tanh(x=state + hidden, name="{}_tanh".format(name)) 39 | score = tf.keras.layers.Dense(units=1, dtype=d_type, name="{}_score_dense".format(name))(effect) 40 | 41 | attention_weights = tf.nn.softmax(logits=score, axis=1, name="{}_softmax".format(name)) 42 | context_vector = attention_weights * value 43 | context_vector = tf.reduce_sum(input_tensor=context_vector, axis=1, name="{}_reduce_sum".format(name)) 44 | 45 | return tf.keras.Model(inputs=[query, value], outputs=[context_vector, attention_weights]) 46 | 47 | 48 | # 点积注意力 49 | def scaled_dot_product_attention(q, k, v, mask=None): 50 | """计算注意力权重。 51 | q, k, v 必须具有匹配的前置维度。 52 | k, v 必须有匹配的倒数第二个维度,例如:seq_len_k = seq_len_v。 53 | 虽然 mask 根据其类型(填充或前瞻)有不同的形状, 54 | 但是 mask 必须能进行广播转换以便求和。 55 | 56 | 参数: 57 | q: 请求的形状 == (..., seq_len_q, depth) 58 | k: 主键的形状 == (..., seq_len_k, depth) 59 | v: 数值的形状 == (..., seq_len_v, depth_v) 60 | mask: Float 张量,其形状能转换成 61 | (..., seq_len_q, seq_len_k)。默认为None。 62 | 63 | 返回值: 64 | 输出,注意力权重 65 | """ 66 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 67 | 68 | # 缩放 matmul_qk 69 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 70 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 71 | 72 | # 将 mask 加入到缩放的张量上。 73 | if mask is not None: 74 | scaled_attention_logits += (mask * -1e9) 75 | 76 | # softmax 在最后一个轴(seq_len_k)上归一化,因此分数相加等于1。 77 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) 78 | 79 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) 80 | 81 | return output, attention_weights 82 | 83 | 84 | # 多头注意力层 85 | class MultiHeadAttention(tf.keras.layers.Layer): 86 | def __init__(self, d_model, num_heads): 87 | super(MultiHeadAttention, self).__init__() 88 | self.num_heads = num_heads 89 | self.d_model = d_model 90 | 91 | assert d_model % self.num_heads == 0 92 | 93 | self.depth = d_model // self.num_heads 94 | 95 | self.wq = tf.keras.layers.Dense(d_model) 96 | self.wk = tf.keras.layers.Dense(d_model) 97 | self.wv = tf.keras.layers.Dense(d_model) 98 | 99 | self.dense = tf.keras.layers.Dense(d_model) 100 | 101 | def split_heads(self, x, batch_size): 102 | """分拆最后一个维度到 (num_heads, depth). 103 | 转置结果使得形状为 (batch_size, num_heads, seq_len, depth) 104 | """ 105 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 106 | return tf.transpose(x, perm=[0, 2, 1, 3]) 107 | 108 | def call(self, v, k, q, mask=None): 109 | batch_size = tf.shape(q)[0] 110 | 111 | q = self.wq(q) # (batch_size, seq_len, d_model) 112 | k = self.wk(k) # (batch_size, seq_len, d_model) 113 | v = self.wv(v) # (batch_size, seq_len, d_model) 114 | 115 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 116 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 117 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 118 | 119 | # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) 120 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) 121 | scaled_attention, attention_weights = scaled_dot_product_attention( 122 | q, k, v, mask) 123 | 124 | scaled_attention = tf.transpose(scaled_attention, 125 | perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth) 126 | 127 | concat_attention = tf.reshape(scaled_attention, 128 | (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model) 129 | 130 | output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model) 131 | 132 | return output, attention_weights 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

NLP-Dialogue

2 | 3 |
4 | 5 | [![Blog](https://img.shields.io/badge/blog-@DengBoCong-blue.svg?style=social)](https://www.zhihu.com/people/dengbocong) 6 | [![Paper Support](https://img.shields.io/badge/paper-repo-blue.svg?style=social)](https://github.com/DengBoCong/nlp-paper) 7 | ![Stars Thanks](https://img.shields.io/badge/Stars-thanks-brightgreen.svg?style=social&logo=trustpilot) 8 | ![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=social&logo=appveyor) 9 | 10 | [comment]: <> ([![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)]()) 11 | 12 |
13 | 14 |

15 | 16 | [English](https://github.com/DengBoCong/nlp-dialogue) | [中文](https://github.com/DengBoCong/nlp-dialogue/blob/main/README.CN.md) 17 | 18 |

19 | 20 | # 架构 21 | 开放域生成问答模型 22 | Leveraging Passage Retrieval with Generative Models for Open Domain Question Answerin 23 | 24 | 检索 25 | Dense Passage Retrieval for Open-Domain Question Answering 26 | 27 | # 工具 28 | APScheduler 29 | 30 | # 项目正在优化架构,可执行代码已经标记tag 31 | 32 | 一个能够部署执行的全流程对话系统 33 | + TensorFlow模型 34 | + Transformer 35 | + Seq2Seq 36 | + SMN检索式模型 37 | + Scheduled Sampling的Transformer 38 | + GPT2 39 | + Task Dialogue 40 | + Pytorch模型 41 | + Transformer 42 | + Seq2Seq 43 | 44 | # 项目说明 45 | 46 | 本项目奔着构建一个能够在线部署对话系统,同时包含开放域和面向任务型两种对话系统,针对相关模型进行复现,论文阅读笔记放置另一个项目:[nlp-paper](https://github.com/DengBoCong/nlp-paper),项目中使用TensorFlow和Pytorch进行实现。 47 | 48 | # 语料 49 | 仓库中的[data](https://github.com/DengBoCong/nlp-dialogue/tree/main/dialogue/data)目录下放着各语料的玩具数据,可用于验证系统执行性,完整语料以及Paper可以在[这里](https://github.com/DengBoCong/nlp-paper)查看 50 | 51 | + LCCC 52 | + CrossWOZ 53 | + 小黄鸡 54 | + 豆瓣 55 | + Ubuntu 56 | + 微博 57 | + 青云 58 | + 贴吧 59 | 60 | # 执行说明 61 | 62 | + Linux执行run.sh,项目工程目录检查执行check.sh(或check.py) 63 | + 根目录下的actuator.py为总执行入口,通过调用如下指令格式执行(执行前注意安装requirements.txt): 64 | ``` 65 | python actuator.py --version [Options] --model [Options] ... 66 | ``` 67 | + 通过根目录下的actuator.py进行执行时,`--version`、`--model`和`--act`为必传参数,其中`--version`为代码版本`tf/torch`,`--model`为执行对应的模型`transformer/smn...`,而act为执行模式(缺省状态下为`pre_treat`模式),更详细指令参数参见各模型下的`actuator.py`或config目录下的对应json配置文件。 68 | + `--act`执行模式说明如下: 69 | + pre_treat模式为文本预处理模式,如果在没有分词结果集以及字典的情况下,需要先运行pre_treat模式 70 | + train模式为训练模式 71 | + evaluate模式为指标评估模式 72 | + chat模式为对话模式,chat模式下运行时,输入ESC即退出对话。 73 | + 正常执行顺序为pre_treat->train->evaluate->chat 74 | + 各模型下单独有一个actuator.py,可以绕开外层耦合进行执行开发,不过执行时注意调整工程目录路径 75 | 76 | # 目录结构说明 77 | + dialogue下为相关模型的核心代码放置位置,方便日后进行封装打包等 78 | + checkpoints为检查点保存位置 79 | + config为配置文件保存目录 80 | + data为原始数据储存位置,同时,在模型执行过程中产生的中间数据文件也保存在此目录下 81 | + models为模型保存目录 82 | + tensorflow及pytorch放置模型构建以及各模组执行的核心代码 83 | + preprocess_corpus.py为语料处理脚本,对各语料进行单轮和多轮对话的处理,并规范统一接口调用 84 | + read_data.py用于load_dataset.py的数据加载格式调用 85 | + metrics.py为各项指标脚本 86 | + tools.py为工具脚本,保存有分词器、日志操作、检查点保存/加载脚本等 87 | + docs下放置文档说明,包括模型论文阅读笔记 88 | + docker(mobile)用于服务端(移动终端)部署脚本 89 | + server为UI服务界面,使用flask进行构建使用,执行对应的server.py即可 90 | + tools为预留工具目录 91 | + actuator.py(run.sh)为总执行器入口 92 | + check.py(check.sh)为工程目录检查脚本 93 | 94 | 95 | # SMN模型运行说明 96 | SMN检索式对话系统使用前需要准备solr环境,solr部署系统环境推荐Linux,工具推荐使用容器部署(推荐Docker),并准备: 97 | + Solr(8.6.3) 98 | + pysolr(3.9.0) 99 | 100 | 以下提供简要说明,更详细可参见文章:[搞定检索式对话系统的候选response检索--使用pysolr调用Solr](https://zhuanlan.zhihu.com/p/300165220) 101 | ## Solr环境 102 | 需要保证solr在线上运行稳定,以及方便后续维护,请使用DockerFile进行部署,DockerFile获取地址:[docker-solr](https://github.com/docker-solr/docker-solr) 103 | 104 | 仅测试模型使用,可使用如下最简构建指令: 105 | ``` 106 | docker pull solr:8.6.3 107 | # 然后启动solr 108 | docker run -itd --name solr -p 8983:8983 solr:8.6.3 109 | # 然后创建core核心选择器,这里取名smn(可选) 110 | docker exec -it --user=solr solr bin/solr create_core -c smn 111 | ``` 112 | 113 | 关于solr中分词工具有IK Analyzer、Smartcn、拼音分词器等等,需要下载对应jar,然后在Solr核心配置文件managed-schema中添加配置。 114 | 115 | **特别说明**:如果使用TF-IDF,还需要在managed-schema中开启相似度配置。 116 | ## Python中使用说明 117 | 线上部署好Solr之后,在Python中使用pysolr进行连接使用: 118 | ``` 119 | pip install pysolr 120 | ``` 121 | 122 | 添加索引数据(一般需要先安全检查)方式如下。将回复数据添加索引,responses是一个json,形式如:[{},{},{},...],里面每个对象构建按照你回复的需求即可: 123 | ``` 124 | solr = pysolr.Solr(url=solr_server, always_commit=True, timeout=10) 125 | # 安全检查 126 | solr.ping() 127 | solr.add(docs=responses) 128 | ``` 129 | 130 | 查询方式如下,以TF-IDF查询所有语句query语句方式如下: 131 | ``` 132 | {!func}sum(product(idf(utterance,key1),tf(utterance,key1),product(idf(utterance,key2),tf(utterance,key2),...) 133 | ``` 134 | 135 | 使用前需要先将数据添加至Solr,在本SMN模型中使用,先执行pre_treat模式即可。 136 | 137 | # Demo概览 138 | 139 | 140 | 141 | # 参考代码和文献 142 | 1. [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/250946855):Transformer的开山之作,值得精读 | Ashish et al,2017 143 | 2. [Sequential Matching Network: A New Architecture for Multi-turn Response Selection in Retrieval-Based Chatbots](https://arxiv.org/pdf/1612.01627v2.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/270554147):SMN检索式对话模型,多层多粒度提取信息 | Devlin et al,2018 144 | 3. [Massive Exploration of Neural Machine Translation Architectures](https://arxiv.org/pdf/1703.03906.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/328801239):展示了以NMT架构超参数为例的首次大规模分析,实验为构建和扩展NMT体系结构带来了新颖的见解和实用建议。 | Denny et al,2017 145 | 4. [Scheduled Sampling for Transformers](https://arxiv.org/pdf/1906.07651.pdf) | [阅读笔记](https://zhuanlan.zhihu.com/p/267146739):在Transformer应用Scheduled Sampling | Mihaylova et al,2019 146 | 147 | # License 148 | Licensed under the Apache License, Version 2.0. Copyright 2021 DengBoCong. [Copy of the license](). 149 | -------------------------------------------------------------------------------- /docs/Attention_Is_All_You_Need.md: -------------------------------------------------------------------------------- 1 | 2 | 提示:阅读论文时进行相关思想、结构、优缺点,内容进行提炼和记录,论文和相关引用会标明出处。 3 | 4 | # 前言 5 | 6 | > 标题:Attention Is All You Need\ 7 | > 原文链接:[Link](https://arxiv.org/pdf/1706.03762.pdf)\ 8 | > 转载请注明:DengBoCong 9 | 10 | # Abstract 11 | 序列转导模型基于复杂的递归或卷积神经网络,包括编码器和解码器,表现最佳的模型还通过注意力机制连接编码器和解码器。我们提出了一种新的简单网络架构,即Transformer,它完全基于注意力机制,完全消除了重复和卷积。在两个机器翻译任务上进行的实验表明,这些模型在质量上具有优势,同时具有更高的可并行性,并且所需的训练时间大大减少。我们的模型在WMT 2014英语到德语的翻译任务上达到了28.4 BLEU,比包括集成学习在内的现有最佳结果提高了2 BLEU。在2014年WMT英语到法语翻译任务中,我们的模型在八个GPU上进行了3.5天的训练后,创造了新的单模型最新BLEU分数41.8,比文献中最好的模型的训练成本更小。我们展示了Transformer通过将其成功应用于具有大量训练数据和有限训练数据的英语解析,将其很好地概括了其他任务。 12 | # Introduction 13 | 在Transformer出现之前,RNN、LSTM、GRU等在序列模型和转导问题的方法中占据了稳固的地位,比如语言模型、机器翻译等,人们一直在努力扩大循环语言模型和编码器-解码器体系结构的界限。递归模型通常沿输入和输出序列的符号位置考虑计算。将位置与计算时间中的步骤对齐,它们根据先前的隐藏状态ht-1和位置t的输入生成一系列隐藏状态ht。这种固有的顺序性导致了没办法并行化进行训练,这在较长的序列长度上变得至关重要。最近的工作通过分解技巧和条件计算大大提高了计算效率,同时在后者的情况下还提高了模型性能,但是,顺序计算的基本约束仍然存在。注意力机制已成为各种任务中引人注目的序列建模和转导模型不可或缺的一部分,允许对依赖项进行建模,而无需考虑它们在输入或输出序列中的距离。在这项工作中,我们提出了一种Transformer,一种避免重复的模型体系结构,而是完全依赖于注意力机制来绘制输入和输出之间的全局依存关系。 14 | 15 | # Background 16 | 减少顺序计算的目标也构成了扩展神经GPU,ByteNet和ConvS2S的基础,它们全部使用卷积神经网络作为基本构建块,并行计算所有输入和输出的隐藏表示。在这些模型中,关联来自两个任意输入或输出位置的信号所需的操作数在位置之间的距离中增加,对于ConvS2S线性增长,而对于ByteNet则对数增长,这使得学习远处位置之间的依存关系变得更加困难。在Transformer中,此操作被减少为恒定的操作次数,尽管以平均注意力加权位置为代价,导致有效分辨率降低,但是我们用多头注意力抵消了这种代价。 17 | 18 | Self-attention(有时称为d intra-attention)是一种与单个序列的不同位置相关的注意力机制,目的是计算序列的表示形式。Self-attention已成功用于各种任务中,包括阅读理解,抽象摘要和学习与任务无关的句子表示。Transformer是第一个完全依靠Self-attention来计算其输入和输出表示的转导模型,而无需使用序列对齐的RNN或卷积。 19 | 20 | # Model Architecture 21 | Transformer依旧是遵循encoder-decoder结构,其模型的每一步都是自回归的,在生成下一个模型时,会将先前生成的符号用作附加输入。在此基础上,使用堆叠式Self-attention和point-wise,并在encoder和decoder中使用全连接层,结构图如下: 22 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200917160031494.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RCQ18xMjE=,size_16,color_FFFFFF,t_70#pic_center) 23 | ## Encoder and Decoder Stacks 24 | + Encoder 25 | + 编码器由$N = 6$个相同层的堆栈组成,每层有两个子层,分别是Self-attention机制和位置完全连接的前馈网络 26 | + 每个子层周围都使用残差连接并进行归一化,也就是说每个子层的输出为$LayerNorm(x+Sublayer(x))$ 27 | + 为了促进这些残差连连接,模型中的所有子层以及嵌入层均产生尺寸为dmodel = 512的输出 28 | 29 | + Decoder 30 | + 解码器还由N = 6个相同层的堆栈组成 31 | + 除了每个编码器层中的两个子层之外,解码器还插入一个第三子层,该子层对编码器堆栈的输出执行多头注意力 32 | + 对编码器堆栈的输出执行多头注意力时,要注意使用mask,保证预测只能依赖于小于当前位置的已知输出。 33 | + 每个子层周围都使用残差连接并进行归一化 34 | 35 | ## Attention 36 | 注意力方法可以描述为将query和一组key-value映射到输出,其中query,key,value和输出都是向量。输出是计算value的加权总和,其中分配给每个value的权重是通过query与相应key的方法来计算的。 37 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200917162420181.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RCQ18xMjE=,size_16,color_FFFFFF,t_70#pic_center) 38 | ### Scaled Dot-Product Attention 39 | 它的输入是$d_k$维的queries和keys组成,使用所有key和query做点积,并除以$\sqrt{d_k}$,然后应用softmax函数获得value的权重,公式如下: 40 | $$Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_k}})V$$ 41 | + 常用注意力方法 42 | + 相加(在更大的$d_k$下,效果更好) 43 | + 点积(更快一些) 44 | + 所以为了在较大的$d_k$下,点积也能工作的好,在公式中才使用了$\frac{1}{\sqrt{d_k}}$ 45 | 46 | ### Multi-Head Attention 47 | 多头注意力使模型可以共同关注来自不同位置的不同表示子空间的信息,最后取平均: 48 | $$ 49 | MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^{O}\\ 50 | head_1=Attention(QW_{i}^{Q},K_{i}^{K},V_{i}^{V}) 51 | $$ 52 | 论文中使用$h=8$注意力层,其中$d_k=d_v=\frac{d_{model}}{h}=64$ 53 | ### Applications of Attention in our Model 54 | Transformer以三种不同方式使用多头注意力: 55 | + 在“encoder-decoder注意”层中,queries来自先前的decoder层,而keys和values来自encoder的输出,这允许解码器中的每个位置都参与输入序列中的所有位置。 56 | + encoder包含self-attention层。 在 self-attention层中,所有key,value和query都来自同一位置,在这种情况下,是编码器中前一层的输出。 57 | + 类似地,decoder中的self-attention层允许decoder中的每个位置都参与decoder中直至并包括该位置的所有位置。我们需要阻止decoder中的向左信息流,以保留自回归属性。 58 | 59 | ## Position-wise Feed-Forward Networks 60 | 除了关注子层之外,我们的encoder和decoder中的每个层还包含一个完全连接的前馈网络,该网络分别应用于每个位置。 这由两个线性变换组成,两个线性变换之间有ReLU激活。 61 | $$ 62 | FNN(x)=max(0,xW_1+b_1)W_2+b_2 63 | $$ 64 | 虽然线性变换在不同位置上相同,但是它们使用不同的参数 65 | ## Embeddings and Softmax 66 | 与其他序列转导模型类似,使用学习的嵌入将输入标记和输出标记转换为维dmodel的向量。我们还使用通常学习的线性变换和softmax函数将解码器输出转换为预测的下一个token概率 67 | ## Positional Encoding 68 | 位置编码的维数dmodel与嵌入的维数相同,因此可以将两者相加,位置编码有很多选择,可以学习和固定。在这项工作中,我们使用不同频率的正弦和余弦函数,其中pos是位置,i是维度。 69 | $$ 70 | PE_{(pos,2i)}=sin(pos/10000^{2i/d_{model}}) 71 | PE_{(pos,2i+1)}=sin(pos/10000^{2i/d_{model}}) 72 | $$ 73 | 也就是说,位置编码的每个维度对应于一个正弦曲线,波长形成从2π到10000·2π的几何级数。当然还有其他的方法,不过选择正弦曲线版本是因为它可以使模型外推到比训练期间遇到的序列长的序列长度 74 | 75 | # Why Self-Attention 76 | 考虑一下三点: 77 | + 每层的总计算复杂度 78 | + 可以并行化的计算量,以所需的最少顺序操作数衡量 79 | + 网络中远程依赖关系之间的路径长度,在许多序列转导任务中,学习远程依赖性是一项关键挑战。影响学习这种依赖性的能力的一个关键因素是网络中前向和后向信号必须经过的路径长度。输入和输出序列中位置的任意组合之间的这些路径越短,学习远程依赖关系就越容易 80 | 81 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200917175632580.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RCQ18xMjE=,size_16,color_FFFFFF,t_70#pic_center) 82 | 作为附带的好处,自我关注可以产生更多可解释的模型 83 | # Training 84 | ## Training Data and Batching 85 | 我们对标准WMT 2014英语-德语数据集进行了培训,该数据集包含约450万个句子对。句子是使用字节对编码的,字节对编码具有大约37000个token的共享源目标词汇。 86 | ## Hardware and Schedule 87 | 大型模型接受了300,000步(3.5天)的训练。 88 | ## Optimizer 89 | 我们使用Adam优化器,其中β1= 0.9,β2= 0.98和$\xi $= 10-9。 根据公式,我们在训练过程中改变了学习率: 90 | $$lrate=d_{model}^{-0.5}\cdot min(step\_num^{-0.5},step\_num\cdot warmup\_steps^{-1.5})$$ 91 | 这对应于第一个warmup_steps训练步骤的线性增加学习率,此后与步骤数的平方根的平方成反比地降低学习率,我们使用的warmup_steps=4000。 92 | ## Regularization 93 | + Residual Dropout 94 | + Label Smoothing 95 | 96 | # Results 97 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200917203338187.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RCQ18xMjE=,size_16,color_FFFFFF,t_70#pic_center) 98 | # Conclusion 99 | 在这项工作中,我们介绍了Transformer,这是完全基于注意力的第一个序列转导模型,用多头自注意力代替了编码器-解码器体系结构中最常用的循环层。对于翻译任务,与基于循环层或卷积层的体系结构相比,可以比在体系结构上更快地训练Transformer。 在WMT 2014英语到德语和WMT 2014英语到法语的翻译任务中,我们都达到了最新水平。 在前一项任务中,我们最好的模型甚至胜过所有先前报告。我们对基于注意力的模型的未来感到兴奋,并计划将其应用于其他任务。 我们计划将Transformer扩展到涉及文本以外的涉及输入和输出形式的问题,并研究局部受限的注意机制,以有效处理大型输入和输出,例如图像,音频和视频。 使生成减少连续性是我们的另一个研究目标。 -------------------------------------------------------------------------------- /dialogue/tensorflow/task/common/pre_treat.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | from collections import defaultdict 5 | from common.data_utils import tokenize_en 6 | from nltk.tokenize import RegexpTokenizer 7 | 8 | 9 | class Delexicalizer: 10 | """ 11 | 去词化器 12 | """ 13 | 14 | def __init__(self, info_slots, semi_dict, values, replaces): 15 | self.info_slots = info_slots # 所有informable槽位信息 16 | self.semi_dict = semi_dict # 语句中槽位同义词替换字典 17 | self.values = values # 数据库中所有requestable槽位信息 18 | self.replaces = replaces 19 | 20 | self.inv_info_slots = self._inverse_dict(self.info_slots, '%s') # informable槽值对字典 21 | self.inv_values = self._inverse_dict(self.values, ' ', 22 | func=lambda x: x.upper()) # requestable槽值对字典,槽位已同义化 23 | self.inv_semi_dict = self._inverse_dict(self.semi_dict, '%s') 24 | 25 | self.inv_semi_dict = {k: " " % self.inv_info_slots[v].upper() 26 | if v in self.inv_info_slots else " " % v.upper() for k, v in self.inv_semi_dict.items()} 27 | 28 | self.num_matcher = re.compile(r' \d{1,2}([., ])') 29 | self.post_matcher = re.compile( 30 | r'( [.]?c\.b[.]?[ ]?\d[ ]?[,]?[ ]?\d[.]?[ ]?[a-z][\.]?[ ]?[a-z][\.]?)|( cb\d\d[a-z]{2})') 31 | self.phone_matcher = re.compile(r'[ (](#?0)?(\d{10}|\d{4}[ ]\d{5,6}|\d{3}-\d{3}-\d{4})[ ).,]') 32 | self.street_matcher = re.compile( 33 | r' (([a-z]+)?\d{1,3}([ ]?-[ ]?\d+)? )?[a-z]+ (street|road|avenue)(, (city [a-z]+))?') 34 | 35 | def _inverse_dict(self, d, fmt="%s ", func=str): 36 | """ 37 | 将字典中key和value转换工具 38 | """ 39 | inv = {} 40 | for k, vs in d.items(): 41 | for v in vs: 42 | inv[v.lower()] = fmt % (func(k)) 43 | return inv 44 | 45 | def delex(self, sent): 46 | """ 47 | 将句子去词化 48 | """ 49 | sent = ' ' + sent.lower() 50 | sent = self.post_matcher.sub(' ', sent) 51 | sent = " , ".join(sent.split(",")) 52 | 53 | # for r, v in self.replaces: 54 | # sent = sent.replace(" " + r + " ", " " + v + " ") 55 | 56 | sent = sent.replace(' ', ' ') 57 | 58 | sent = self.phone_matcher.sub(' ', sent) 59 | for v in sorted(self.inv_values.keys(), key=len, reverse=True): 60 | sent = sent.replace(v, self.inv_values[v]) 61 | 62 | sent = self.street_matcher.sub(' ', sent) 63 | for v in sorted(self.inv_semi_dict.keys(), key=len, reverse=True): 64 | sent = sent.replace(v, self.inv_semi_dict[v]) 65 | 66 | sent = self.num_matcher.sub(' ', sent) 67 | 68 | sent = sent.replace(' ', ' ') 69 | 70 | return sent.strip() 71 | 72 | 73 | def create_delexicaliser(semi_dict_fn, kb_fn, onto_fn, req_slots=["address", "phone", "postcode", "name"]): 74 | """ 75 | 去词化器创建工具 76 | """ 77 | semi_dict = defaultdict(list) 78 | values = defaultdict(list) 79 | 80 | with open(kb_fn) as file: 81 | kb = json.load(file) 82 | 83 | with open(semi_dict_fn) as file: 84 | semi_dict = json.load(file) 85 | 86 | with open(onto_fn) as file: 87 | onto_data = json.load(file) 88 | 89 | for entry in kb: 90 | for slot in req_slots: 91 | if slot in entry: 92 | values[slot].append(entry[slot]) 93 | 94 | # slots = ["area", "food", "pricerange", "address", "phone", "postcode", "name"] 95 | return Delexicalizer(onto_data['informable'], semi_dict, values, '') 96 | 97 | 98 | def convert_delex(diag_fn, delex_fn, output_fn): 99 | """ 100 | 系统回复槽位生成,将结果保存在一个文件中 101 | """ 102 | with open(diag_fn) as file: 103 | dialogues = json.load(file) 104 | 105 | with open(delex_fn) as file: 106 | delexed = file.readlines() 107 | 108 | delex_iter = iter(delexed) 109 | for diag_idx, diag in enumerate(dialogues): 110 | for turn_idx, turn in enumerate(diag['diaglogue']): 111 | dialogues[diag_idx]['diaglogue'][turn_idx]['system_transcript'] = next(delex_iter).replace("\t", "").strip() 112 | 113 | with open(output_fn, 'w', encoding='utf-8') as file: 114 | file.write(json.dumps(dialogues, indent=4, ensure_ascii=False)) 115 | 116 | 117 | def preprocess_raw_task_data(raw_data, tokenized_data, semi_dict, database, ontology): 118 | """ 119 | 专门针对task标注数据的client和agent对话的token数据处理 120 | :param raw_data: 原始对话数据路径 121 | :param tokenized_data: 生成token数据保存路径 122 | :return: 123 | """ 124 | # 首先判断原数据集是否存在,不存在则退出 125 | if not os.path.exists(raw_data): 126 | print('数据集不存在,请添加数据集!') 127 | exit() 128 | 129 | pairs = [] 130 | delex = create_delexicaliser(semi_dict, database, ontology) 131 | tokenizer = RegexpTokenizer(r'<[a-z][.\w]+>|[^<]+') 132 | 133 | with open(raw_data, encoding='utf-8') as file: 134 | pair_count = 0 135 | dialogues = json.load(file) 136 | 137 | for diag in dialogues: 138 | for turn in diag['dialogue']: 139 | user = tokenize_en(turn['transcript'].lower(), tokenizer) 140 | system = tokenize_en(delex.delex(turn['system_transcript']).lower(), tokenizer) 141 | pairs.append([user, system]) 142 | pair_count += 1 143 | if pair_count % 1000 == 0: 144 | print('已处理:', pair_count, '个问答对') 145 | 146 | print('读取完毕,处理中...') 147 | 148 | train_tokenized = open(tokenized_data, 'w', encoding='utf-8') 149 | for i in range(len(pairs)): 150 | train_tokenized.write(' '.join(pairs[i][0]) + '\t' + ' '.join(pairs[i][1]) + '\n') 151 | if i % 1000 == 0: 152 | print('处理进度:', i) 153 | 154 | train_tokenized.close() 155 | -------------------------------------------------------------------------------- /dialogue/tensorflow/seq2seq/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """seq2seq模型核心core 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from dialogue.tensorflow.layers import bahdanau_attention 23 | from dialogue.tools import log_operator 24 | 25 | 26 | def rnn_layer(units: int, input_feature_dim: int, cell_type: str = "lstm", if_bidirectional: bool = True, 27 | d_type: tf.dtypes.DType = tf.float32, name: str = "rnn_layer") -> tf.keras.Model: 28 | """ RNNCell层,其中可定义cell类型,是否双向 29 | 30 | :param units: cell单元数 31 | :param input_feature_dim: 输入的特征维大小 32 | :param cell_type: cell类型,lstm/gru, 默认lstm 33 | :param if_bidirectional: 是否双向 34 | :param d_type: 运算精度 35 | :param name: 名称 36 | :return: Multi-layer RNN 37 | """ 38 | inputs = tf.keras.Input(shape=(None, input_feature_dim), dtype=d_type, name="{}_inputs".format(name)) 39 | if cell_type == "lstm": 40 | rnn = tf.keras.layers.LSTM(units=units, return_sequences=True, return_state=True, 41 | recurrent_initializer="glorot_uniform", dtype=d_type, 42 | name="{}_lstm_cell".format(name)) 43 | elif cell_type == "gru": 44 | rnn = tf.keras.layers.GRU(units=units, return_sequences=True, return_state=True, 45 | recurrent_initializer="glorot_uniform", dtype=d_type, name="{}_gru_cell".format(name)) 46 | else: 47 | print("cell执行了类型执行出错,定位细节参见log") 48 | log_operator(level="INFO").info("cell执行了类型执行出错") 49 | 50 | if if_bidirectional: 51 | rnn = tf.keras.layers.Bidirectional(layer=rnn, dtype=d_type, name="{}_biRnn".format(name)) 52 | 53 | rnn_outputs = rnn(inputs) 54 | outputs = rnn_outputs[0] 55 | states = outputs[:, -1, :] 56 | 57 | return tf.keras.Model(inputs=inputs, outputs=[outputs, states]) 58 | 59 | 60 | def encoder(vocab_size: int, embedding_dim: int, enc_units: int, 61 | num_layers: int, cell_type: str, if_bidirectional: bool = True, 62 | d_type: tf.dtypes.DType = tf.float32, name: str = "encoder") -> tf.keras.Model: 63 | """ 64 | :param vocab_size: 词汇量大小 65 | :param embedding_dim: 词嵌入维度 66 | :param enc_units: 单元大小 67 | :param num_layers: encoder中内部RNN层数 68 | :param cell_type: cell类型,lstm/gru, 默认lstm 69 | :param if_bidirectional: 是否双向 70 | :param d_type: 运算精度 71 | :param name: 名称 72 | :return: Seq2Seq的Encoder 73 | """ 74 | inputs = tf.keras.Input(shape=(None,), dtype=d_type, name="{}_inputs".format(name)) 75 | outputs = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, 76 | dtype=d_type, name="{}_embedding".format(name))(inputs) 77 | 78 | for i in range(num_layers): 79 | outputs, states = rnn_layer( 80 | units=enc_units, input_feature_dim=outputs.shape[-1], cell_type=cell_type, 81 | d_type=d_type, if_bidirectional=if_bidirectional, name="{}_rnn_{}".format(name, i) 82 | )(outputs) 83 | 84 | return tf.keras.Model(inputs=inputs, outputs=[outputs, states]) 85 | 86 | 87 | def decoder(vocab_size: int, embedding_dim: int, dec_units: int, enc_units: int, num_layers: int, 88 | cell_type: str, d_type: tf.dtypes.DType = tf.float32, name: str = "decoder") -> tf.keras.Model: 89 | """ 90 | :param vocab_size: 词汇量大小 91 | :param embedding_dim: 词嵌入维度 92 | :param dec_units: decoder单元大小 93 | :param enc_units: encoder单元大小 94 | :param num_layers: encoder中内部RNN层数 95 | :param cell_type: cell类型,lstm/gru, 默认lstm 96 | :param d_type: 运算精度 97 | :param name: 名称 98 | :return: Seq2Seq的Decoder 99 | """ 100 | inputs = tf.keras.Input(shape=(None,), dtype=d_type, name="{}_inputs".format(name)) 101 | enc_output = tf.keras.Input(shape=(None, enc_units), dtype=d_type, name="{}_enc_output".format(name)) 102 | dec_hidden = tf.keras.Input(shape=(enc_units,), dtype=d_type, name="{}_dec_hidden".format(name)) 103 | 104 | embeddings = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, 105 | dtype=d_type, name="{}_embedding".format(name))(inputs) 106 | context_vector, attention_weight = bahdanau_attention( 107 | units=dec_units, d_type=d_type, query_dim=enc_units, value_dim=enc_units)(inputs=[dec_hidden, enc_output]) 108 | outputs = tf.concat(values=[tf.expand_dims(input=context_vector, axis=1), embeddings], axis=-1) 109 | 110 | for i in range(num_layers): 111 | # Decoder中不允许使用双向 112 | outputs, states = rnn_layer(units=dec_units, input_feature_dim=outputs.shape[-1], cell_type=cell_type, 113 | if_bidirectional=False, d_type=d_type, name="{}_rnn_{}".format(name, i))(outputs) 114 | 115 | outputs = tf.reshape(tensor=outputs, shape=(-1, outputs.shape[-1])) 116 | outputs = tf.keras.layers.Dense(units=vocab_size, dtype=d_type, name="{}_outputs_dense".format(name))(outputs) 117 | 118 | return tf.keras.Model(inputs=[inputs, enc_output, dec_hidden], outputs=[outputs, states, attention_weight]) 119 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/model/chatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import jieba 5 | import tensorflow as tf 6 | from pathlib import Path 7 | import common.data_utils as _data 8 | from configs import configs as _config 9 | from utils.beamsearch import BeamSearch 10 | 11 | 12 | class Chatter(object): 13 | """" 14 | 面向使用者的聊天器基类 15 | 该类及其子类实现和用户间的聊天,即接收聊天请求,产生回复。 16 | 不同模型或方法实现的聊天子类化该类。 17 | """ 18 | 19 | def __init__(self, checkpoint_dir, beam_size): 20 | """ 21 | Transformer聊天器初始化,用于加载模型 22 | """ 23 | self.checkpoint_dir = checkpoint_dir 24 | self.beam_search_container = BeamSearch( 25 | beam_size=beam_size, 26 | max_length=_config.max_length, 27 | worst_score=0 28 | ) 29 | is_exist = Path(checkpoint_dir) 30 | if not is_exist.exists(): 31 | os.makedirs(checkpoint_dir, exist_ok=True) 32 | self.ckpt = tf.io.gfile.listdir(checkpoint_dir) 33 | 34 | def respond(self, req): 35 | """ 对外部聊天请求进行回复 36 | 子类需要利用模型进行推断和搜索以产生回复。 37 | :param req: 外部聊天请求字符串 38 | :return: 系统回复字符串 39 | """ 40 | pass 41 | 42 | def _init_loss_accuracy(self): 43 | """ 44 | 初始化损失 45 | """ 46 | pass 47 | 48 | def _train_step(self, inp, tar, step_loss): 49 | """ 50 | 模型训练步方法,需要返回时间步损失 51 | """ 52 | pass 53 | 54 | def _create_predictions(self, inputs, dec_input, t): 55 | """ 56 | 使用模型预测下一个Token的id 57 | """ 58 | pass 59 | 60 | def train(self, checkpoint, dict_fn, data_fn, start_sign, end_sign, max_train_data_size): 61 | """ 62 | 对模型进行训练 63 | """ 64 | dataset, checkpoint_prefix, steps_per_epoch = self._treat_dataset(dict_fn, data_fn, start_sign, end_sign, 65 | max_train_data_size) 66 | 67 | for epoch in range(_config.epochs): 68 | print('Epoch {}/{}'.format(epoch + 1, _config.epochs)) 69 | start_time = time.time() 70 | 71 | self._init_loss_accuracy() 72 | 73 | step_loss = [0] 74 | batch_sum = 0 75 | sample_sum = 0 76 | for (batch, (inp, tar)) in enumerate(dataset.take(steps_per_epoch)): 77 | self._train_step(inp, tar, step_loss) 78 | batch_sum = batch_sum + len(inp) 79 | sample_sum = steps_per_epoch * len(inp) 80 | sys.stdout.write('{}/{} [==================================]'.format(batch_sum, sample_sum)) 81 | sys.stdout.flush() 82 | 83 | step_time = (time.time() - start_time) 84 | sys.stdout.write(' - {:.4f}s/step - loss: {:.4f}\n' 85 | .format(step_time, step_loss[0])) 86 | sys.stdout.flush() 87 | checkpoint.save(file_prefix=checkpoint_prefix) 88 | 89 | print('训练结束') 90 | 91 | def respond(self, req, dict_fn): 92 | # 对req进行初步处理 93 | token = _data.load_token_dict(dict_fn=dict_fn) 94 | inputs, dec_input = self._pre_treat_inputs(req, token) 95 | self.beam_search_container.init_variables(inputs=inputs, dec_input=dec_input) 96 | inputs, dec_input = self.beam_search_container.get_variables() 97 | for t in range(_config.max_length_tar): 98 | predictions = self._create_predictions(inputs, dec_input, t) 99 | self.beam_search_container.add(predictions=predictions, end_sign=token.get('end')) 100 | if self.beam_search_container.beam_size == 0: 101 | break 102 | 103 | inputs, dec_input = self.beam_search_container.get_variables() 104 | beam_search_result = self.beam_search_container.get_result() 105 | result = '' 106 | # 从容器中抽取序列,生成最终结果 107 | for i in range(len(beam_search_result)): 108 | temp = beam_search_result[i].numpy() 109 | text = _data.sequences_to_texts(temp, token) 110 | text[0] = text[0].replace('start', '').replace('end', '').replace(' ', '') 111 | result = '<' + text[0] + '>' + result 112 | return result 113 | 114 | def _pre_treat_inputs(self, sentence, token): 115 | # 分词 116 | sentence = " ".join(jieba.cut(sentence)) 117 | # 添加首尾符号 118 | sentence = _data.preprocess_sentence(sentence) 119 | # 将句子转成token列表 120 | inputs = [token.get(i, 3) for i in sentence.split(' ')] 121 | # 填充 122 | inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=_config.max_length_inp, padding='post') 123 | # 转成Tensor 124 | inputs = tf.convert_to_tensor(inputs) 125 | # decoder的input就是开始符号 126 | dec_input = tf.expand_dims([token['start']], 0) 127 | return inputs, dec_input 128 | 129 | def _treat_dataset(self, dict_fn, data_fn, start_sign, end_sign, max_train_data_size): 130 | input_tensor, target_tensor, _ = _data.load_dataset(dict_fn=dict_fn, 131 | data_fn=data_fn, 132 | start_sign=start_sign, 133 | end_sign=end_sign, 134 | max_train_data_size=max_train_data_size) 135 | dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor)).cache().shuffle( 136 | _config.BUFFER_SIZE).prefetch(tf.data.experimental.AUTOTUNE) 137 | dataset = dataset.batch(_config.BATCH_SIZE, drop_remainder=True) 138 | checkpoint_prefix = os.path.join(self.checkpoint_dir, "ckpt") 139 | print('训练开始,正在准备数据中...') 140 | step_per_epoch = len(input_tensor) // _config.BATCH_SIZE 141 | 142 | return dataset, checkpoint_prefix, step_per_epoch 143 | -------------------------------------------------------------------------------- /dialogue/tensorflow/smn/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """smn检索式模型实现核心core 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def accumulate(units: int, embedding_dim: int, max_utterance: int, max_sentence: int, 25 | d_type: tf.dtypes.DType = tf.float32, name: str = "accumulate") -> tf.keras.Model: 26 | """ SMN的语义抽取层,主要是对匹配对的两个相似度矩阵进行计算,并返回最终的最后一层GRU的状态,用于计算分数 27 | 28 | :param units: GRU单元数 29 | :param embedding_dim: embedding维度 30 | :param max_utterance: 每轮最大语句数 31 | :param max_sentence: 句子最大长度 32 | :param d_type: 运算精度 33 | :param name: 名称 34 | :return: GRU的状态 35 | """ 36 | utterance_inputs = tf.keras.Input(shape=(max_utterance, max_sentence, embedding_dim), 37 | dtype=d_type, name="{}_utterance_inputs".format(name)) 38 | response_inputs = tf.keras.Input(shape=(max_sentence, embedding_dim), 39 | dtype=d_type, name="{}_response_inputs".format(name)) 40 | a_matrix = tf.keras.initializers.GlorotNormal()(shape=(units, units), dtype=d_type) 41 | 42 | # 这里对response进行GRU的Word级关系建模,这里用正交矩阵初始化内核权重矩阵,用于输入的线性变换。 43 | response_gru = tf.keras.layers.GRU(units=units, return_sequences=True, kernel_initializer="orthogonal", 44 | dtype=d_type, name="{}_gru".format(name))(response_inputs) 45 | conv2d_layer = tf.keras.layers.Conv2D( 46 | filters=8, kernel_size=(3, 3), padding="valid", kernel_initializer="he_normal", 47 | activation="relu", dtype=d_type, name="{}_conv2d".format(name) 48 | ) 49 | max_pooling2d_layer = tf.keras.layers.MaxPooling2D( 50 | pool_size=(3, 3), strides=(3, 3), padding="valid", dtype=d_type, name="{}_pooling2d".format(name) 51 | ) 52 | dense_layer = tf.keras.layers.Dense( 53 | 50, activation="tanh", kernel_initializer="glorot_normal", dtype=d_type, name="{}_dense".format(name) 54 | ) 55 | 56 | # 这里需要做一些前提工作,因为我们要针对每个batch中的每个utterance进行运算,所 57 | # 以我们需要将batch中的utterance序列进行拆分,使得batch中的序列顺序一一匹配 58 | utterance_embeddings = tf.unstack(utterance_inputs, num=max_utterance, axis=1, name="{}_unstack".format(name)) 59 | matching_vectors = [] 60 | for index, utterance_input in enumerate(utterance_embeddings): 61 | # 求解第一个相似度矩阵,公式见论文 62 | matrix1 = tf.matmul(utterance_input, response_inputs, transpose_b=True, name="{}_matmul_{}".format(name, index)) 63 | utterance_gru = tf.keras.layers.GRU(units, return_sequences=True, kernel_initializer="orthogonal", 64 | dtype=d_type, name="{}_gru_{}".format(name, index))(utterance_input) 65 | matrix2 = tf.einsum("aij,jk->aik", utterance_gru, a_matrix) 66 | # matrix2 = tf.matmul(utterance_gru, a_matrix) 67 | # 求解第二个相似度矩阵 68 | matrix2 = tf.matmul(matrix2, response_gru, transpose_b=True) 69 | matrix = tf.stack([matrix1, matrix2], axis=3) 70 | 71 | conv_outputs = conv2d_layer(matrix) 72 | pooling_outputs = max_pooling2d_layer(conv_outputs) 73 | flatten_outputs = tf.keras.layers.Flatten(dtype=d_type, name="{}_flatten_{}".format(name, index))( 74 | pooling_outputs) 75 | 76 | matching_vector = dense_layer(flatten_outputs) 77 | matching_vectors.append(matching_vector) 78 | 79 | vector = tf.stack(matching_vectors, axis=1, name="{}_stack".format(name)) 80 | outputs = tf.keras.layers.GRU( 81 | units, kernel_initializer="orthogonal", dtype=d_type, name="{}_gru_outputs".format(name) 82 | )(vector) 83 | 84 | return tf.keras.Model(inputs=[utterance_inputs, response_inputs], outputs=outputs) 85 | 86 | 87 | def smn(units: int, vocab_size: int, embedding_dim: int, max_utterance: int, max_sentence: int, 88 | d_type: tf.dtypes.DType = tf.float32, name: str = "smn") -> tf.keras.Model: 89 | """ SMN的模型,在这里将输入进行accumulate之后,得到匹配对的向量,然后通过这些向量计算最终的分类概率 90 | 91 | :param units: GRU单元数 92 | :param vocab_size: embedding词汇量 93 | :param embedding_dim: embedding维度 94 | :param max_utterance: 每轮最大语句数 95 | :param max_sentence: 句子最大长度 96 | :param d_type: 运算精度 97 | :param name: 名称 98 | :return: 匹配对打分 99 | """ 100 | utterances = tf.keras.Input(shape=(max_utterance, max_sentence), dtype=d_type, name="{}_utterance".format(name)) 101 | responses = tf.keras.Input(shape=(max_sentence,), dtype=d_type, name="{}_response".format(name)) 102 | 103 | embeddings = tf.keras.layers.Embedding(vocab_size, embedding_dim, dtype=d_type, name="{}_embedding".format(name)) 104 | utterances_embeddings = embeddings(utterances) 105 | responses_embeddings = embeddings(responses) 106 | 107 | accumulate_outputs = accumulate( 108 | units=units, embedding_dim=embedding_dim, max_utterance=max_utterance, 109 | max_sentence=max_sentence, d_type=d_type, name="{}_accumulate".format(name) 110 | )(inputs=[utterances_embeddings, responses_embeddings]) 111 | 112 | outputs = tf.keras.layers.Dense( 113 | 2, kernel_initializer="glorot_normal", dtype=d_type, name="{}_dense_outputs".format(name) 114 | )(accumulate_outputs) 115 | 116 | outputs = tf.keras.layers.Softmax(axis=-1, dtype=d_type, name="{}_softmax".format(name))(outputs) 117 | 118 | return tf.keras.Model(inputs=[utterances, responses], outputs=outputs) 119 | -------------------------------------------------------------------------------- /dialogue/tensorflow/gpt2/gpt2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def creat_padding_mask(inputs): 5 | """ 6 | 对input中的padding单位进行mask 7 | :param inputs: 句子序列输入 8 | :return: 填充部分标记 9 | """ 10 | mask = tf.cast(tf.math.equal(inputs, 0), dtype=tf.float32) 11 | return mask[:, tf.newaxis, tf.newaxis, :] 12 | 13 | 14 | def creat_look_ahead_mask(inputs): 15 | sequence_length = tf.shape(inputs)[1] 16 | look_ahead_mask = 1 - \ 17 | tf.linalg.band_part(tf.ones((sequence_length, sequence_length)), -1, 0) 18 | padding_mask = creat_padding_mask(inputs) 19 | return tf.maximum(look_ahead_mask, padding_mask) 20 | 21 | 22 | def positional_encoding(position, deep): 23 | i = tf.range(deep, dtype=tf.float32)[tf.newaxis, :] 24 | position = tf.range(position, dtype=tf.float32)[:, tf.newaxis] 25 | angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(deep, tf.float32)) 26 | angle_rads = position * angles 27 | 28 | sines = tf.math.sin(angle_rads[:, 0::2]) 29 | cosines = tf.math.cos(angle_rads[:, 1::2]) 30 | pos_encoding = tf.concat([sines, cosines], axis=-1) 31 | pos_encoding = pos_encoding[tf.newaxis, ...] 32 | return tf.cast(pos_encoding, tf.float32) 33 | 34 | 35 | def split_heads(inputs, batch_size, num, deep): 36 | depth = deep // num 37 | inputs = tf.reshape(inputs, (batch_size, -1, num, depth)) 38 | return tf.transpose(inputs, perm=[0, 2, 1, 3]) 39 | 40 | 41 | def positional_encoding_layer(position, deep): 42 | inputs = tf.keras.Input(shape=(None, deep)) 43 | pos_encoding = positional_encoding(position, deep) 44 | outputs = inputs + pos_encoding[:, :tf.shape(inputs)[1], :] 45 | return tf.keras.Model(inputs=inputs, outputs=outputs) 46 | 47 | 48 | def encoder(vocab_size, deep, dropout): 49 | inputs = tf.keras.Input(shape=(None,)) 50 | embedding = tf.keras.layers.Embedding(vocab_size, deep)(inputs) 51 | embedding *= tf.math.sqrt(tf.cast(deep, tf.float32)) 52 | embedding = positional_encoding_layer(vocab_size, deep)(embedding) 53 | outputs = tf.keras.layers.Dropout(rate=dropout)(embedding) 54 | 55 | return tf.keras.Model(inputs=inputs, outputs=outputs) 56 | 57 | 58 | def self_attention(query, key, value, mask): 59 | matmul = tf.matmul(query, key, transpose_b=True) 60 | deep = tf.cast(tf.shape(key)[-1], tf.float32) 61 | scaled_attention = matmul / tf.math.sqrt(deep) 62 | 63 | if mask is not None: 64 | scaled_attention += (mask * -1e9) 65 | attention_weight = tf.nn.softmax(scaled_attention, axis=-1) 66 | output = tf.matmul(attention_weight, value) 67 | return output 68 | 69 | 70 | def attention(deep, num): 71 | query = tf.keras.Input(shape=(None, deep)) 72 | key = tf.keras.Input(shape=(None, deep)) 73 | value = tf.keras.Input(shape=(None, deep)) 74 | mask = tf.keras.Input(shape=(1, None, None)) 75 | batch_size = tf.shape(query)[0] 76 | 77 | query_fc = tf.keras.layers.Dense(units=deep)(query) 78 | key_fc = tf.keras.layers.Dense(units=deep)(key) 79 | value_fc = tf.keras.layers.Dense(units=deep)(value) 80 | 81 | query_fc = split_heads(query_fc, batch_size, num, deep) 82 | key_fc = split_heads(key_fc, batch_size, num, deep) 83 | value_fc = split_heads(value_fc, batch_size, num, deep) 84 | 85 | attention_state = self_attention(query_fc, key_fc, value_fc, mask) 86 | attention_state = tf.transpose(attention_state, perm=[0, 2, 1, 3]) 87 | concat_attention = tf.reshape(attention_state, (batch_size, -1, deep)) 88 | output = tf.keras.layers.Dense(units=deep)(concat_attention) 89 | 90 | return tf.keras.Model(inputs=[query, key, value, mask], outputs=output) 91 | 92 | 93 | def block(units, deep, num, dropout): 94 | inputs = tf.keras.Input(shape=(None, deep)) 95 | mask = tf.keras.Input(shape=(1, None, None)) 96 | 97 | attention_state = attention(deep, num)( 98 | inputs=[inputs, inputs, inputs, mask]) 99 | attention_state = tf.keras.layers.LayerNormalization( 100 | epsilon=1e-6)(attention_state + inputs) 101 | outputs = tf.keras.layers.Dense( 102 | units=units, activation="relu")(attention_state) 103 | outputs = tf.keras.layers.Dense(units=deep)(outputs) 104 | outputs = tf.keras.layers.Dropout(rate=dropout)(outputs) 105 | outputs = tf.keras.layers.LayerNormalization( 106 | epsilon=1e-6)(outputs + attention_state) 107 | 108 | return tf.keras.Model(inputs=[inputs, mask], outputs=outputs) 109 | 110 | 111 | def decoder(num_layers, num_heads, units, deep, dropout): 112 | inputs = tf.keras.Input(shape=(None, deep)) 113 | mask = tf.keras.Input(shape=(1, None, None)) 114 | 115 | outputs = inputs 116 | for i in range(num_layers): 117 | outputs = block(units, deep, num_heads, dropout)( 118 | inputs=[outputs, mask]) 119 | return tf.keras.Model(inputs=[inputs, mask], outputs=outputs) 120 | 121 | 122 | class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): 123 | """ 124 | 优化器将 Adam 优化器与自定义的学习速率调度程序配合使用,这里直接参考了官网的实现 125 | 因为是公式的原因,其实大同小异 126 | """ 127 | 128 | def __init__(self, d_model, warmup_steps=2000): 129 | super(CustomSchedule, self).__init__() 130 | self.d_model = d_model 131 | self.d_model = tf.cast(self.d_model, tf.float32) 132 | self.warmup_steps = warmup_steps 133 | 134 | def __call__(self, step): 135 | arg1 = tf.math.rsqrt(step) 136 | arg2 = step * (self.warmup_steps ** -1.5) 137 | return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) 138 | 139 | 140 | def gpt2(vocab_size, num_layers, units, deep, num_heads, dropout): 141 | inputs = tf.keras.Input(shape=(None,)) 142 | outputs = encoder(vocab_size=vocab_size, deep=deep, 143 | dropout=dropout)(inputs) 144 | 145 | mask = tf.keras.layers.Lambda( 146 | creat_look_ahead_mask, output_shape=(1, None, None))(inputs) 147 | 148 | for i in range(num_layers): 149 | outputs = block(units, deep, num_heads, dropout)( 150 | inputs=[outputs, mask]) 151 | output = tf.keras.layers.Dense(units=vocab_size)(outputs) 152 | 153 | return tf.keras.Model(inputs=inputs, outputs=output) 154 | -------------------------------------------------------------------------------- /docs/Massive_Exploration_of_Neural_Machine_Translation_Architectures.md: -------------------------------------------------------------------------------- 1 | # 前言 2 | 3 | > 标题:Massive Exploration of Neural Machine Translation Architectures\ 4 | > 原文链接:[Link](https://arxiv.org/pdf/1703.03906.pdf)\ 5 | > Github:[NLP相关Paper笔记和代码复现](https://github.com/DengBoCong/nlp-paper)\ 6 | > 说明:阅读论文时进行相关思想、结构、优缺点,内容进行提炼和记录,论文和相关引用会标明出处,引用之处如有侵权,烦请告知删除。\ 7 | > 转载请注明:DengBoCong 8 | 9 | # 介绍 10 | 在计算机视觉中通常会在大型超参数空间中进行扫描,但对于NMT模型而言,这样的探索成本过高,从而限制了研究人员完善的架构和超参数选择。更改超参数成本很大,在这篇论文中,展示了以NMT架构超参数为例的首次大规模分析,实验为构建和扩展NMT体系结构带来了新颖的见解和实用建议。本文工作探索NMT架构的常见变体,并了解哪些架构选择最重要,同时展示所有实验的BLEU分数,perplexities,模型大小和收敛时间,包括每个实验多次运行中计算出的方差数。论文主要贡献如下: 11 | + 展示了以NMT架构超参数为例的首次大规模分析,实验为构建和扩展NMT体系结构带来了新颖的见解和实用建议。例如,深层编码器比解码器更难优化,密度残差连接比常规的残差连接具有更好的性能,LSTM优于GRU,并且调整好的BeamSearch对于获得最新的结果至关重要。 12 | + 确定了随机初始化和细微的超参数变化对指标(例如BLEU)的影响程度,有助于研究人员从随机噪声中区分出具有统计学意义的结果。 13 | + 发布了基于TensorFlow的[开源软件包](https://github.com/google/seq2seq/),该软件包专为实现可再现的先进sequence-to-sequence模型而设计。 14 | 15 | 本篇论文使用的sequence-to-sequence结构如下,对应的编号是论文中章节阐述部分。 16 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20201203111623761.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RCQ18xMjE=,size_16,color_FFFFFF,t_70) 17 | 编码器方法 $f_{enc}$ 将 $x=(x_1,...,x_m)$ 的源tokens序列作为输入,并产生状态序列 $h=(h_1,...,h_m)$。在base model中, $f_{enc}$ 是双向RNN,状态 $h_i$ 对应于由后向RNN和前向RNN的状态的concatenation,$h_i = [\overrightarrow{h_i};\overleftarrow{h_i}]$。解码器 $f_{dec}$ 是RNN,可根据 $h$ 预测目标序列 $y =(y_1,...,y_k)$ 的概率。根据解码器RNN中的循环状态 $s_i$、前一个单词 $y_{ NoReturn: 31 | """ 32 | :param beam_size: beam大小 33 | :param max_length: 句子最大长度 34 | :param worst_score: 最差分数 35 | """ 36 | self.BEAM_SIZE = beam_size # 保存原始beam大小,用于重置 37 | self.MAX_LEN = max_length - 1 38 | self.MIN_SCORE = worst_score # 保留原始worst_score,用于重置 39 | 40 | self.candidates = [] # 保存中间状态序列的容器,元素格式为(score, sequence)类型为(float, []) 41 | self.result = [] # 用来保存已经遇到结束符的序列 42 | self.result_plus = [] # 用来保存已经遇到结束符的带概率分布的序列 43 | self.candidates_plus = [] # 保存已经遇到结束符的序列及概率分布 44 | 45 | self.enc_output = None 46 | self.remain = None 47 | self.dec_inputs = None 48 | self.beam_size = None 49 | self.worst_score = None 50 | 51 | def __len__(self): 52 | """当前候选结果数 53 | """ 54 | return len(self.candidates) 55 | 56 | def reset(self, enc_output: tf.Tensor, dec_input: tf.Tensor, remain: tf.Tensor) -> NoReturn: 57 | """重置搜索 58 | 59 | :param enc_output: 已经序列化的输入句子 60 | :param dec_input: 解码器输入序列 61 | :param remain: 预留decoder输入 62 | :return: 无返回值 63 | """ 64 | self.candidates = [] # 保存中间状态序列的容器,元素格式为(score, sequence)类型为(float, []) 65 | self.candidates_plus = [] # 保存已经遇到结束符的序列及概率分布,元素为(score, tensor),tensor的shape为(seq_len, vocab_size) 66 | self.candidates.append((1, dec_input)) 67 | self.enc_output = enc_output 68 | self.remain = remain 69 | self.dec_inputs = dec_input 70 | self.beam_size = self.BEAM_SIZE # 新一轮中,将beam_size重置为原beam大小 71 | self.worst_score = self.MIN_SCORE # 新一轮中,worst_score重置 72 | self.result = [] # 用来保存已经遇到结束符的序列 73 | self.result_plus = [] # 用来保存已经遇到结束符的带概率分布的序列元素为tensor, tensor的shape为(seq_len, vocab_size) 74 | 75 | def get_search_inputs(self) -> Tuple: 76 | """为下一步预测生成输入 77 | 78 | :return: enc_output, dec_inputs, remain 79 | """ 80 | # 生成多beam输入 81 | enc_output = self.enc_output 82 | remain = self.remain 83 | self.dec_inputs = self.candidates[0][1] 84 | for i in range(1, len(self)): 85 | enc_output = tf.concat([enc_output, self.enc_output], 0) 86 | remain = tf.concat([remain, self.remain], 0) 87 | self.dec_inputs = tf.concat([self.dec_inputs, self.candidates[i][1]], axis=0) 88 | 89 | return enc_output, self.dec_inputs, remain 90 | 91 | def _reduce_end(self, end_sign: str) -> NoReturn: 92 | """ 当序列遇到了结束token,需要将该序列从容器中移除 93 | 94 | :param end_sign: 句子结束标记 95 | :return: 无返回值 96 | """ 97 | for idx, (s, dec) in enumerate(self.candidates): 98 | temp = dec.numpy() 99 | if temp[0][-1] == end_sign: 100 | self.result.append(self.candidates[idx]) 101 | self.result_plus.append(self.candidates_plus[idx]) 102 | del self.candidates[idx] 103 | del self.candidates_plus[idx] 104 | self.beam_size -= 1 105 | 106 | def expand(self, predictions, end_sign) -> NoReturn: 107 | """ 根据预测结果对候选进行扩展 108 | 往容器中添加预测结果,在本方法中对预测结果进行整理、排序的操作 109 | 110 | :param predictions: 传入每个时间步的模型预测值 111 | :param end_sign: 句子结束标记 112 | :return: 无返回值 113 | """ 114 | prev_candidates = copy.deepcopy(self.candidates) 115 | prev_candidates_plus = copy.deepcopy(self.candidates_plus) 116 | self.candidates.clear() 117 | self.candidates_plus.clear() 118 | predictions = predictions.numpy() 119 | predictions_plus = copy.deepcopy(predictions) 120 | # 在batch_size*beam_size个prediction中找到分值最高的beam_size个 121 | for i in range(self.dec_inputs.shape[0]): # 外循环遍历batch_size(batch_size的值其实就是之前选出的候选数量) 122 | for _ in range(self.beam_size): # 内循环遍历选出beam_size个概率最大位置 123 | token_index = tf.argmax(input=predictions[i], axis=0) # predictions.shape --> (batch_size, vocab_size) 124 | # 计算分数 125 | score = prev_candidates[i][0] * predictions[i][token_index] 126 | predictions[i][token_index] = 0 127 | # 判断容器容量以及分数比较 128 | if len(self) < self.beam_size or score > self.worst_score: 129 | self.candidates.append((score, tf.concat( 130 | [prev_candidates[i][1], tf.constant([[token_index.numpy()]], shape=(1, 1))], axis=-1))) 131 | if len(prev_candidates_plus) == 0: 132 | self.candidates_plus.append((score, predictions_plus)) 133 | else: 134 | self.candidates_plus.append( 135 | (score, tf.concat([prev_candidates_plus[i][1], [predictions_plus[i]]], axis=0))) 136 | if len(self) > self.beam_size: 137 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.candidates)]) 138 | del self.candidates[sorted_scores[0][1]] 139 | del self.candidates_plus[sorted_scores[0][1]] 140 | self.worst_score = sorted_scores[1][0] 141 | else: 142 | self.worst_score = min(score, self.worst_score) 143 | self._reduce_end(end_sign=end_sign) 144 | 145 | def get_result(self, top_k=1) -> List: 146 | """获得概率最高的top_k个结果 147 | 148 | :param top_k: 输出结果数量 149 | :return: 概率最高的top_k个结果 150 | """ 151 | if not self.result: 152 | self.result = self.candidates 153 | results = [element[1] for element in sorted(self.result)[-top_k:]] 154 | return results 155 | 156 | def get_result_plus(self, top_k=1) -> List: 157 | """获得概率最高的top_k个结果 158 | 159 | :param top_k: 输出结果数量 160 | :return: 概率最高的top_k个带概率的结果 161 | """ 162 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.result)], reverse=True) 163 | results_plus = [] 164 | for i in range(top_k): 165 | results_plus.append(self.result_plus[sorted_scores[i][1]][1]) 166 | 167 | return results_plus 168 | -------------------------------------------------------------------------------- /dialogue/pytorch/beamsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """BeamSearch组件 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import torch 23 | from typing import List 24 | from typing import NoReturn 25 | from typing import Tuple 26 | 27 | 28 | class BeamSearch(object): 29 | 30 | def __init__(self, beam_size, max_length, worst_score) -> NoReturn: 31 | """ 32 | :param beam_size: beam大小 33 | :param max_length: 句子最大长度 34 | :param worst_score: 最差分数 35 | """ 36 | self.BEAM_SIZE = beam_size # 保存原始beam大小,用于重置 37 | self.MAX_LEN = max_length - 1 38 | self.MIN_SCORE = worst_score # 保留原始worst_score,用于重置 39 | 40 | self.candidates = [] # 保存中间状态序列的容器,元素格式为(score, sequence)类型为(float, []) 41 | self.result = [] # 用来保存已经遇到结束符的序列 42 | self.result_plus = [] # 用来保存已经遇到结束符的带概率分布的序列 43 | self.candidates_plus = [] # 保存已经遇到结束符的序列及概率分布 44 | 45 | self.enc_output = None 46 | self.remain = None 47 | self.dec_inputs = None 48 | self.beam_size = None 49 | self.worst_score = None 50 | 51 | def __len__(self): 52 | """当前候选结果数 53 | """ 54 | return len(self.candidates) 55 | 56 | def reset(self, enc_output: torch.Tensor, dec_input: torch.Tensor, remain: torch.Tensor) -> NoReturn: 57 | """重置搜索 58 | 59 | :param enc_output: 已经序列化的输入句子 60 | :param dec_input: 解码器输入序列 61 | :param remain: 预留decoder输入 62 | :return: 无返回值 63 | """ 64 | self.candidates = [] # 保存中间状态序列的容器,元素格式为(score, sequence)类型为(float, []) 65 | self.candidates_plus = [] # 保存已经遇到结束符的序列及概率分布,元素为(score, tensor),tensor的shape为(seq_len, vocab_size) 66 | self.candidates.append((1, dec_input)) 67 | self.enc_output = enc_output 68 | self.remain = remain 69 | self.dec_inputs = dec_input 70 | self.beam_size = self.BEAM_SIZE # 新一轮中,将beam_size重置为原beam大小 71 | self.worst_score = self.MIN_SCORE # 新一轮中,worst_score重置 72 | self.result = [] # 用来保存已经遇到结束符的序列 73 | self.result_plus = [] # 用来保存已经遇到结束符的带概率分布的序列元素为tensor, tensor的shape为(seq_len, vocab_size) 74 | 75 | def get_search_inputs(self) -> Tuple: 76 | """为下一步预测生成输入 77 | 78 | :return: enc_output, dec_inputs, remain 79 | """ 80 | # 生成多beam输入 81 | enc_output = self.enc_output 82 | remain = self.remain 83 | self.dec_inputs = self.candidates[0][1] 84 | for i in range(1, len(self)): 85 | enc_output = torch.cat((enc_output, self.enc_output), dim=0) 86 | remain = torch.cat((remain, self.remain), dim=0) 87 | self.dec_inputs = torch.cat((self.dec_inputs, self.candidates[i][1]), dim=0) 88 | 89 | return enc_output, self.dec_inputs, remain 90 | 91 | def _reduce_end(self, end_sign: str) -> NoReturn: 92 | """ 当序列遇到了结束token,需要将该序列从容器中移除 93 | 94 | :param end_sign: 句子结束标记 95 | :return: 无返回值 96 | """ 97 | for idx, (s, dec) in enumerate(self.candidates): 98 | temp = dec.numpy() 99 | if temp[0][-1] == end_sign: 100 | self.result.append(self.candidates[idx]) 101 | # self.result_plus.append(self.candidates_plus[idx]) 102 | del self.candidates[idx] 103 | # del self.candidates_plus[idx] 104 | self.beam_size -= 1 105 | 106 | def expand(self, predictions, end_sign) -> NoReturn: 107 | """ 根据预测结果对候选进行扩展 108 | 往容器中添加预测结果,在本方法中对预测结果进行整理、排序的操作 109 | 110 | :param predictions: 传入每个时间步的模型预测值 111 | :param end_sign: 句子结束标记 112 | :return: 无返回值 113 | """ 114 | prev_candidates = copy.deepcopy(self.candidates) 115 | prev_candidates_plus = copy.deepcopy(self.candidates_plus) 116 | self.candidates.clear() 117 | self.candidates_plus.clear() 118 | # predictions = predictions.numpy() 119 | predictions_plus = copy.deepcopy(predictions) 120 | # 在batch_size*beam_size个prediction中找到分值最高的beam_size个 121 | for i in range(self.dec_inputs.shape[0]): # 外循环遍历batch_size(batch_size的值其实就是之前选出的候选数量) 122 | for _ in range(self.beam_size): # 内循环遍历选出beam_size个概率最大位置 123 | token_index = torch.argmax(input=predictions[i], dim=0) # predictions.shape -> (batch_size, vocab_size) 124 | score = prev_candidates[i][0] * predictions[i][token_index] # 计算分数 125 | predictions[i][token_index] = 0 126 | # 判断容器容量以及分数比较 127 | if len(self) < self.beam_size or score > self.worst_score: 128 | self.candidates.append( 129 | (score, torch.cat((prev_candidates[i][1], torch.reshape(token_index, shape=(1, 1))), dim=-1)) 130 | ) 131 | # if len(prev_candidates_plus) == 0: 132 | # self.candidates_plus.append((score, predictions_plus)) 133 | # else: 134 | # self.candidates_plus.append( 135 | # (score, torch.cat((prev_candidates_plus[i][1], [predictions_plus[i]]), dim=0)) 136 | # ) 137 | if len(self) > self.beam_size: 138 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.candidates)]) 139 | del self.candidates[sorted_scores[0][1]] 140 | # del self.candidates_plus[sorted_scores[0][1]] 141 | self.worst_score = sorted_scores[1][0] 142 | else: 143 | self.worst_score = min(score, self.worst_score) 144 | self._reduce_end(end_sign=end_sign) 145 | 146 | def get_result(self, top_k=1) -> List: 147 | """获得概率最高的top_k个结果 148 | 149 | :param top_k: 输出结果数量 150 | :return: 概率最高的top_k个结果 151 | """ 152 | if not self.result: 153 | self.result = self.candidates 154 | results = [element[1] for element in sorted(self.result)[-top_k:]] 155 | return results 156 | 157 | def get_result_plus(self, top_k=1) -> List: 158 | """获得概率最高的top_k个结果 159 | 160 | :param top_k: 输出结果数量 161 | :return: 概率最高的top_k个带概率的结果 162 | """ 163 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.result)], reverse=True) 164 | results_plus = [] 165 | for i in range(top_k): 166 | results_plus.append(self.result_plus[sorted_scores[i][1]][1]) 167 | 168 | return results_plus 169 | -------------------------------------------------------------------------------- /dialogue/tensorflow/task/data/ontology.json: -------------------------------------------------------------------------------- 1 | { 2 | "requestable": [ 3 | "address", 4 | "area", 5 | "food", 6 | "phone", 7 | "pricerange", 8 | "postcode", 9 | "signature", 10 | "name" 11 | ], 12 | "method": [ 13 | "none", 14 | "byconstraints", 15 | "byname", 16 | "finished", 17 | "byalternatives" 18 | ], 19 | "informable": { 20 | "food": [ 21 | "afghan", 22 | "african", 23 | "afternoon tea", 24 | "asian oriental", 25 | "australasian", 26 | "australian", 27 | "austrian", 28 | "barbeque", 29 | "basque", 30 | "belgian", 31 | "bistro", 32 | "brazilian", 33 | "british", 34 | "canapes", 35 | "cantonese", 36 | "caribbean", 37 | "catalan", 38 | "chinese", 39 | "christmas", 40 | "corsica", 41 | "creative", 42 | "crossover", 43 | "cuban", 44 | "danish", 45 | "eastern european", 46 | "english", 47 | "eritrean", 48 | "european", 49 | "french", 50 | "fusion", 51 | "gastropub", 52 | "german", 53 | "greek", 54 | "halal", 55 | "hungarian", 56 | "indian", 57 | "indonesian", 58 | "international", 59 | "irish", 60 | "italian", 61 | "jamaican", 62 | "japanese", 63 | "korean", 64 | "kosher", 65 | "latin american", 66 | "lebanese", 67 | "light bites", 68 | "malaysian", 69 | "mediterranean", 70 | "mexican", 71 | "middle eastern", 72 | "modern american", 73 | "modern eclectic", 74 | "modern european", 75 | "modern global", 76 | "molecular gastronomy", 77 | "moroccan", 78 | "new zealand", 79 | "north african", 80 | "north american", 81 | "north indian", 82 | "northern european", 83 | "panasian", 84 | "persian", 85 | "polish", 86 | "polynesian", 87 | "portuguese", 88 | "romanian", 89 | "russian", 90 | "scandinavian", 91 | "scottish", 92 | "seafood", 93 | "singaporean", 94 | "south african", 95 | "south indian", 96 | "spanish", 97 | "sri lankan", 98 | "steakhouse", 99 | "swedish", 100 | "swiss", 101 | "thai", 102 | "the americas", 103 | "traditional", 104 | "turkish", 105 | "tuscan", 106 | "unusual", 107 | "vegetarian", 108 | "venetian", 109 | "vietnamese", 110 | "welsh", 111 | "world" 112 | ], 113 | "pricerange": [ 114 | "cheap", 115 | "moderate", 116 | "expensive" 117 | ], 118 | "name": [ 119 | "ali baba", 120 | "anatolia", 121 | "ask", 122 | "backstreet bistro", 123 | "bangkok city", 124 | "bedouin", 125 | "bloomsbury restaurant", 126 | "caffe uno", 127 | "cambridge lodge restaurant", 128 | "charlie chan", 129 | "chiquito restaurant bar", 130 | "city stop restaurant", 131 | "clowns cafe", 132 | "cocum", 133 | "cote", 134 | "cotto", 135 | "curry garden", 136 | "curry king", 137 | "curry prince", 138 | "curry queen", 139 | "da vinci pizzeria", 140 | "da vince pizzeria", 141 | "darrys cookhouse and wine shop", 142 | "de luca cucina and bar", 143 | "dojo noodle bar", 144 | "don pasquale pizzeria", 145 | "efes restaurant", 146 | "eraina", 147 | "fitzbillies restaurant", 148 | "frankie and bennys", 149 | "galleria", 150 | "golden house", 151 | "golden wok", 152 | "gourmet burger kitchen", 153 | "graffiti", 154 | "grafton hotel restaurant", 155 | "hakka", 156 | "hk fusion", 157 | "hotel du vin and bistro", 158 | "india house", 159 | "j restaurant", 160 | "jinling noodle bar", 161 | "kohinoor", 162 | "kymmoy", 163 | "la margherita", 164 | "la mimosa", 165 | "la raza", 166 | "la tasca", 167 | "lan hong house", 168 | "little seoul", 169 | "loch fyne", 170 | "mahal of cambridge", 171 | "maharajah tandoori restaurant", 172 | "meghna", 173 | "meze bar restaurant", 174 | "michaelhouse cafe", 175 | "midsummer house restaurant", 176 | "nandos", 177 | "nandos city centre", 178 | "panahar", 179 | "peking restaurant", 180 | "pipasha restaurant", 181 | "pizza express", 182 | "pizza express fen ditton", 183 | "pizza hut", 184 | "pizza hut city centre", 185 | "pizza hut cherry hinton", 186 | "pizza hut fen ditton", 187 | "prezzo", 188 | "rajmahal", 189 | "restaurant alimentum", 190 | "restaurant one seven", 191 | "restaurant two two", 192 | "rice boat", 193 | "rice house", 194 | "riverside brasserie", 195 | "royal spice", 196 | "royal standard", 197 | "saffron brasserie", 198 | "saigon city", 199 | "saint johns chop house", 200 | "sala thong", 201 | "sesame restaurant and bar", 202 | "shanghai family restaurant", 203 | "shiraz restaurant", 204 | "sitar tandoori", 205 | "stazione restaurant and coffee bar", 206 | "taj tandoori", 207 | "tandoori palace", 208 | "tang chinese", 209 | "thanh binh", 210 | "the cambridge chop house", 211 | "the copper kettle", 212 | "the cow pizza kitchen and bar", 213 | "the gandhi", 214 | "the gardenia", 215 | "the golden curry", 216 | "the good luck chinese food takeaway", 217 | "the hotpot", 218 | "the lucky star", 219 | "the missing sock", 220 | "the nirala", 221 | "the oak bistro", 222 | "the river bar steakhouse and grill", 223 | "the slug and lettuce", 224 | "the varsity restaurant", 225 | "travellers rest", 226 | "ugly duckling", 227 | "venue", 228 | "wagamama", 229 | "yippee noodle bar", 230 | "yu garden", 231 | "zizzi cambridge" 232 | ], 233 | "area": [ 234 | "centre", 235 | "north", 236 | "west", 237 | "south", 238 | "east" 239 | ] 240 | } 241 | } -------------------------------------------------------------------------------- /dialogue/tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """全局公用工具模块 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import sys 24 | import json 25 | import time 26 | import jieba 27 | import logging 28 | import numpy as np 29 | import matplotlib.pyplot as plt 30 | import matplotlib.ticker as ticker 31 | from collections import defaultdict 32 | from collections import OrderedDict 33 | 34 | 35 | def log_operator(level: str, log_file: str = None, 36 | log_format: str = "[%(levelname)s] - [%(asctime)s] - [file: %(filename)s] - " 37 | "[function: %(funcName)s] - [%(message)s]") -> logging.Logger: 38 | """ 日志操作方法,日志级别有"CRITICAL","FATAL","ERROR","WARN","WARNING","INFO","DEBUG","NOTSET" 39 | CRITICAL = 50, FATAL = CRITICAL, ERROR = 40, WARNING = 30, WARN = WARNING, INFO = 20, DEBUG = 10, NOTSET = 0 40 | 41 | :param log_file: 日志路径 42 | :param level: 日志级别 43 | :param log_format: 日志信息格式 44 | :return: 日志记录器 45 | """ 46 | if log_file is None: 47 | log_file = os.path.abspath(__file__)[ 48 | :os.path.abspath(__file__).rfind("\\dialogue\\")] + "\\dialogue\\data\\preprocess\\runtime.logs" 49 | 50 | logger = logging.getLogger() 51 | logger.setLevel(level) 52 | file_handler = logging.FileHandler(log_file, encoding="utf-8") 53 | file_handler.setLevel(level=level) 54 | formatter = logging.Formatter(log_format) 55 | file_handler.setFormatter(formatter) 56 | logger.addHandler(file_handler) 57 | 58 | return logger 59 | 60 | 61 | def show_history(history: dict, save_dir: str, valid_freq: int): 62 | """ 用于显示历史指标趋势以及保存历史指标图表图 63 | 64 | :param history: 历史指标 65 | :param save_dir: 历史指标显示图片保存位置 66 | :param valid_freq: 验证频率 67 | :return: 无返回值 68 | """ 69 | train_x_axis = [i + 1 for i in range(len(history["train_loss"]))] 70 | valid_x_axis = [(i + 1) * valid_freq for i in range(len(history["valid_loss"]))] 71 | 72 | figure, axis = plt.subplots(1, 1) 73 | tick_spacing = 1 74 | if len(history["train_loss"]) > 20: 75 | tick_spacing = len(history["train_loss"]) // 20 76 | plt.plot(train_x_axis, history["train_loss"], label="train_loss", marker=".") 77 | plt.plot(train_x_axis, history["train_accuracy"], label="train_accuracy", marker=".") 78 | plt.plot(valid_x_axis, history["valid_loss"], label="valid_loss", marker=".", linestyle="--") 79 | plt.plot(valid_x_axis, history["valid_accuracy"], label="valid_accuracy", marker=".", linestyle="--") 80 | plt.xticks(valid_x_axis) 81 | plt.xlabel("epoch") 82 | plt.legend() 83 | 84 | axis.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing)) 85 | 86 | save_path = save_dir + time.strftime("%Y_%m_%d_%H_%M_%S_", time.localtime(time.time())) 87 | if not os.path.exists(save_dir): 88 | os.makedirs(save_dir, exist_ok=True) 89 | plt.savefig(save_path) 90 | plt.show() 91 | 92 | 93 | class ProgressBar(object): 94 | """ 进度条工具 """ 95 | 96 | EXECUTE = "%(current)d/%(total)d %(bar)s (%(percent)3d%%) %(metrics)s" 97 | DONE = "%(current)d/%(total)d %(bar)s - %(time).4fs/step %(metrics)s" 98 | 99 | def __init__(self, total: int = 100, num: int = 1, width: int = 30, fmt: str = EXECUTE, 100 | symbol: str = "=", remain: str = ".", output=sys.stderr): 101 | """ 102 | :param total: 执行总的次数 103 | :param num: 每执行一次任务数量级 104 | :param width: 进度条符号数量 105 | :param fmt: 进度条格式 106 | :param symbol: 进度条完成符号 107 | :param remain: 进度条未完成符号 108 | :param output: 错误输出 109 | """ 110 | assert len(symbol) == 1 111 | self.args = {} 112 | self.metrics = "" 113 | self.total = total 114 | self.num = num 115 | self.width = width 116 | self.symbol = symbol 117 | self.remain = remain 118 | self.output = output 119 | self.fmt = re.sub(r"(?P%\(.+?\))d", r"\g%dd" % len(str(total)), fmt) 120 | 121 | def __call__(self, current: int, metrics: str): 122 | """ 123 | :param current: 已执行次数 124 | :param metrics: 附加在进度条后的指标字符串 125 | """ 126 | self.metrics = metrics 127 | percent = current / float(self.total) 128 | size = int(self.width * percent) 129 | bar = "[" + self.symbol * size + ">" + self.remain * (self.width - size - 1) + "]" 130 | 131 | self.args = { 132 | "total": self.total * self.num, 133 | "bar": bar, 134 | "current": current * self.num, 135 | "percent": percent * 100, 136 | "metrics": metrics 137 | } 138 | print("\r" + self.fmt % self.args, file=self.output, end="") 139 | 140 | def reset(self, total: int, num: int, width: int = 30, fmt: str = EXECUTE, 141 | symbol: str = "=", remain: str = ".", output=sys.stderr): 142 | """重置内部属性 143 | 144 | :param total: 执行总的次数 145 | :param num: 每执行一次任务数量级 146 | :param width: 进度条符号数量 147 | :param fmt: 进度条格式 148 | :param symbol: 进度条完成符号 149 | :param remain: 进度条未完成符号 150 | :param output: 错误输出 151 | """ 152 | self.__init__(total=total, num=num, width=width, fmt=fmt, 153 | symbol=symbol, remain=remain, output=output) 154 | 155 | def done(self, step_time: float, fmt=DONE): 156 | """ 157 | :param step_time: 该时间步执行完所用时间 158 | :param fmt: 执行完成之后进度条格式 159 | """ 160 | self.args["bar"] = "[" + self.symbol * self.width + "]" 161 | self.args["time"] = step_time 162 | print("\r" + fmt % self.args + "\n", file=self.output, end="") 163 | 164 | 165 | def get_dict_string(data: dict, prefix: str = "- ", precision: str = ": {:.4f} "): 166 | """将字典数据转换成key——value字符串 167 | 168 | :param data: 字典数据 169 | :param prefix: 组合前缀 170 | :param precision: key——value打印精度 171 | :return: 字符串 172 | """ 173 | result = "" 174 | for key, value in data.items(): 175 | result += (prefix + key + precision).format(value) 176 | 177 | return result 178 | 179 | 180 | 181 | 182 | 183 | def preprocess_request(sentence: str, max_length: int, tokenizer: Tokenizer, 184 | start_sign: str = "", end_sign: str = ""): 185 | """ 用于处理回复功能的输入句子,返回模型使用的序列 186 | 187 | :param sentence: 待处理句子 188 | :param max_length: 单个句子最大长度 189 | :param tokenizer: 分词器 190 | :param start_sign: 句子开始标记 191 | :param end_sign: 句子结束标记 192 | :return: 处理好的句子和decoder输入 193 | """ 194 | sentence = " ".join(jieba.cut(sentence)) 195 | sentence = start_sign + " " + sentence + " " + end_sign 196 | 197 | inputs = tokenizer.texts_to_sequences([sentence]) 198 | inputs = pad_sequences(inputs, maxlen=max_length, padding="post") 199 | 200 | return inputs 201 | -------------------------------------------------------------------------------- /dialogue/pytorch/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """模型功能顶层封装类,包含train、evaluate等等模式 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | import time 23 | import torch 24 | from torch.utils.data import DataLoader 25 | from torch.optim import Optimizer 26 | from dialogue.pytorch.load_dataset import load_data 27 | from dialogue.pytorch.utils import save_checkpoint 28 | from dialogue.tools import get_dict_string 29 | from dialogue.tools import ProgressBar 30 | from typing import AnyStr 31 | from typing import Dict 32 | from typing import NoReturn 33 | from typing import Tuple 34 | 35 | 36 | class Modules(abc.ABC): 37 | def __init__(self, batch_size: int, max_sentence: int, train_data_type: str, valid_data_type: str, 38 | dict_path: str = "", num_workers: int = 2, model: torch.nn.Module = None, 39 | encoder: torch.nn.Module = None, decoder: torch.nn.Module = None, 40 | device: torch.device = None) -> NoReturn: 41 | """model以及(encoder,decoder)两类模型传其中一种即可,具体在各自继承之后的训练步中使用 42 | Note: 43 | a): 模型训练指标中,保证至少返回到当前batch为止的平均训练损失 44 | 45 | :param batch_size: Dataset加载批大小 46 | :param max_sentence: 最大句子长度 47 | :param train_data_type: 读取训练数据类型,单轮/多轮... 48 | :param valid_data_type: 读取验证数据类型,单轮/多轮... 49 | :param dict_path: 字典路径,若使用phoneme则不用传 50 | :param num_workers: 数据加载器的工作线程 51 | :param model: 模型 52 | :param encoder: encoder模型 53 | :param decoder: decoder模型 54 | :param device: 指定运行设备 55 | :return: 返回历史指标数据 56 | """ 57 | self.batch_size = batch_size 58 | self.max_sentence = max_sentence 59 | self.train_data_type = train_data_type 60 | self.valid_data_type = valid_data_type 61 | self.dict_path = dict_path 62 | self.num_workers = num_workers 63 | self.model = model 64 | self.encoder = encoder 65 | self.decoder = decoder 66 | self.device = device 67 | 68 | @abc.abstractmethod 69 | def _train_step(self, batch_dataset: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 70 | optimizer: Optimizer, *args, **kwargs) -> Dict: 71 | """该方法用于定于训练步中,模型实际训练的核心代码(在train方法中使用) 72 | 73 | Note: 74 | a): 返回所得指标字典 75 | b): batch_dataset、optimizer为模型训练必需 76 | """ 77 | 78 | raise NotImplementedError("Must be implemented in subclasses.") 79 | 80 | @abc.abstractmethod 81 | def _valid_step(self, loader: DataLoader, steps_per_epoch: int, 82 | progress_bar: ProgressBar, *args, **kwargs) -> Dict: 83 | """ 该方法用于定义验证模型逻辑 84 | 85 | Note: 86 | a): 返回所得指标字典 87 | b): DataLoader为模型验证必需 88 | """ 89 | 90 | raise NotImplementedError("Must be implemented in subclasses.") 91 | 92 | def train(self, optimizer: torch.optim.Optimizer, train_data_path: str, epochs: int, checkpoint_save_freq: int, 93 | checkpoint_dir: str = "", valid_data_split: float = 0.0, max_train_data_size: int = 0, 94 | valid_data_path: str = "", max_valid_data_size: int = 0, history: dict = {}, **kwargs) -> Dict: 95 | """ 训练模块 96 | 97 | :param optimizer: 优化器 98 | :param train_data_path: 文本数据路径 99 | :param epochs: 训练周期 100 | :param checkpoint_save_freq: 检查点保存频率 101 | :param checkpoint_dir: 检查点保存目录路径 102 | :param valid_data_split: 用于从训练数据中划分验证数据 103 | :param max_train_data_size: 最大训练数据量 104 | :param valid_data_path: 验证数据文本路径 105 | :param max_valid_data_size: 最大验证数据量 106 | :param history: 用于保存训练过程中的历史指标数据 107 | :return: 返回历史指标数据 108 | """ 109 | print('训练开始,正在准备数据中...') 110 | train_loader, valid_loader, train_steps_per_epoch, valid_steps_per_epoch = load_data( 111 | dict_path=self.dict_path, batch_size=self.batch_size, train_data_type=self.train_data_type, 112 | valid_data_type=self.valid_data_type, max_sentence=self.max_sentence, valid_data_split=valid_data_split, 113 | train_data_path=train_data_path, valid_data_path=valid_data_path, max_train_data_size=max_train_data_size, 114 | max_valid_data_size=max_valid_data_size, num_workers=self.num_workers, **kwargs 115 | ) 116 | 117 | progress_bar = ProgressBar() 118 | 119 | for epoch in range(epochs): 120 | print("Epoch {}/{}".format(epoch + 1, epochs)) 121 | start_time = time.time() 122 | 123 | progress_bar.reset(total=train_steps_per_epoch, num=self.batch_size) 124 | 125 | for (batch, batch_dataset) in enumerate(train_loader): 126 | train_metrics = self._train_step(batch_dataset=batch_dataset, optimizer=optimizer, **kwargs) 127 | 128 | progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metrics)) 129 | 130 | progress_bar.done(step_time=time.time() - start_time) 131 | 132 | for key, value in train_metrics.items(): 133 | history[key].append(value) 134 | 135 | if (epoch + 1) % checkpoint_save_freq == 0: 136 | save_checkpoint(checkpoint_dir=checkpoint_dir, optimizer=optimizer, 137 | model=self.model, encoder=self.encoder, decoder=self.decoder) 138 | 139 | if valid_steps_per_epoch == 0 or valid_loader is None: 140 | print("验证数据量过小,小于batch_size,已跳过验证轮次") 141 | else: 142 | valid_metrics = self._valid_step(loader=valid_loader, progress_bar=progress_bar, 143 | steps_per_epoch=valid_steps_per_epoch, **kwargs) 144 | 145 | for key, value in valid_metrics.items(): 146 | history[key].append(value) 147 | 148 | print("训练结束") 149 | return history 150 | 151 | def evaluate(self, valid_data_path: str = "", max_valid_data_size: int = 0, **kwargs) -> NoReturn: 152 | """ 验证模块 153 | 154 | :param valid_data_path: 验证数据文本路径 155 | :param max_valid_data_size: 最大验证数据量 156 | :return: 返回历史指标数据 157 | """ 158 | print("验证开始,正在准备数据中") 159 | _, valid_loader, _, valid_steps_per_epoch = load_data( 160 | dict_path=self.dict_path, batch_size=self.batch_size, train_data_type=self.train_data_type, 161 | valid_data_type=self.valid_data_type, max_sentence=self.max_sentence, valid_data_path=valid_data_path, 162 | max_valid_data_size=max_valid_data_size, num_workers=self.num_workers, **kwargs 163 | ) 164 | 165 | progress_bar = ProgressBar() 166 | _ = self._valid_step(loader=valid_loader, progress_bar=progress_bar, 167 | steps_per_epoch=valid_steps_per_epoch, **kwargs) 168 | 169 | print("验证结束") 170 | 171 | @abc.abstractmethod 172 | def inference(self, *args, **kwargs) -> AnyStr: 173 | """ 对话推断模块 174 | """ 175 | 176 | raise NotImplementedError("Must be implemented in subclasses.") 177 | -------------------------------------------------------------------------------- /dialogue/tensorflow/seq2seq/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """seq2seq的模型功能实现,包含train模式、evaluate模式、chat模式 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import time 22 | import tensorflow as tf 23 | from typing import AnyStr 24 | from typing import Dict 25 | from typing import NoReturn 26 | from typing import Tuple 27 | from dialogue.tensorflow.beamsearch import BeamSearch 28 | from dialogue.tensorflow.modules import Modules 29 | from dialogue.tensorflow.optimizers import loss_func_mask 30 | from dialogue.tools import load_tokenizer 31 | from dialogue.tools import preprocess_request 32 | from dialogue.tools import ProgressBar 33 | 34 | 35 | class Seq2SeqModule(Modules): 36 | def __init__(self, loss_metric: tf.keras.metrics.Mean = None, batch_size: int = 0, buffer_size: int = 0, 37 | accuracy_metric: tf.keras.metrics.SparseCategoricalAccuracy = None, max_sentence: int = 0, 38 | train_data_type: str = "", valid_data_type: str = "", dict_path: str = "", 39 | model: tf.keras.Model = None, encoder: tf.keras.Model = None, decoder: tf.keras.Model = None): 40 | super(Seq2SeqModule, self).__init__( 41 | loss_metric=loss_metric, accuracy_metric=accuracy_metric, train_data_type=train_data_type, 42 | valid_data_type=valid_data_type, batch_size=batch_size, buffer_size=buffer_size, max_sentence=max_sentence, 43 | dict_path=dict_path, model=model, encoder=encoder, decoder=decoder 44 | ) 45 | 46 | def _save_model(self, **kwargs) -> NoReturn: 47 | self.encoder.save(filepath=kwargs["encoder_save_path"]) 48 | self.decoder.save(filepath=kwargs["decoder_save_path"]) 49 | print("模型已保存为SaveModel格式") 50 | 51 | @tf.function(autograph=True) 52 | def _train_step(self, batch_dataset: tuple, optimizer: tf.optimizers.Adam, *args, **kwargs) -> Dict: 53 | """训练步 54 | 55 | :param batch_dataset: 训练步的当前batch数据 56 | :param optimizer: 优化器 57 | :return: 返回所得指标字典 58 | """ 59 | loss = 0. 60 | inputs, targets, weights = batch_dataset 61 | 62 | with tf.GradientTape() as tape: 63 | enc_output, states = self.encoder(inputs=inputs) 64 | dec_input = tf.expand_dims(input=[kwargs.get("start_sign", 2)] * self.batch_size, axis=1) 65 | for t in range(1, self.max_sentence): 66 | predictions, states, _ = self.decoder(inputs=[dec_input, enc_output, states]) 67 | loss += loss_func_mask(real=targets[:, t], pred=predictions, weights=weights) 68 | self.accuracy_metric(targets[:, t], predictions) 69 | dec_input = tf.expand_dims(targets[:, t], 1) 70 | 71 | self.loss_metric(loss) 72 | variables = self.encoder.trainable_variables + self.decoder.trainable_variables 73 | gradients = tape.gradient(target=loss, sources=variables) 74 | optimizer.apply_gradients(zip(gradients, variables)) 75 | 76 | return {"train_loss": self.loss_metric.result(), "train_accuracy": self.accuracy_metric.result()} 77 | 78 | def _valid_step(self, dataset: tf.data.Dataset, steps_per_epoch: int, 79 | progress_bar: ProgressBar, *args, **kwargs) -> Dict: 80 | """ 验证步 81 | 82 | :param dataset: 验证步的dataset 83 | :param steps_per_epoch: 验证总步数 84 | :param progress_bar: 进度管理器 85 | :return: 返回所得指标字典 86 | """ 87 | print("验证轮次") 88 | start_time = time.time() 89 | self.loss_metric.reset_states() 90 | self.accuracy_metric.reset_states() 91 | progress_bar = ProgressBar(total=steps_per_epoch, num=self.batch_size) 92 | 93 | for (batch, (inputs, target, _)) in enumerate(dataset.take(steps_per_epoch)): 94 | loss = self._valid_one_step(inputs=inputs, target=target, **kwargs) 95 | self.loss_metric(loss) 96 | progress_bar(current=batch + 1, metrics="- train_loss: {:.4f} - train_accuracy: {:.4f}" 97 | .format(self.loss_metric.result(), self.accuracy_metric.result())) 98 | 99 | progress_bar.done(step_time=time.time() - start_time) 100 | 101 | return {"valid_loss": self.loss_metric.result(), "valid_accuracy": self.accuracy_metric.result()} 102 | 103 | @tf.function(autograph=True) 104 | def _valid_one_step(self, inputs: tf.Tensor, target: tf.Tensor, **kwargs) -> Tuple: 105 | loss = 0 106 | enc_output, states = self.encoder(inputs=inputs) 107 | dec_input = tf.expand_dims(input=[kwargs.get("start_sign", 2)] * self.batch_size, axis=1) 108 | for t in range(1, self.max_sentence): 109 | predictions, states, _ = self.decoder(inputs=[dec_input, enc_output, states]) 110 | loss += loss_func_mask(real=target[:, t], pred=predictions) 111 | dec_input = tf.expand_dims(target[:, t], 1) 112 | 113 | self.accuracy_metric(target[:, t], predictions) 114 | 115 | return loss 116 | 117 | def inference(self, request: str, beam_size: int, start_sign: str = "", end_sign: str = "") -> AnyStr: 118 | """ 对话推断模块 119 | 120 | :param request: 输入句子 121 | :param beam_size: beam大小 122 | :param start_sign: 句子开始标记 123 | :param end_sign: 句子结束标记 124 | :return: 返回历史指标数据 125 | """ 126 | tokenizer = load_tokenizer(self.dict_path) 127 | 128 | enc_input = preprocess_request(sentence=request, tokenizer=tokenizer, 129 | max_length=self.max_sentence, start_sign=start_sign, end_sign=end_sign) 130 | enc_output, states = self.encoder(inputs=enc_input) 131 | dec_input = tf.expand_dims([tokenizer.word_index.get(start_sign)], 0) 132 | 133 | beam_search_container = BeamSearch(beam_size=beam_size, max_length=self.max_sentence, worst_score=0) 134 | beam_search_container.reset(enc_output=enc_output, dec_input=dec_input, remain=states) 135 | enc_output, dec_input, states = beam_search_container.get_search_inputs() 136 | 137 | for t in range(self.max_sentence): 138 | predictions, _, _ = self.decoder(inputs=[dec_input, enc_output, states]) 139 | predictions = tf.nn.softmax(predictions, axis=-1) 140 | 141 | beam_search_container.expand(predictions=predictions, end_sign=tokenizer.word_index.get(end_sign)) 142 | if beam_search_container.beam_size == 0: 143 | break 144 | 145 | enc_output, dec_input, states = beam_search_container.get_search_inputs() 146 | dec_input = tf.expand_dims(input=dec_input[:, -1], axis=-1) 147 | 148 | beam_search_result = beam_search_container.get_result(top_k=3) 149 | result = "" 150 | # 从容器中抽取序列,生成最终结果 151 | for i in range(len(beam_search_result)): 152 | temp = beam_search_result[i].numpy() 153 | text = tokenizer.sequences_to_texts(temp) 154 | text[0] = text[0].replace(start_sign, "").replace(end_sign, "").replace(" ", "") 155 | result = "<" + text[0] + ">" + result 156 | return result 157 | -------------------------------------------------------------------------------- /dialogue/tensorflow/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DengBoCong. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """模型功能顶层封装类,包含train、evaluate等等模式 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | import time 23 | import tensorflow as tf 24 | from dialogue.tensorflow.load_dataset import load_data 25 | from dialogue.tools import get_dict_string 26 | from dialogue.tools import ProgressBar 27 | from typing import AnyStr 28 | from typing import Dict 29 | from typing import NoReturn 30 | 31 | 32 | class Modules(abc.ABC): 33 | def __init__(self, loss_metric: tf.keras.metrics.Mean, accuracy_metric: tf.keras.metrics.SparseCategoricalAccuracy, 34 | batch_size: int, buffer_size: int, max_sentence: int, train_data_type: str, valid_data_type: str, 35 | dict_path: str = "", model: tf.keras.Model = None, encoder: tf.keras.Model = None, 36 | decoder: tf.keras.Model = None) -> NoReturn: 37 | """model以及(encoder,decoder)两类模型传其中一种即可,具体在各自继承之后的训练步中使用 38 | Note: 39 | a): 模型训练指标中,损失器和精度器必传,保证至少返回到当前batch为止的平均训练损失和训练精度 40 | 41 | :param loss_metric: 损失计算器 42 | :param accuracy_metric: 精度计算器 43 | :param batch_size: Dataset加载批大小 44 | :param buffer_size: Dataset加载缓存大小 45 | :param max_sentence: 最大句子长度 46 | :param train_data_type: 读取训练数据类型,单轮/多轮... 47 | :param valid_data_type: 读取验证数据类型,单轮/多轮... 48 | :param dict_path: 字典路径,若使用phoneme则不用传 49 | :param model: 模型 50 | :param encoder: encoder模型 51 | :param decoder: decoder模型 52 | :return: 返回历史指标数据 53 | """ 54 | self.loss_metric = loss_metric 55 | self.accuracy_metric = accuracy_metric 56 | self.batch_size = batch_size 57 | self.buffer_size = buffer_size 58 | self.max_sentence = max_sentence 59 | self.train_data_type = train_data_type 60 | self.valid_data_type = valid_data_type 61 | self.dict_path = dict_path 62 | self.model = model 63 | self.encoder = encoder 64 | self.decoder = decoder 65 | 66 | @abc.abstractmethod 67 | def _train_step(self, batch_dataset: tuple, optimizer: tf.optimizers.Adam, *args, **kwargs) -> Dict: 68 | """该方法用于定于训练步中,模型实际训练的核心代码(在train方法中使用) 69 | 70 | Note: 71 | a): 返回所得指标字典 72 | b): batch_dataset、optimizer为模型训练必需 73 | """ 74 | 75 | raise NotImplementedError("Must be implemented in subclasses.") 76 | 77 | @abc.abstractmethod 78 | def _valid_step(self, dataset: tf.data.Dataset, steps_per_epoch: int, 79 | progress_bar: ProgressBar, *args, **kwargs) -> Dict: 80 | """ 该方法用于定义验证模型逻辑 81 | 82 | Note: 83 | a): 返回所得指标字典 84 | b): dataset为模型验证必需 85 | """ 86 | 87 | raise NotImplementedError("Must be implemented in subclasses.") 88 | 89 | @abc.abstractmethod 90 | def _save_model(self, **kwargs) -> NoReturn: 91 | """ 将模型保存为SaveModel格式 92 | 93 | Note: 94 | 如果不在train之后保存SaveModel,子类继承实现这个方法时,直接pass即可 95 | """ 96 | 97 | raise NotImplementedError("Must be implemented in subclasses.") 98 | 99 | def train(self, optimizer: tf.optimizers.Adam, checkpoint: tf.train.CheckpointManager, train_data_path: str, 100 | epochs: int, checkpoint_save_freq: int, valid_data_split: float = 0.0, max_train_data_size: int = 0, 101 | valid_data_path: str = "", max_valid_data_size: int = 0, history: dict = {}, **kwargs) -> Dict: 102 | """ 训练模块 103 | 104 | :param optimizer: 优化器 105 | :param checkpoint: 检查点管理器 106 | :param train_data_path: 文本数据路径 107 | :param epochs: 训练周期 108 | :param checkpoint_save_freq: 检查点保存频率 109 | :param valid_data_split: 用于从训练数据中划分验证数据 110 | :param max_train_data_size: 最大训练数据量 111 | :param valid_data_path: 验证数据文本路径 112 | :param max_valid_data_size: 最大验证数据量 113 | :param history: 用于保存训练过程中的历史指标数据 114 | :return: 返回历史指标数据 115 | """ 116 | print("训练开始,正在准备数据中") 117 | train_dataset, valid_dataset, train_steps_per_epoch, valid_steps_per_epoch = load_data( 118 | dict_path=self.dict_path, train_data_path=train_data_path, buffer_size=self.buffer_size, 119 | batch_size=self.batch_size, max_sentence=self.max_sentence, valid_data_split=valid_data_split, 120 | valid_data_path=valid_data_path, max_train_data_size=max_train_data_size, 121 | valid_data_type=self.valid_data_type, max_valid_data_size=max_valid_data_size, 122 | train_data_type=self.train_data_type, **kwargs 123 | ) 124 | 125 | progress_bar = ProgressBar() 126 | 127 | for epoch in range(epochs): 128 | print("Epoch {}/{}".format(epoch + 1, epochs)) 129 | 130 | start_time = time.time() 131 | self.loss_metric.reset_states() 132 | self.accuracy_metric.reset_states() 133 | progress_bar.reset(total=train_steps_per_epoch, num=self.batch_size) 134 | 135 | for (batch, batch_dataset) in enumerate(train_dataset.take(train_steps_per_epoch)): 136 | train_metrics = self._train_step(batch_dataset=batch_dataset, optimizer=optimizer, **kwargs) 137 | 138 | progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metrics)) 139 | 140 | progress_bar.done(step_time=time.time() - start_time) 141 | 142 | for key, value in train_metrics.items(): 143 | history[key].append(value) 144 | 145 | if (epoch + 1) % checkpoint_save_freq == 0: 146 | checkpoint.save() 147 | 148 | if valid_steps_per_epoch == 0 or valid_dataset is None: 149 | print("验证数据量过小,小于batch_size,已跳过验证轮次") 150 | else: 151 | valid_metrics = self._valid_step(dataset=valid_dataset, progress_bar=progress_bar, 152 | steps_per_epoch=valid_steps_per_epoch, **kwargs) 153 | 154 | for key, value in valid_metrics.items(): 155 | history[key].append(value) 156 | 157 | print("训练结束") 158 | self._save_model(**kwargs) 159 | return history 160 | 161 | def evaluate(self, valid_data_path: str = "", max_valid_data_size: int = 0, **kwargs) -> NoReturn: 162 | """ 验证模块 163 | 164 | :param valid_data_path: 验证数据文本路径 165 | :param max_valid_data_size: 最大验证数据量 166 | :return: 返回历史指标数据 167 | """ 168 | print("验证开始,正在准备数据中") 169 | _, valid_dataset, _, valid_steps_per_epoch = load_data( 170 | dict_path=self.dict_path, valid_data_path=valid_data_path, valid_data_type=self.valid_data_type, 171 | buffer_size=self.buffer_size, train_data_type=self.train_data_type, batch_size=self.batch_size, 172 | max_sentence=self.max_sentence, max_valid_data_size=max_valid_data_size, **kwargs 173 | ) 174 | 175 | progress_bar = ProgressBar() 176 | _ = self._valid_step(dataset=valid_dataset, progress_bar=progress_bar, 177 | steps_per_epoch=valid_steps_per_epoch, **kwargs) 178 | 179 | print("验证结束") 180 | 181 | @abc.abstractmethod 182 | def inference(self, *args, **kwargs) -> AnyStr: 183 | """ 对话推断模块 184 | """ 185 | 186 | raise NotImplementedError("Must be implemented in subclasses.") 187 | --------------------------------------------------------------------------------