├── README.md ├── code ├── bert4keras_5_8 │ ├── __init__.py │ ├── backend.py │ ├── layers.py │ ├── models.py │ ├── optimizers.py │ ├── snippets.py │ └── tokenizers.py ├── cfg.py ├── conversation_client.py ├── conversation_server.py ├── data_deal │ ├── base_input.py │ ├── input_ct.py │ ├── input_goal.py │ ├── input_rc.py │ ├── pre_trans.py │ └── trans_output.py ├── model │ ├── bert_lm.py │ ├── extract_embedding.py │ ├── model_context.py │ ├── model_goal.py │ ├── model_rc.py │ └── model_recall.py ├── predict │ ├── check_predict.py │ ├── check_predict_lm_ct.py │ ├── predict_final.py │ ├── predict_lm.py │ └── predict_lm_ct.py ├── score_fn.py ├── test_answer.py ├── test_client.py ├── train │ ├── train_bert_lm.py │ ├── train_ct.py │ ├── train_goal.py │ ├── train_rc.py │ └── z_t.py └── utils │ ├── sif.py │ └── snippet.py └── data └── roberta ├── bert_config.json └── vocab.txt /README.md: -------------------------------------------------------------------------------- 1 | # bd-chat-2020 2 | 2020语言与智能技术竞赛:面向推荐的对话任务 3 | 第二名 强行跳大 团队 4 | 5 | # 介绍 6 | * 最终只使用了集成Goal预测、文本回复的单向注意力Bert模型。此项目还附带了一些抛弃的尝试方案,包含了基础的阅读理解模型(model/model_rc.py)、分类模型(model/model_context.py,model/model_goal.py)、QA召回(检索)模型(model/model_recall.py)的实现。 7 | * 训练的数据存放位置和输出位置可以参考 cfg.py。训练的顺序是先运行 data_deal/pre_trans.py,再运行train/train_bert_lm.py。 8 | * 预测test集的代码是code/predict/predict_lm.py,人工评估阶段使用的是code/predict/predict_final.py里的预测 9 | 10 | # 数据 11 | https://aistudio.baidu.com/aistudio/competition/detail/48 12 | 里面的推荐对话任务数据即是,格式一致 13 | 14 | # 依赖 15 | ```text 16 | keras==2.3.1 17 | tensorflow-gpu==1.14.0 18 | easydict 19 | nltk==3.5 20 | ``` 21 | 22 | # 文件介绍: 23 | * conversation_client.py: 类似官方的那种client 24 | * test_client.py: 交互式的client 25 | * conversation_server.py: 类似官方的服务端。 26 | 27 | # 参考 28 | * bert4keras:https://github.com/bojone/bert4keras 29 | 30 | # 注意 31 | * 不要覆盖 data/roberta 里的 vocab.txt 和 config 32 | -------------------------------------------------------------------------------- /code/bert4keras_5_8/__init__.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import warnings 5 | 6 | 7 | __version__ = '0.5.8' 8 | 9 | 10 | class Legacy1: 11 | """向后兼容 12 | """ 13 | def __init__(self): 14 | import bert4keras_5_8.models 15 | self.models = bert4keras_5_8.models 16 | self.__all__ = [k for k in dir(self.models) if k[0] != '_'] 17 | 18 | def __getattr__(self, attr): 19 | """使得 from bert4keras_5_8.bert import xxx 20 | 等价于from bert4keras_5_8.models import xxx 21 | """ 22 | warnings.warn('bert4keras_5_8.bert has been renamed as bert4keras_5_8.models.') 23 | warnings.warn('please use bert4keras_5_8.models.') 24 | return getattr(self.models, attr) 25 | 26 | 27 | Legacy1.__name__ = 'bert4keras_5_8.bert' 28 | sys.modules[Legacy1.__name__] = Legacy1() 29 | del Legacy1 30 | 31 | 32 | class Legacy2: 33 | """向后兼容 34 | """ 35 | def __init__(self): 36 | import bert4keras_5_8.tokenizers 37 | self.tokenizers = bert4keras_5_8.tokenizers 38 | self.__all__ = [k for k in dir(self.tokenizers) if k[0] != '_'] 39 | 40 | def __getattr__(self, attr): 41 | """使得 from bert4keras_5_8.tokenizer import xxx 42 | 等价于from bert4keras_5_8.tokenizers import xxx 43 | """ 44 | warnings.warn('bert4keras_5_8.tokenizer has been renamed as bert4keras_5_8.tokenizers.') 45 | warnings.warn('please use bert4keras_5_8.tokenizers.') 46 | return getattr(self.tokenizers, attr) 47 | 48 | 49 | Legacy2.__name__ = 'bert4keras_5_8.tokenizer' 50 | sys.modules[Legacy2.__name__] = Legacy2() 51 | del Legacy2 52 | -------------------------------------------------------------------------------- /code/bert4keras_5_8/backend.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 分离后端函数,主要是为了同时兼容原生keras和tf.keras 3 | # 通过设置环境变量TF_KERAS=1来切换tf.keras 4 | 5 | import os, sys 6 | from distutils.util import strtobool 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | 11 | # 判断是tf.keras还是纯keras的标记 12 | is_tf_keras = strtobool(os.environ.get('TF_KERAS', '0')) 13 | 14 | if is_tf_keras: 15 | import tensorflow.keras as keras 16 | import tensorflow.keras.backend as K 17 | sys.modules['keras'] = keras 18 | else: 19 | import keras 20 | import keras.backend as K 21 | 22 | 23 | def gelu_erf(x): 24 | """基于Erf直接计算的gelu函数 25 | """ 26 | return 0.5 * x * (1.0 + tf.math.erf(x / np.sqrt(2.0))) 27 | 28 | 29 | def gelu_tanh(x): 30 | """基于Tanh近似计算的gelu函数 31 | """ 32 | cdf = 0.5 * (1.0 + K.tanh( 33 | (np.sqrt(2 / np.pi) * (x + 0.044715 * K.pow(x, 3))))) 34 | return x * cdf 35 | 36 | 37 | def set_gelu(version): 38 | """设置gelu版本 39 | """ 40 | version = version.lower() 41 | assert version in ['erf', 'tanh'], 'gelu version must be erf or tanh' 42 | if version == 'erf': 43 | keras.utils.get_custom_objects()['gelu'] = gelu_erf 44 | else: 45 | keras.utils.get_custom_objects()['gelu'] = gelu_tanh 46 | 47 | 48 | def piecewise_linear(t, schedule): 49 | """分段线性函数 50 | 其中schedule是形如{1000: 1, 2000: 0.1}的字典, 51 | 表示 t ∈ [0, 1000]时,输出从0均匀增加至1,而 52 | t ∈ [1000, 2000]时,输出从1均匀降低到0.1,最后 53 | t > 2000时,保持0.1不变。 54 | """ 55 | schedule = sorted(schedule.items()) 56 | if schedule[0][0] != 0: 57 | schedule = [(0, 0.)] + schedule 58 | 59 | x = K.constant(schedule[0][1], dtype=K.floatx()) 60 | t = K.cast(t, K.floatx()) 61 | for i in range(len(schedule)): 62 | t_begin = schedule[i][0] 63 | x_begin = x 64 | if i != len(schedule) - 1: 65 | dx = schedule[i + 1][1] - schedule[i][1] 66 | dt = schedule[i + 1][0] - schedule[i][0] 67 | slope = 1. * dx / dt 68 | x = schedule[i][1] + slope * (t - t_begin) 69 | else: 70 | x = K.constant(schedule[i][1], dtype=K.floatx()) 71 | x = K.switch(t >= t_begin, x, x_begin) 72 | 73 | return x 74 | 75 | 76 | def search_layer(inputs, name, exclude_from=None): 77 | """根据inputs和name来搜索层 78 | 说明:inputs为某个层或某个层的输出;name为目标层的名字。 79 | 实现:根据inputs一直往上递归搜索,直到发现名字为name的层为止; 80 | 如果找不到,那就返回None。 81 | """ 82 | if exclude_from is None: 83 | exclude_from = set() 84 | 85 | if isinstance(inputs, keras.layers.Layer): 86 | layer = inputs 87 | else: 88 | layer = inputs._keras_history[0] 89 | 90 | if layer.name == name: 91 | return layer 92 | elif layer in exclude_from: 93 | return None 94 | else: 95 | exclude_from.add(layer) 96 | if isinstance(layer, keras.models.Model): 97 | model = layer 98 | for layer in model.layers: 99 | if layer.name == name: 100 | return layer 101 | inbound_layers = layer._inbound_nodes[0].inbound_layers 102 | if not isinstance(inbound_layers, list): 103 | inbound_layers = [inbound_layers] 104 | if len(inbound_layers) > 0: 105 | for layer in inbound_layers: 106 | layer = search_layer(layer, name, exclude_from) 107 | if layer is not None: 108 | return layer 109 | 110 | 111 | def sequence_masking(x, mask, mode=0, axis=None): 112 | """为序列条件mask的函数 113 | mask: 形如(batch_size, seq_len)的0-1矩阵; 114 | mode: 如果是0,则直接乘以mask; 115 | 如果是1,则在padding部分减去一个大正数。 116 | axis: 序列所在轴,默认为1; 117 | """ 118 | if mask is None or mode not in [0, 1]: 119 | return x 120 | else: 121 | if axis is None: 122 | axis = 1 123 | if axis == -1: 124 | axis = K.ndim(x) - 1 125 | assert axis > 0, 'axis muse be greater than 0' 126 | for _ in range(axis - 1): 127 | mask = K.expand_dims(mask, 1) 128 | for _ in range(K.ndim(x) - K.ndim(mask) - axis + 1): 129 | mask = K.expand_dims(mask, K.ndim(mask)) 130 | if mode == 0: 131 | return x * mask 132 | else: 133 | return x - (1 - mask) * 1e12 134 | 135 | 136 | def batch_gather(params, indices): 137 | """同tf旧版本的batch_gather 138 | """ 139 | try: 140 | return tf.gather(params, indices, batch_dims=-1) 141 | except Exception as e1: 142 | try: 143 | return tf.batch_gather(params, indices) 144 | except Exception as e2: 145 | raise ValueError('%s\n%s\n' % (e1.message, e2.message)) 146 | 147 | 148 | def pool1d(x, 149 | pool_size, 150 | strides=1, 151 | padding='valid', 152 | data_format=None, 153 | pool_mode='max'): 154 | """向量序列的pool函数 155 | """ 156 | x = K.expand_dims(x, 1) 157 | x = K.pool2d(x, 158 | pool_size=(1, pool_size), 159 | strides=(1, strides), 160 | padding=padding, 161 | data_format=data_format, 162 | pool_mode=pool_mode) 163 | return x[:, 0] 164 | 165 | 166 | def divisible_temporal_padding(x, n): 167 | """将一维向量序列右padding到长度能被n整除 168 | """ 169 | r_len = K.shape(x)[1] % n 170 | p_len = K.switch(r_len > 0, n - r_len, 0) 171 | return K.temporal_padding(x, (0, p_len)) 172 | 173 | 174 | def swish(x): 175 | """swish函数(这样封装过后才有 __name__ 属性) 176 | """ 177 | return tf.nn.swish(x) 178 | 179 | 180 | def leaky_relu(x, alpha=0.2): 181 | """leaky relu函数(这样封装过后才有 __name__ 属性) 182 | """ 183 | return tf.nn.leaky_relu(x, alpha=alpha) 184 | 185 | 186 | def symbolic(f): 187 | """恒等装饰器(兼容旧版本keras用) 188 | """ 189 | return f 190 | 191 | 192 | # 给旧版本keras新增symbolic方法(装饰器), 193 | # 以便兼容optimizers.py中的代码 194 | K.symbolic = getattr(K, 'symbolic', None) or symbolic 195 | 196 | custom_objects = { 197 | 'gelu_erf': gelu_erf, 198 | 'gelu_tanh': gelu_tanh, 199 | 'gelu': gelu_erf, 200 | 'swish': swish, 201 | 'leaky_relu': leaky_relu, 202 | } 203 | 204 | keras.utils.get_custom_objects().update(custom_objects) 205 | -------------------------------------------------------------------------------- /code/bert4keras_5_8/optimizers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 优化相关 3 | 4 | import tensorflow as tf 5 | from bert4keras_5_8.backend import keras, K, is_tf_keras 6 | from bert4keras_5_8.snippets import is_string, string_matching 7 | from bert4keras_5_8.snippets import is_one_of 8 | from bert4keras_5_8.backend import piecewise_linear 9 | import re 10 | 11 | 12 | class Adam(keras.optimizers.Optimizer): 13 | """重新定义Adam优化器,便于派生出新的优化器 14 | (tensorflow的optimizer_v2类) 15 | """ 16 | def __init__(self, 17 | learning_rate=0.001, 18 | beta_1=0.9, 19 | beta_2=0.999, 20 | epsilon=1e-6, 21 | bias_correction=True, 22 | name='Adam', 23 | **kwargs): 24 | kwargs['name'] = name 25 | super(Adam, self).__init__(**kwargs) 26 | self._set_hyper('learning_rate', learning_rate) 27 | self._set_hyper('beta_1', beta_1) 28 | self._set_hyper('beta_2', beta_2) 29 | self.epsilon = epsilon or K.epislon() 30 | self.bias_correction = bias_correction 31 | 32 | def _create_slots(self, var_list): 33 | for var in var_list: 34 | self.add_slot(var, 'm') 35 | self.add_slot(var, 'v') 36 | 37 | def _resource_apply_op(self, grad, var, indices=None): 38 | # 准备变量 39 | var_dtype = var.dtype.base_dtype 40 | lr_t = self._decayed_lr(var_dtype) 41 | m = self.get_slot(var, 'm') 42 | v = self.get_slot(var, 'v') 43 | beta_1_t = self._get_hyper('beta_1', var_dtype) 44 | beta_2_t = self._get_hyper('beta_2', var_dtype) 45 | epsilon_t = K.cast(self.epsilon, var_dtype) 46 | local_step = K.cast(self.iterations + 1, var_dtype) 47 | beta_1_t_power = K.pow(beta_1_t, local_step) 48 | beta_2_t_power = K.pow(beta_2_t, local_step) 49 | 50 | # 更新公式 51 | if indices is None: 52 | m_t = K.update(m, beta_1_t * m + (1 - beta_1_t) * grad) 53 | v_t = K.update(v, beta_2_t * v + (1 - beta_2_t) * grad**2) 54 | else: 55 | mv_ops = [K.update(m, beta_1_t * m), K.update(v, beta_2_t * v)] 56 | with tf.control_dependencies(mv_ops): 57 | m_t = self._resource_scatter_add(m, indices, 58 | (1 - beta_1_t) * grad) 59 | v_t = self._resource_scatter_add(v, indices, 60 | (1 - beta_2_t) * grad**2) 61 | 62 | # 返回算子 63 | with tf.control_dependencies([m_t, v_t]): 64 | if self.bias_correction: 65 | m_t = m_t / (1. - beta_1_t_power) 66 | v_t = v_t / (1. - beta_2_t_power) 67 | var_t = var - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) 68 | return K.update(var, var_t) 69 | 70 | def _resource_apply_dense(self, grad, var): 71 | return self._resource_apply_op(grad, var) 72 | 73 | def _resource_apply_sparse(self, grad, var, indices): 74 | return self._resource_apply_op(grad, var, indices) 75 | 76 | def get_config(self): 77 | config = { 78 | 'learning_rate': self._serialize_hyperparameter('learning_rate'), 79 | 'beta_1': self._serialize_hyperparameter('beta_1'), 80 | 'beta_2': self._serialize_hyperparameter('beta_2'), 81 | 'epsilon': self.epsilon, 82 | } 83 | base_config = super(Adam, self).get_config() 84 | return dict(list(base_config.items()) + list(config.items())) 85 | 86 | 87 | def extend_with_weight_decay(base_optimizer, name=None): 88 | """返回新的优化器类,加入权重衰减 89 | """ 90 | class new_optimizer(base_optimizer): 91 | """带有权重衰减的优化器 92 | """ 93 | def __init__(self, 94 | weight_decay_rate, 95 | exclude_from_weight_decay=None, 96 | *args, 97 | **kwargs): 98 | super(new_optimizer, self).__init__(*args, **kwargs) 99 | self.weight_decay_rate = weight_decay_rate 100 | self.exclude_from_weight_decay = exclude_from_weight_decay or [] 101 | if not hasattr(self, 'learning_rate'): 102 | self.learning_rate = self.lr 103 | 104 | @K.symbolic 105 | def get_updates(self, loss, params): 106 | old_update = K.update 107 | 108 | def new_update(x, new_x): 109 | if is_one_of(x, params) and self._do_weight_decay(x): 110 | new_x = new_x - self.learning_rate * self.weight_decay_rate * x 111 | return old_update(x, new_x) 112 | 113 | K.update = new_update 114 | updates = super(new_optimizer, self).get_updates(loss, params) 115 | K.update = old_update 116 | 117 | return updates 118 | 119 | def _do_weight_decay(self, w): 120 | return (not string_matching(w.name, 121 | self.exclude_from_weight_decay)) 122 | 123 | def get_config(self): 124 | config = { 125 | 'weight_decay_rate': self.weight_decay_rate, 126 | 'exclude_from_weight_decay': self.exclude_from_weight_decay 127 | } 128 | base_config = super(new_optimizer, self).get_config() 129 | return dict(list(base_config.items()) + list(config.items())) 130 | 131 | if is_string(name): 132 | new_optimizer.__name__ = name 133 | keras.utils.get_custom_objects()[name] = new_optimizer 134 | 135 | return new_optimizer 136 | 137 | 138 | def extend_with_weight_decay_v2(base_optimizer, name=None): 139 | """返回新的优化器类,加入权重衰减 140 | """ 141 | class new_optimizer(base_optimizer): 142 | """带有权重衰减的优化器 143 | """ 144 | def __init__(self, 145 | weight_decay_rate, 146 | exclude_from_weight_decay=None, 147 | *args, 148 | **kwargs): 149 | super(new_optimizer, self).__init__(*args, **kwargs) 150 | self.weight_decay_rate = weight_decay_rate 151 | self.exclude_from_weight_decay = exclude_from_weight_decay or [] 152 | 153 | def _resource_apply_op(self, grad, var, indices=None): 154 | old_update = K.update 155 | 156 | def new_update(x, new_x): 157 | if x is var and self._do_weight_decay(x): 158 | lr_t = self._decayed_lr(x.dtype.base_dtype) 159 | new_x = new_x - lr_t * self.weight_decay_rate * x 160 | return old_update(x, new_x) 161 | 162 | K.update = new_update 163 | op = super(new_optimizer, 164 | self)._resource_apply_op(grad, var, indices) 165 | K.update = old_update 166 | 167 | return op 168 | 169 | def _do_weight_decay(self, w): 170 | return (not string_matching(w.name, 171 | self.exclude_from_weight_decay)) 172 | 173 | def get_config(self): 174 | config = { 175 | 'weight_decay_rate': self.weight_decay_rate, 176 | 'exclude_from_weight_decay': self.exclude_from_weight_decay 177 | } 178 | base_config = super(new_optimizer, self).get_config() 179 | return dict(list(base_config.items()) + list(config.items())) 180 | 181 | if is_string(name): 182 | new_optimizer.__name__ = name 183 | keras.utils.get_custom_objects()[name] = new_optimizer 184 | 185 | return new_optimizer 186 | 187 | 188 | def extend_with_layer_adaptation(base_optimizer, name=None): 189 | """返回新的优化器类,加入层自适应学习率 190 | """ 191 | class new_optimizer(base_optimizer): 192 | """带有层自适应学习率的优化器 193 | 用每一层参数的模长来校正当前参数的学习率 194 | https://arxiv.org/abs/1904.00962 195 | """ 196 | def __init__(self, exclude_from_layer_adaptation=None, *args, 197 | **kwargs): 198 | super(new_optimizer, self).__init__(*args, **kwargs) 199 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation or [] 200 | if not hasattr(self, 'learning_rate'): 201 | self.learning_rate = self.lr 202 | 203 | @K.symbolic 204 | def get_updates(self, loss, params): 205 | old_update = K.update 206 | 207 | def new_update(x, new_x): 208 | if is_one_of(x, params) and self._do_layer_adaptation(x): 209 | dx = new_x - x 210 | lr_t = K.clip(self.learning_rate, K.epsilon(), 1e10) 211 | x_norm = tf.norm(x) 212 | g_norm = tf.norm(dx / lr_t) 213 | ratio = K.switch( 214 | x_norm > 0., 215 | K.switch(g_norm > K.epsilon(), x_norm / g_norm, 1.), 216 | 1.) 217 | new_x = x + dx * ratio 218 | return old_update(x, new_x) 219 | 220 | K.update = new_update 221 | updates = super(new_optimizer, self).get_updates(loss, params) 222 | K.update = old_update 223 | 224 | return updates 225 | 226 | def _do_layer_adaptation(self, w): 227 | return (not string_matching(w.name, 228 | self.exclude_from_layer_adaptation)) 229 | 230 | def get_config(self): 231 | config = { 232 | 'exclude_from_layer_adaptation': 233 | self.exclude_from_layer_adaptation 234 | } 235 | base_config = super(new_optimizer, self).get_config() 236 | return dict(list(base_config.items()) + list(config.items())) 237 | 238 | if is_string(name): 239 | new_optimizer.__name__ = name 240 | keras.utils.get_custom_objects()[name] = new_optimizer 241 | 242 | return new_optimizer 243 | 244 | 245 | def extend_with_layer_adaptation_v2(base_optimizer, name=None): 246 | """返回新的优化器类,加入层自适应学习率 247 | """ 248 | class new_optimizer(base_optimizer): 249 | """带有层自适应学习率的优化器 250 | 用每一层参数的模长来校正当前参数的学习率 251 | https://arxiv.org/abs/1904.00962 252 | """ 253 | def __init__(self, exclude_from_layer_adaptation=None, *args, 254 | **kwargs): 255 | super(new_optimizer, self).__init__(*args, **kwargs) 256 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation or [] 257 | 258 | def _resource_apply_op(self, grad, var, indices=None): 259 | old_update = K.update 260 | 261 | def new_update(x, new_x): 262 | if x is var and self._do_layer_adaptation(x): 263 | dx = new_x - x 264 | lr_t = self._decayed_lr(x.dtype.base_dtype) 265 | lr_t = K.clip(lr_t, K.epsilon(), 1e10) 266 | x_norm = tf.norm(x) 267 | g_norm = tf.norm(dx / lr_t) 268 | ratio = K.switch( 269 | x_norm > 0., 270 | K.switch(g_norm > K.epsilon(), x_norm / g_norm, 1.), 271 | 1.) 272 | new_x = x + dx * ratio 273 | return old_update(x, new_x) 274 | 275 | K.update = new_update 276 | op = super(new_optimizer, 277 | self)._resource_apply_op(grad, var, indices) 278 | K.update = old_update 279 | 280 | return op 281 | 282 | def _do_layer_adaptation(self, w): 283 | return (not string_matching(w.name, 284 | self.exclude_from_layer_adaptation)) 285 | 286 | def get_config(self): 287 | config = { 288 | 'exclude_from_layer_adaptation': 289 | self.exclude_from_layer_adaptation 290 | } 291 | base_config = super(new_optimizer, self).get_config() 292 | return dict(list(base_config.items()) + list(config.items())) 293 | 294 | if is_string(name): 295 | new_optimizer.__name__ = name 296 | keras.utils.get_custom_objects()[name] = new_optimizer 297 | 298 | return new_optimizer 299 | 300 | 301 | def extend_with_piecewise_linear_lr(base_optimizer, name=None): 302 | """返回新的优化器类,加入分段线性学习率 303 | """ 304 | class new_optimizer(base_optimizer): 305 | """带有分段线性学习率的优化器 306 | 其中schedule是形如{1000: 1, 2000: 0.1}的字典, 307 | 表示0~1000步内学习率线性地从零增加到100%,然后 308 | 1000~2000步内线性地降到10%,2000步以后保持10% 309 | """ 310 | def __init__(self, lr_schedule, *args, **kwargs): 311 | super(new_optimizer, self).__init__(*args, **kwargs) 312 | self.lr_schedule = {int(i): j for i, j in lr_schedule.items()} 313 | 314 | @K.symbolic 315 | def get_updates(self, loss, params): 316 | lr_multiplier = piecewise_linear(self.iterations, self.lr_schedule) 317 | 318 | old_update = K.update 319 | 320 | def new_update(x, new_x): 321 | if is_one_of(x, params): 322 | new_x = x + (new_x - x) * lr_multiplier 323 | return old_update(x, new_x) 324 | 325 | K.update = new_update 326 | updates = super(new_optimizer, self).get_updates(loss, params) 327 | K.update = old_update 328 | 329 | return updates 330 | 331 | def get_config(self): 332 | config = {'lr_schedule': self.lr_schedule} 333 | base_config = super(new_optimizer, self).get_config() 334 | return dict(list(base_config.items()) + list(config.items())) 335 | 336 | if is_string(name): 337 | new_optimizer.__name__ = name 338 | keras.utils.get_custom_objects()[name] = new_optimizer 339 | 340 | return new_optimizer 341 | 342 | 343 | def extend_with_piecewise_linear_lr_v2(base_optimizer, name=None): 344 | """返回新的优化器类,加入分段线性学习率 345 | """ 346 | class new_optimizer(base_optimizer): 347 | """带有分段线性学习率的优化器 348 | 其中schedule是形如{1000: 1, 2000: 0.1}的字典, 349 | 表示0~1000步内学习率线性地从零增加到100%,然后 350 | 1000~2000步内线性地降到10%,2000步以后保持10% 351 | """ 352 | def __init__(self, lr_schedule, *args, **kwargs): 353 | super(new_optimizer, self).__init__(*args, **kwargs) 354 | self.lr_schedule = {int(i): j for i, j in lr_schedule.items()} 355 | 356 | def _decayed_lr(self, var_dtype): 357 | lr_multiplier = piecewise_linear(self.iterations, self.lr_schedule) 358 | lr_t = super(new_optimizer, self)._decayed_lr(var_dtype) 359 | return lr_t * K.cast(lr_multiplier, var_dtype) 360 | 361 | def get_config(self): 362 | config = {'lr_schedule': self.lr_schedule} 363 | base_config = super(new_optimizer, self).get_config() 364 | return dict(list(base_config.items()) + list(config.items())) 365 | 366 | if is_string(name): 367 | new_optimizer.__name__ = name 368 | keras.utils.get_custom_objects()[name] = new_optimizer 369 | 370 | return new_optimizer 371 | 372 | 373 | def extend_with_gradient_accumulation(base_optimizer, name=None): 374 | """返回新的优化器类,加入梯度累积 375 | """ 376 | class new_optimizer(base_optimizer): 377 | """带有梯度累积的优化器 378 | """ 379 | def __init__(self, grad_accum_steps, *args, **kwargs): 380 | super(new_optimizer, self).__init__(*args, **kwargs) 381 | self.grad_accum_steps = grad_accum_steps 382 | self._first_get_gradients = True 383 | 384 | def get_gradients(self, loss, params): 385 | if self._first_get_gradients: 386 | self._first_get_gradients = False 387 | return super(new_optimizer, self).get_gradients(loss, params) 388 | else: 389 | return [ag / self.grad_accum_steps for ag in self.accum_grads] 390 | 391 | @K.symbolic 392 | def get_updates(self, loss, params): 393 | # 更新判据 394 | cond = K.equal(self.iterations % self.grad_accum_steps, 0) 395 | cond = K.cast(cond, K.floatx()) 396 | # 获取梯度 397 | grads = self.get_gradients(loss, params) 398 | self.accum_grads = [ 399 | K.zeros(K.int_shape(p), 400 | dtype=K.dtype(p), 401 | name='accum_grad_%s' % i) for i, p in enumerate(params) 402 | ] 403 | 404 | old_update = K.update 405 | 406 | def new_update(x, new_x): 407 | new_x = cond * new_x + (1 - cond) * x 408 | return old_update(x, new_x) 409 | 410 | K.update = new_update 411 | updates = super(new_optimizer, self).get_updates(loss, params) 412 | K.update = old_update 413 | 414 | # 累积梯度 415 | with tf.control_dependencies(updates): 416 | accum_updates = [ 417 | K.update(ag, g + (1 - cond) * ag) 418 | for g, ag in zip(grads, self.accum_grads) 419 | ] 420 | 421 | return accum_updates 422 | 423 | def get_config(self): 424 | config = {'grad_accum_steps': self.grad_accum_steps} 425 | base_config = super(new_optimizer, self).get_config() 426 | return dict(list(base_config.items()) + list(config.items())) 427 | 428 | if is_string(name): 429 | new_optimizer.__name__ = name 430 | keras.utils.get_custom_objects()[name] = new_optimizer 431 | 432 | return new_optimizer 433 | 434 | 435 | def extend_with_gradient_accumulation_v2(base_optimizer, name=None): 436 | """返回新的优化器类,加入梯度累积 437 | """ 438 | class new_optimizer(base_optimizer): 439 | """带有梯度累积的优化器 440 | """ 441 | def __init__(self, grad_accum_steps, *args, **kwargs): 442 | super(new_optimizer, self).__init__(*args, **kwargs) 443 | self.grad_accum_steps = grad_accum_steps 444 | 445 | def _create_slots(self, var_list): 446 | super(new_optimizer, self)._create_slots(var_list) 447 | for var in var_list: 448 | self.add_slot(var, 'ag') 449 | 450 | def _resource_apply_op(self, grad, var, indices=None): 451 | # 更新判据 452 | cond = K.equal(self.iterations % self.grad_accum_steps, 0) 453 | # 获取梯度 454 | ag = self.get_slot(var, 'ag') 455 | 456 | old_update = K.update 457 | 458 | def new_update(x, new_x): 459 | new_x = K.switch(cond, new_x, x) 460 | return old_update(x, new_x) 461 | 462 | K.update = new_update 463 | ag_t = ag / self.grad_accum_steps 464 | op = super(new_optimizer, self)._resource_apply_op(ag_t, var) 465 | K.update = old_update 466 | 467 | # 累积梯度 468 | with tf.control_dependencies([op]): 469 | ag_t = K.switch(cond, K.zeros_like(ag), ag) 470 | with tf.control_dependencies([K.update(ag, ag_t)]): 471 | if indices is None: 472 | ag_t = K.update(ag, ag + grad) 473 | else: 474 | ag_t = self._resource_scatter_add(ag, indices, grad) 475 | 476 | return ag_t 477 | 478 | def get_config(self): 479 | config = {'grad_accum_steps': self.grad_accum_steps} 480 | base_config = super(new_optimizer, self).get_config() 481 | return dict(list(base_config.items()) + list(config.items())) 482 | 483 | if is_string(name): 484 | new_optimizer.__name__ = name 485 | keras.utils.get_custom_objects()[name] = new_optimizer 486 | 487 | return new_optimizer 488 | 489 | 490 | def extend_with_lookahead(base_optimizer, name=None): 491 | """返回新的优化器类,加入look ahead 492 | """ 493 | class new_optimizer(base_optimizer): 494 | """带有look ahead的优化器 495 | https://arxiv.org/abs/1907.08610 496 | steps_per_slow_update: 即论文中的k; 497 | slow_step_size: 即论文中的alpha。 498 | """ 499 | def __init__(self, 500 | steps_per_slow_update=5, 501 | slow_step_size=0.5, 502 | *args, 503 | **kwargs): 504 | super(new_optimizer, self).__init__(*args, **kwargs) 505 | self.steps_per_slow_update = steps_per_slow_update 506 | self.slow_step_size = slow_step_size 507 | 508 | @K.symbolic 509 | def get_updates(self, loss, params): 510 | updates = super(new_optimizer, self).get_updates(loss, params) 511 | 512 | k, alpha = self.steps_per_slow_update, self.slow_step_size 513 | cond = K.equal(self.iterations % k, 0) 514 | slow_vars = [ 515 | K.zeros(K.int_shape(p), 516 | dtype=K.dtype(p), 517 | name='slow_var_%s' % i) for i, p in enumerate(params) 518 | ] 519 | 520 | with tf.control_dependencies(updates): 521 | slow_updates = [ 522 | K.update(q, K.switch(cond, q + alpha * (p - q), q)) 523 | for p, q in zip(params, slow_vars) 524 | ] 525 | with tf.control_dependencies(slow_updates): 526 | copy_updates = [ 527 | K.update(p, K.switch(cond, q, p)) 528 | for p, q in zip(params, slow_vars) 529 | ] 530 | 531 | return copy_updates 532 | 533 | def get_config(self): 534 | config = { 535 | 'steps_per_slow_update': self.steps_per_slow_update, 536 | 'slow_step_size': self.slow_step_size 537 | } 538 | base_config = super(new_optimizer, self).get_config() 539 | return dict(list(base_config.items()) + list(config.items())) 540 | 541 | if is_string(name): 542 | new_optimizer.__name__ = name 543 | keras.utils.get_custom_objects()[name] = new_optimizer 544 | 545 | return new_optimizer 546 | 547 | 548 | def extend_with_lookahead_v2(base_optimizer, name=None): 549 | """返回新的优化器类,加入look ahead 550 | """ 551 | class new_optimizer(base_optimizer): 552 | """带有look ahead的优化器 553 | https://arxiv.org/abs/1907.08610 554 | steps_per_slow_update: 即论文中的k; 555 | slow_step_size: 即论文中的alpha。 556 | """ 557 | def __init__(self, 558 | steps_per_slow_update=5, 559 | slow_step_size=0.5, 560 | *args, 561 | **kwargs): 562 | super(new_optimizer, self).__init__(*args, **kwargs) 563 | self.steps_per_slow_update = steps_per_slow_update 564 | self.slow_step_size = slow_step_size 565 | 566 | def _create_slots(self, var_list): 567 | super(new_optimizer, self)._create_slots(var_list) 568 | for var in var_list: 569 | self.add_slot(var, 'slow_var') 570 | 571 | def _resource_apply_op(self, grad, var, indices=None): 572 | op = super(new_optimizer, 573 | self)._resource_apply_op(grad, var, indices) 574 | 575 | k, alpha = self.steps_per_slow_update, self.slow_step_size 576 | cond = K.equal(self.iterations % k, 0) 577 | slow_var = self.get_slot(var, 'slow_var') 578 | slow_var_t = slow_var + alpha * (var - slow_var) 579 | 580 | with tf.control_dependencies([op]): 581 | slow_update = K.update(slow_var, 582 | K.switch(cond, slow_var_t, slow_var)) 583 | with tf.control_dependencies([slow_update]): 584 | copy_update = K.update(var, K.switch(cond, slow_var, var)) 585 | 586 | return copy_update 587 | 588 | def get_config(self): 589 | config = { 590 | 'steps_per_slow_update': self.steps_per_slow_update, 591 | 'slow_step_size': self.slow_step_size 592 | } 593 | base_config = super(new_optimizer, self).get_config() 594 | return dict(list(base_config.items()) + list(config.items())) 595 | 596 | if is_string(name): 597 | new_optimizer.__name__ = name 598 | keras.utils.get_custom_objects()[name] = new_optimizer 599 | 600 | return new_optimizer 601 | 602 | 603 | def extend_with_lazy_optimization(base_optimizer, name=None): 604 | """返回新的优化器类,加入懒惰更新 605 | """ 606 | class new_optimizer(base_optimizer): 607 | """带有懒惰更新的优化器 608 | 使得部分权重(尤其是embedding)只有在梯度不等于0时 609 | 才发生更新。 610 | """ 611 | def __init__(self, include_in_lazy_optimization=None, *args, **kwargs): 612 | super(new_optimizer, self).__init__(*args, **kwargs) 613 | self.include_in_lazy_optimization = include_in_lazy_optimization or [] 614 | self._first_get_gradients = True 615 | 616 | def get_gradients(self, loss, params): 617 | if self._first_get_gradients: 618 | self._first_get_gradients = False 619 | return super(new_optimizer, self).get_gradients(loss, params) 620 | else: 621 | return [self.grads[p] for p in params] 622 | 623 | @K.symbolic 624 | def get_updates(self, loss, params): 625 | self.grads = dict(zip(params, self.get_gradients(loss, params))) 626 | 627 | old_update = K.update 628 | 629 | def new_update(x, new_x): 630 | if is_one_of(x, params) and self._do_lazy_optimization(x): 631 | g = self.grads[x] 632 | r = K.any(K.not_equal(g, 0.), axis=-1, keepdims=True) 633 | new_x = x + (new_x - x) * K.cast(r, K.floatx()) 634 | return old_update(x, new_x) 635 | 636 | K.update = new_update 637 | updates = super(new_optimizer, self).get_updates(loss, params) 638 | K.update = old_update 639 | 640 | return updates 641 | 642 | def _do_lazy_optimization(self, w): 643 | return string_matching(w.name, self.include_in_lazy_optimization) 644 | 645 | def get_config(self): 646 | config = { 647 | 'include_in_lazy_optimization': 648 | self.include_in_lazy_optimization 649 | } 650 | base_config = super(new_optimizer, self).get_config() 651 | return dict(list(base_config.items()) + list(config.items())) 652 | 653 | if is_string(name): 654 | new_optimizer.__name__ = name 655 | keras.utils.get_custom_objects()[name] = new_optimizer 656 | 657 | return new_optimizer 658 | 659 | 660 | def extend_with_lazy_optimization_v2(base_optimizer, name=None): 661 | """返回新的优化器类,加入懒惰更新 662 | """ 663 | class new_optimizer(base_optimizer): 664 | """带有懒惰更新的优化器 665 | 使得部分权重(尤其是embedding)只有在梯度不等于0时 666 | 才发生更新。 667 | """ 668 | def __init__(self, include_in_lazy_optimization=None, *args, **kwargs): 669 | super(new_optimizer, self).__init__(*args, **kwargs) 670 | self.include_in_lazy_optimization = include_in_lazy_optimization or [] 671 | 672 | def _resource_apply_op(self, grad, var, indices=None): 673 | old_update = K.update 674 | 675 | def new_update(x, new_x): 676 | if x is var and self._do_lazy_optimization(x): 677 | if indices is None: 678 | r = K.any(K.not_equal(grad, 0.), 679 | axis=-1, 680 | keepdims=True) 681 | new_x = x + (new_x - x) * K.cast(r, K.floatx()) 682 | return old_update(x, new_x) 683 | else: 684 | return self._resource_scatter_add( 685 | x, indices, K.gather(new_x - x, indices)) 686 | return old_update(x, new_x) 687 | 688 | K.update = new_update 689 | op = super(new_optimizer, 690 | self)._resource_apply_op(grad, var, indices) 691 | K.update = old_update 692 | 693 | return op 694 | 695 | def _do_lazy_optimization(self, w): 696 | return string_matching(w.name, self.include_in_lazy_optimization) 697 | 698 | def get_config(self): 699 | config = { 700 | 'include_in_lazy_optimization': 701 | self.include_in_lazy_optimization 702 | } 703 | base_config = super(new_optimizer, self).get_config() 704 | return dict(list(base_config.items()) + list(config.items())) 705 | 706 | if is_string(name): 707 | new_optimizer.__name__ = name 708 | keras.utils.get_custom_objects()[name] = new_optimizer 709 | 710 | return new_optimizer 711 | 712 | 713 | class ExponentialMovingAverage(keras.callbacks.Callback): 714 | """对模型权重进行指数滑动平均(作为Callback来使用) 715 | """ 716 | def __init__(self, momentum=0.999): 717 | self.momentum = momentum 718 | 719 | def set_model(self, model): 720 | """绑定模型,并初始化参数 721 | """ 722 | super(ExponentialMovingAverage, self).set_model(model) 723 | self.ema_weights = [K.zeros(K.shape(w)) for w in model.weights] 724 | self.old_weights = K.batch_get_value(model.weights) 725 | K.batch_set_value(zip(self.ema_weights, self.old_weights)) 726 | self.updates = [] 727 | for w1, w2 in zip(self.ema_weights, model.weights): 728 | op = K.moving_average_update(w1, w2, self.momentum) 729 | self.updates.append(op) 730 | 731 | def on_batch_end(self, batch, logs=None): 732 | """每个batch后自动执行 733 | """ 734 | K.batch_get_value(self.updates) 735 | 736 | def apply_ema_weights(self): 737 | """备份原模型权重,然后将平均权重应用到模型上去。 738 | """ 739 | self.old_weights = K.batch_get_value(self.model.weights) 740 | ema_weights = K.batch_get_value(self.ema_weights) 741 | K.batch_set_value(zip(self.model.weights, ema_weights)) 742 | 743 | def reset_old_weights(self): 744 | """恢复模型到旧权重。 745 | """ 746 | K.batch_set_value(zip(self.model.weights, self.old_weights)) 747 | 748 | 749 | if is_tf_keras: 750 | extend_with_weight_decay = extend_with_weight_decay_v2 751 | extend_with_layer_adaptation = extend_with_layer_adaptation_v2 752 | extend_with_piecewise_linear_lr = extend_with_piecewise_linear_lr_v2 753 | extend_with_gradient_accumulation = extend_with_gradient_accumulation_v2 754 | extend_with_lookahead = extend_with_lookahead_v2 755 | extend_with_lazy_optimization = extend_with_lazy_optimization_v2 756 | else: 757 | Adam = keras.optimizers.Adam 758 | -------------------------------------------------------------------------------- /code/bert4keras_5_8/snippets.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 代码合集 3 | 4 | import six 5 | import logging 6 | import numpy as np 7 | import re 8 | import sys 9 | 10 | _open_ = open 11 | is_py2 = six.PY2 12 | 13 | if not is_py2: 14 | basestring = str 15 | 16 | 17 | def is_string(s): 18 | """判断是否是字符串 19 | """ 20 | return isinstance(s, basestring) 21 | 22 | 23 | def strQ2B(ustring): 24 | """全角符号转对应的半角符号 25 | """ 26 | rstring = '' 27 | for uchar in ustring: 28 | inside_code = ord(uchar) 29 | # 全角空格直接转换 30 | if inside_code == 12288: 31 | inside_code = 32 32 | # 全角字符(除空格)根据关系转化 33 | elif (inside_code >= 65281 and inside_code <= 65374): 34 | inside_code -= 65248 35 | rstring += unichr(inside_code) 36 | return rstring 37 | 38 | 39 | def string_matching(s, keywords): 40 | """判断s是否至少包含keywords中的至少一个字符串 41 | """ 42 | for k in keywords: 43 | if re.search(k, s): 44 | return True 45 | return False 46 | 47 | 48 | def convert_to_unicode(text, encoding='utf-8'): 49 | """字符串转换为unicode格式(假设输入为utf-8格式) 50 | """ 51 | if is_py2: 52 | if isinstance(text, str): 53 | text = text.decode(encoding, 'ignore') 54 | else: 55 | if isinstance(text, bytes): 56 | text = text.decode(encoding, 'ignore') 57 | return text 58 | 59 | 60 | def convert_to_str(text, encoding='utf-8'): 61 | """字符串转换为str格式(假设输入为utf-8格式) 62 | """ 63 | if is_py2: 64 | if isinstance(text, unicode): 65 | text = text.encode(encoding, 'ignore') 66 | else: 67 | if isinstance(text, bytes): 68 | text = text.decode(encoding, 'ignore') 69 | return text 70 | 71 | 72 | class open: 73 | """模仿python自带的open函数,主要是为了同时兼容py2和py3 74 | """ 75 | 76 | def __init__(self, name, mode='r', encoding=None): 77 | if is_py2: 78 | self.file = _open_(name, mode) 79 | else: 80 | self.file = _open_(name, mode, encoding=encoding) 81 | self.encoding = encoding 82 | 83 | def __iter__(self): 84 | for l in self.file: 85 | if self.encoding: 86 | l = convert_to_unicode(l, self.encoding) 87 | yield l 88 | 89 | def read(self): 90 | text = self.file.read() 91 | if self.encoding: 92 | text = convert_to_unicode(text, self.encoding) 93 | return text 94 | 95 | def write(self, text): 96 | if self.encoding: 97 | text = convert_to_str(text, self.encoding) 98 | self.file.write(text) 99 | 100 | def flush(self): 101 | self.file.flush() 102 | 103 | def close(self): 104 | self.file.close() 105 | 106 | def __enter__(self): 107 | return self 108 | 109 | def __exit__(self, type, value, tb): 110 | self.close() 111 | 112 | 113 | class Progress: 114 | """显示进度,自己简单封装,比tqdm更可控一些 115 | iterable: 可迭代的对象; 116 | period: 显示进度的周期; 117 | steps: iterable可迭代的总步数,相当于len(iterable) 118 | """ 119 | 120 | def __init__(self, iterable, period=1, steps=None, desc=None): 121 | self.iterable = iterable 122 | self.period = period 123 | if hasattr(iterable, '__len__'): 124 | self.steps = len(iterable) 125 | else: 126 | self.steps = steps 127 | self.desc = desc 128 | if self.steps: 129 | self._format_ = u'%s/%s passed' % ('%s', self.steps) 130 | else: 131 | self._format_ = u'%s passed' 132 | if self.desc: 133 | self._format_ = self.desc + ' - ' + self._format_ 134 | self.logger = logging.getLogger() 135 | 136 | def __iter__(self): 137 | for i, j in enumerate(self.iterable): 138 | if (i + 1) % self.period == 0: 139 | self.logger.info(self._format_ % (i + 1)) 140 | yield j 141 | 142 | 143 | def parallel_apply(func, 144 | iterable, 145 | workers, 146 | max_queue_size, 147 | callback=None, 148 | dummy=False): 149 | """多进程或多线程地将func应用到iterable的每个元素中。 150 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是 151 | 输出可能是func(c), func(a), func(b)。 152 | 参数: 153 | dummy: False是多进程/线性,True则是多线程/线性; 154 | callback: 处理单个输出的回调函数; 155 | """ 156 | if dummy: 157 | from multiprocessing.dummy import Pool, Queue 158 | else: 159 | from multiprocessing import Pool, Queue 160 | 161 | in_queue, out_queue = Queue(max_queue_size), Queue() 162 | 163 | def worker_step(in_queue, out_queue): 164 | # 单步函数包装成循环执行 165 | while True: 166 | d = in_queue.get() 167 | r = func(d) 168 | out_queue.put(r) 169 | 170 | # 启动多进程/线程 171 | pool = Pool(workers, worker_step, (in_queue, out_queue)) 172 | 173 | if callback is None: 174 | results = [] 175 | 176 | # 后处理函数 177 | def process_out_queue(): 178 | out_count = 0 179 | for _ in range(out_queue.qsize()): 180 | d = out_queue.get() 181 | out_count += 1 182 | if callback is None: 183 | results.append(d) 184 | else: 185 | callback(d) 186 | return out_count 187 | 188 | # 存入数据,取出结果 189 | in_count, out_count = 0, 0 190 | for d in iterable: 191 | in_count += 1 192 | while True: 193 | try: 194 | in_queue.put(d, block=False) 195 | break 196 | except six.moves.queue.Full: 197 | out_count += process_out_queue() 198 | if in_count % max_queue_size == 0: 199 | out_count += process_out_queue() 200 | 201 | while out_count != in_count: 202 | out_count += process_out_queue() 203 | 204 | pool.terminate() 205 | 206 | if callback is None: 207 | return results 208 | 209 | 210 | def sequence_padding(inputs, length=None, padding=0): 211 | """Numpy函数,将序列padding到同一长度 212 | """ 213 | if length is None: 214 | length = max([len(x) for x in inputs]) 215 | 216 | outputs = np.array([ 217 | np.concatenate([x, [padding] * (length - len(x))]) 218 | if len(x) < length else x[:length] for x in inputs 219 | ]) 220 | return outputs 221 | 222 | 223 | def is_one_of(x, ys): 224 | """判断x是否在ys之中 225 | 等价于x in ys,但有些情况下x in ys会报错 226 | """ 227 | for y in ys: 228 | if x is y: 229 | return True 230 | return False 231 | 232 | 233 | class DataGenerator(object): 234 | """数据生成器模版 235 | """ 236 | 237 | def __init__(self, data, batch_size=32): 238 | self.data = data 239 | self.batch_size = batch_size 240 | self.steps = len(self.data) // self.batch_size 241 | if len(self.data) % self.batch_size != 0: 242 | self.steps += 1 243 | 244 | def __len__(self): 245 | return self.steps 246 | 247 | def __iter__(self, random=False): 248 | raise NotImplementedError 249 | 250 | def forfit(self): 251 | while True: 252 | for d in self.__iter__(True): 253 | yield d 254 | 255 | 256 | def softmax(x, axis=-1): 257 | """numpy版softmax 258 | """ 259 | x = x - x.max(axis=axis, keepdims=True) 260 | x = np.exp(x) 261 | return x / x.sum(axis=axis, keepdims=True) 262 | 263 | 264 | class AutoRegressiveDecoder(object): 265 | """通用自回归生成模型解码基类 266 | 包含beam search和random sample两种策略 267 | """ 268 | 269 | def __init__(self, start_id, end_id, maxlen): 270 | self.start_id = start_id 271 | self.end_id = end_id 272 | self.maxlen = maxlen 273 | if start_id is None: 274 | self.first_output_ids = np.empty((1, 0), dtype=int) 275 | else: 276 | self.first_output_ids = np.array([[self.start_id]]) 277 | 278 | @staticmethod 279 | def set_rtype(default='probas'): 280 | """用来给predict方法加上rtype参数,并作相应的处理 281 | """ 282 | 283 | def predict_decorator(predict): 284 | def new_predict(self, inputs, output_ids, step, rtype=default): 285 | assert rtype in ['probas', 'logits'] 286 | result = predict(self, inputs, output_ids, step) 287 | if default == 'probas': 288 | if rtype == 'probas': 289 | return result 290 | else: 291 | return np.log(result + 1e-12) 292 | else: 293 | if rtype == 'probas': 294 | return softmax(result, -1) 295 | else: 296 | return result 297 | 298 | return new_predict 299 | 300 | return predict_decorator 301 | 302 | def predict(self, inputs, output_ids, step, rtype='logits'): 303 | """用户需自定义递归预测函数 304 | rtype为字符串logits或probas,用户定义的时候,应当根据rtype来 305 | 返回不同的结果,rtype=probas时返回归一化的概率,rtype=logits时 306 | 则返回softmax前的结果或者概率对数。 307 | """ 308 | raise NotImplementedError 309 | 310 | def beam_search(self, inputs, topk): 311 | """beam search解码 312 | 说明:这里的topk即beam size; 313 | 返回:最优解码序列。 314 | """ 315 | inputs = [np.array([i]) for i in inputs] 316 | output_ids, output_scores = self.first_output_ids, np.zeros(1) 317 | for step in range(self.maxlen): 318 | scores = self.predict(inputs, output_ids, step, 'logits') # 计算当前得分 319 | if step == 0: # 第1步预测后将输入重复topk次 320 | inputs = [np.repeat(i, topk, axis=0) for i in inputs] 321 | scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分 322 | indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk 323 | indices_1 = indices // scores.shape[1] # 行索引 324 | indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 325 | output_ids = np.concatenate([output_ids[indices_1], indices_2], 1) # 更新输出 326 | output_scores = np.take_along_axis(scores, indices, axis=None) # 更新得分 327 | best_one = output_scores.argmax() # 得分最大的那个 328 | if indices_2[best_one, 0] == self.end_id: # 如果已经终止 329 | return output_ids[best_one] # 直接输出 330 | else: # 否则,只保留未完成部分 331 | flag = (indices_2[:, 0] != self.end_id) # 标记未完成序列 332 | if not flag.all(): # 如果有已完成的 333 | inputs = [i[flag] for i in inputs] # 扔掉已完成序列 334 | output_ids = output_ids[flag] # 扔掉已完成序列 335 | output_scores = output_scores[flag] # 扔掉已完成序列 336 | topk = flag.sum() # topk相应变化 337 | # 达到长度直接输出 338 | return output_ids[output_scores.argmax()] 339 | 340 | def nucleus_sample(self, inputs, n, p=0.95, temperature=False, topk=None, min_k=1): 341 | inputs = [np.array([i]) for i in inputs] 342 | output_ids = self.first_output_ids 343 | results = [] 344 | for step in range(self.maxlen): 345 | probas = self.predict(inputs, output_ids, step, 'probas') # 计算当前概率 346 | if step == 0: # 第1步预测后将结果重复n次 347 | probas = np.repeat(probas, n, axis=0) 348 | inputs = [np.repeat(i, n, axis=0) for i in inputs] 349 | output_ids = np.repeat(output_ids, n, axis=0) 350 | 351 | indices = (-probas).argsort(axis=1) 352 | probas = np.take_along_axis(probas, indices, axis=1) 353 | 354 | if temperature and isinstance(temperature, (int, float)): 355 | probas = probas / temperature 356 | probas = probas - np.max(probas, axis=1, keepdims=True) 357 | probas = np.exp(probas) 358 | probas = probas / np.sum(probas, axis=1, keepdims=True) 359 | 360 | x = np.cumsum(probas, axis=1) 361 | m = (x < p) * 1.0 362 | m[:, :min_k] = 1.0 363 | probas = np.multiply(probas, m) 364 | if topk is not None: 365 | probas = probas[:, :topk] 366 | 367 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 368 | # probas /= (1.0 + 1e-4) * probas.sum(axis=1, keepdims=True) # 重新归一化 369 | sample_func = lambda _p: np.random.choice(len(_p), p=_p) # 按概率采样函数 370 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样 371 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状 372 | sample_ids = np.take_along_axis(indices, sample_ids, axis=1) # 对齐原id 373 | 374 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出 375 | flag = (sample_ids[:, 0] == self.end_id) # 标记已完成序列 376 | if flag.any(): # 如果有已完成的 377 | for ids in output_ids[flag]: # 存好已完成序列 378 | results.append(ids) 379 | flag = (flag == False) # 标记未完成序列 380 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入 381 | output_ids = output_ids[flag] # 只保留未完成部分候选集 382 | if len(output_ids) == 0: 383 | break 384 | # 如果还有未完成序列,直接放入结果 385 | for ids in output_ids: 386 | results.append(ids) 387 | # 返回结果 388 | return results 389 | 390 | def random_sample(self, inputs, n, topk=None): 391 | """随机采样n个结果 392 | 说明:非None的topk表示每一步只从概率最高的topk个中采样; 393 | 返回:n个解码序列组成的list。 394 | """ 395 | inputs = [np.array([i]) for i in inputs] 396 | output_ids = self.first_output_ids 397 | results = [] 398 | for step in range(self.maxlen): 399 | probas = self.predict(inputs, output_ids, step, 'probas') # 计算当前概率 400 | if step == 0: # 第1步预测后将结果重复n次 401 | probas = np.repeat(probas, n, axis=0) 402 | inputs = [np.repeat(i, n, axis=0) for i in inputs] 403 | output_ids = np.repeat(output_ids, n, axis=0) 404 | if topk is not None: 405 | indices = probas.argpartition(-topk, axis=1)[:, -topk:] # 仅保留topk 406 | probas = np.take_along_axis(probas, indices, axis=1) # topk概率 407 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 408 | sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数 409 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样 410 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状 411 | if topk is not None: 412 | sample_ids = np.take_along_axis(indices, sample_ids, axis=1) # 对齐原id 413 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出 414 | flag = (sample_ids[:, 0] == self.end_id) # 标记已完成序列 415 | if flag.any(): # 如果有已完成的 416 | for ids in output_ids[flag]: # 存好已完成序列 417 | results.append(ids) 418 | flag = (flag == False) # 标记未完成序列 419 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入 420 | output_ids = output_ids[flag] # 只保留未完成部分候选集 421 | if len(output_ids) == 0: 422 | break 423 | # 如果还有未完成序列,直接放入结果 424 | for ids in output_ids: 425 | results.append(ids) 426 | # 返回结果 427 | return results 428 | 429 | 430 | class Hook: 431 | """注入uniout模块,实现import时才触发 432 | """ 433 | 434 | def __init__(self, module): 435 | self.module = module 436 | 437 | def __getattr__(self, attr): 438 | """使得 from bert4keras_5_8.backend import uniout 439 | 等效于 import uniout (自动识别Python版本,Python3 440 | 下则无操作。) 441 | """ 442 | if attr == 'uniout': 443 | if is_py2: 444 | import uniout 445 | else: 446 | return getattr(self.module, attr) 447 | 448 | 449 | Hook.__name__ = __name__ 450 | sys.modules[__name__] = Hook(sys.modules[__name__]) 451 | del Hook 452 | -------------------------------------------------------------------------------- /code/bert4keras_5_8/tokenizers.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 工具函数 3 | 4 | import unicodedata, re 5 | from bert4keras_5_8.snippets import is_string, is_py2 6 | from bert4keras_5_8.snippets import open 7 | 8 | 9 | def load_vocab(dict_path, encoding='utf-8', simplified=False, startwith=None, max_num=None): 10 | """从bert的词典文件中读取词典 11 | """ 12 | token_dict = {} 13 | with open(dict_path, encoding=encoding) as reader: 14 | for line in reader: 15 | token = line.strip() 16 | token_dict[token] = len(token_dict) 17 | if max_num is not None: 18 | if len(token_dict) >= max_num: 19 | break 20 | 21 | if simplified: # 过滤冗余部分token 22 | new_token_dict, keep_tokens = {}, [] 23 | startwith = startwith or [] 24 | for t in startwith: 25 | new_token_dict[t] = len(new_token_dict) 26 | keep_tokens.append(token_dict[t]) 27 | 28 | for t, _ in sorted(token_dict.items(), key=lambda s: s[1]): 29 | if t not in new_token_dict: 30 | keep = True 31 | if len(t) > 1: 32 | for c in (t[2:] if t[:2] == '##' else t): 33 | if (Tokenizer._is_cjk_character(c) 34 | or Tokenizer._is_punctuation(c)): 35 | keep = False 36 | break 37 | if keep: 38 | new_token_dict[t] = len(new_token_dict) 39 | keep_tokens.append(token_dict[t]) 40 | 41 | return new_token_dict, keep_tokens 42 | else: 43 | return token_dict 44 | 45 | 46 | class BasicTokenizer(object): 47 | """分词器基类 48 | """ 49 | def __init__(self, do_lower_case=False): 50 | """初始化 51 | """ 52 | self._token_pad = '[PAD]' 53 | self._token_cls = '[CLS]' 54 | self._token_sep = '[SEP]' 55 | self._token_unk = '[UNK]' 56 | self._token_mask = '[MASK]' 57 | self._do_lower_case = do_lower_case 58 | 59 | def tokenize(self, text, add_cls=True, add_sep=True, max_length=None): 60 | """分词函数 61 | """ 62 | if self._do_lower_case: 63 | if is_py2: 64 | text = unicode(text) 65 | text = unicodedata.normalize('NFD', text) 66 | text = ''.join( 67 | [ch for ch in text if unicodedata.category(ch) != 'Mn']) 68 | text = text.lower() 69 | 70 | tokens = self._tokenize(text) 71 | if add_cls: 72 | tokens.insert(0, self._token_cls) 73 | if add_sep: 74 | tokens.append(self._token_sep) 75 | 76 | if max_length is not None: 77 | self.truncate_sequence(max_length, tokens, None, -2) 78 | 79 | return tokens 80 | 81 | def token_to_id(self, token): 82 | """token转换为对应的id 83 | """ 84 | raise NotImplementedError 85 | 86 | def tokens_to_ids(self, tokens): 87 | """token序列转换为对应的id序列 88 | """ 89 | return [self.token_to_id(token) for token in tokens] 90 | 91 | def truncate_sequence(self, 92 | max_length, 93 | first_sequence, 94 | second_sequence=None, 95 | pop_index=-1): 96 | """截断总长度 97 | """ 98 | if second_sequence is None: 99 | second_sequence = [] 100 | 101 | while True: 102 | total_length = len(first_sequence) + len(second_sequence) 103 | if total_length <= max_length: 104 | break 105 | elif len(first_sequence) > len(second_sequence): 106 | first_sequence.pop(pop_index) 107 | else: 108 | second_sequence.pop(pop_index) 109 | 110 | def encode(self, 111 | first_text, 112 | second_text=None, 113 | max_length=None, 114 | first_length=None, 115 | second_length=None): 116 | """输出文本对应token id和segment id 117 | 如果传入first_length,则强行padding第一个句子到指定长度; 118 | 同理,如果传入second_length,则强行padding第二个句子到指定长度。 119 | """ 120 | if is_string(first_text): 121 | first_tokens = self.tokenize(first_text) 122 | else: 123 | first_tokens = first_text 124 | 125 | if second_text is None: 126 | second_tokens = None 127 | elif is_string(second_text): 128 | second_tokens = self.tokenize(second_text, add_cls=False) 129 | else: 130 | second_tokens = second_text 131 | 132 | if max_length is not None: 133 | self.truncate_sequence(max_length, first_tokens, second_tokens, -2) 134 | 135 | first_token_ids = self.tokens_to_ids(first_tokens) 136 | if first_length is not None: 137 | first_token_ids = first_token_ids[:first_length] 138 | first_token_ids.extend([self._token_pad_id] * 139 | (first_length - len(first_token_ids))) 140 | first_segment_ids = [0] * len(first_token_ids) 141 | 142 | if second_text is not None: 143 | second_token_ids = self.tokens_to_ids(second_tokens) 144 | if second_length is not None: 145 | second_token_ids = second_token_ids[:second_length] 146 | second_token_ids.extend( 147 | [self._token_pad_id] * 148 | (second_length - len(second_token_ids))) 149 | second_segment_ids = [1] * len(second_token_ids) 150 | 151 | first_token_ids.extend(second_token_ids) 152 | first_segment_ids.extend(second_segment_ids) 153 | 154 | return first_token_ids, first_segment_ids 155 | 156 | def id_to_token(self, i): 157 | """id序列为对应的token 158 | """ 159 | raise NotImplementedError 160 | 161 | def ids_to_tokens(self, ids): 162 | """id序列转换为对应的token序列 163 | """ 164 | return [self.id_to_token(i) for i in ids] 165 | 166 | def decode(self, ids): 167 | """转为可读文本 168 | """ 169 | raise NotImplementedError 170 | 171 | def _tokenize(self, text): 172 | """基本分词函数 173 | """ 174 | raise NotImplementedError 175 | 176 | 177 | class Tokenizer(BasicTokenizer): 178 | """Bert原生分词器 179 | 纯Python实现,代码修改自keras_bert的tokenizer实现 180 | """ 181 | def __init__(self, token_dict, do_lower_case=False): 182 | """初始化 183 | """ 184 | super(Tokenizer, self).__init__(do_lower_case) 185 | if is_string(token_dict): 186 | token_dict = load_vocab(token_dict) 187 | 188 | self._token_dict = token_dict 189 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 190 | for token in ['pad', 'cls', 'sep', 'unk', 'mask']: 191 | try: 192 | _token_id = token_dict[getattr(self, '_token_%s' % token)] 193 | setattr(self, '_token_%s_id' % token, _token_id) 194 | except: 195 | pass 196 | self._vocab_size = len(token_dict) 197 | 198 | def token_to_id(self, token): 199 | """token转换为对应的id 200 | """ 201 | return self._token_dict.get(token, self._token_unk_id) 202 | 203 | def id_to_token(self, i): 204 | """id转换为对应的token 205 | """ 206 | return self._token_dict_inv[i] 207 | 208 | def decode(self, ids, tokens=None): 209 | """转为可读文本 210 | """ 211 | tokens = tokens or self.ids_to_tokens(ids) 212 | tokens = [token for token in tokens if not self._is_special(token)] 213 | 214 | text, flag = '', False 215 | for i, token in enumerate(tokens): 216 | if token[:2] == '##': 217 | text += token[2:] 218 | elif len(token) == 1 and self._is_cjk_character(token): 219 | text += token 220 | elif len(token) == 1 and self._is_punctuation(token): 221 | text += token 222 | text += ' ' 223 | elif i > 0 and self._is_cjk_character(text[-1]): 224 | text += token 225 | else: 226 | text += ' ' 227 | text += token 228 | 229 | text = re.sub(' +', ' ', text) 230 | text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) 231 | punctuation = self._cjk_punctuation() + '+-/={(<[' 232 | punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) 233 | punctuation_regex = '(%s) ' % punctuation_regex 234 | text = re.sub(punctuation_regex, '\\1', text) 235 | text = re.sub('(\d\.) (\d)', '\\1\\2', text) 236 | 237 | return text.strip() 238 | 239 | def _tokenize(self, text): 240 | """基本分词函数 241 | """ 242 | spaced = '' 243 | for ch in text: 244 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 245 | spaced += ' ' + ch + ' ' 246 | elif self._is_space(ch): 247 | spaced += ' ' 248 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 249 | continue 250 | else: 251 | spaced += ch 252 | 253 | tokens = [] 254 | for word in spaced.strip().split(): 255 | tokens.extend(self._word_piece_tokenize(word)) 256 | 257 | return tokens 258 | 259 | def _word_piece_tokenize(self, word): 260 | """word内分成subword 261 | """ 262 | if word in self._token_dict: 263 | return [word] 264 | 265 | tokens = [] 266 | start, stop = 0, 0 267 | while start < len(word): 268 | stop = len(word) 269 | while stop > start: 270 | sub = word[start:stop] 271 | if start > 0: 272 | sub = '##' + sub 273 | if sub in self._token_dict: 274 | break 275 | stop -= 1 276 | if start == stop: 277 | stop += 1 278 | tokens.append(sub) 279 | start = stop 280 | 281 | return tokens 282 | 283 | @staticmethod 284 | def _is_space(ch): 285 | """空格类字符判断 286 | """ 287 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ 288 | unicodedata.category(ch) == 'Zs' 289 | 290 | @staticmethod 291 | def _is_punctuation(ch): 292 | """标点符号类字符判断(全/半角均在此内) 293 | """ 294 | code = ord(ch) 295 | return 33 <= code <= 47 or \ 296 | 58 <= code <= 64 or \ 297 | 91 <= code <= 96 or \ 298 | 123 <= code <= 126 or \ 299 | unicodedata.category(ch).startswith('P') 300 | 301 | @staticmethod 302 | def _cjk_punctuation(): 303 | return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\xb7\uff01\uff1f\uff61\u3002' 304 | 305 | @staticmethod 306 | def _is_cjk_character(ch): 307 | """CJK类字符判断(包括中文字符也在此列) 308 | 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 309 | """ 310 | code = ord(ch) 311 | return 0x4E00 <= code <= 0x9FFF or \ 312 | 0x3400 <= code <= 0x4DBF or \ 313 | 0x20000 <= code <= 0x2A6DF or \ 314 | 0x2A700 <= code <= 0x2B73F or \ 315 | 0x2B740 <= code <= 0x2B81F or \ 316 | 0x2B820 <= code <= 0x2CEAF or \ 317 | 0xF900 <= code <= 0xFAFF or \ 318 | 0x2F800 <= code <= 0x2FA1F 319 | 320 | @staticmethod 321 | def _is_control(ch): 322 | """控制类字符判断 323 | """ 324 | return unicodedata.category(ch) in ('Cc', 'Cf') 325 | 326 | @staticmethod 327 | def _is_special(ch): 328 | """判断是不是有特殊含义的符号 329 | """ 330 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') 331 | 332 | 333 | class SpTokenizer(BasicTokenizer): 334 | """基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。 335 | """ 336 | def __init__(self, sp_model_path, do_lower_case=False): 337 | super(SpTokenizer, self).__init__(do_lower_case) 338 | import sentencepiece as spm 339 | self.sp_model = spm.SentencePieceProcessor() 340 | self.sp_model.Load(sp_model_path) 341 | self._token_pad = self.sp_model.id_to_piece(self.sp_model.pad_id()) 342 | self._token_unk = self.sp_model.id_to_piece(self.sp_model.unk_id()) 343 | self._token_pad_id = self.sp_model.piece_to_id(self._token_pad) 344 | self._token_cls_id = self.sp_model.piece_to_id(self._token_cls) 345 | self._token_sep_id = self.sp_model.piece_to_id(self._token_sep) 346 | self._token_unk_id = self.sp_model.piece_to_id(self._token_unk) 347 | self._token_mask_id = self.sp_model.piece_to_id(self._token_mask) 348 | self._vocab_size = self.sp_model.get_piece_size() 349 | 350 | def token_to_id(self, token): 351 | """token转换为对应的id 352 | """ 353 | return self.sp_model.piece_to_id(token) 354 | 355 | def id_to_token(self, i): 356 | """id转换为对应的token 357 | """ 358 | return self.sp_model.id_to_piece(i) 359 | 360 | def decode(self, ids): 361 | """转为可读文本 362 | """ 363 | ids = [i for i in ids if not self._is_special(i)] 364 | return self.sp_model.decode_ids(ids) 365 | 366 | def _tokenize(self, text): 367 | """基本分词函数 368 | """ 369 | tokens = self.sp_model.encode_as_pieces(text) 370 | return tokens 371 | 372 | def _is_special(self, i): 373 | """判断是不是有特殊含义的符号 374 | """ 375 | return self.sp_model.is_control(i) or \ 376 | self.sp_model.is_unknown(i) or \ 377 | self.sp_model.is_unused(i) 378 | -------------------------------------------------------------------------------- /code/cfg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/4/24 22:01 6 | @File :cfg.py 7 | @Desc : 8 | """ 9 | import os 10 | from easydict import EasyDict as edict 11 | 12 | join = os.path.join 13 | 14 | MAIN_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 15 | DATA_PATH = join(MAIN_PATH, 'data') 16 | MID_PATH = join(MAIN_PATH, 'mid') 17 | MODEL_PATH = join(MAIN_PATH, 'model') 18 | BERT_PATH = join(DATA_PATH, 'roberta') 19 | OUT_PATH = join(MAIN_PATH, 'output') 20 | FILE_DICT = { 21 | 'train': join(DATA_PATH, 'train/train.txt'), 22 | 'dev': join(DATA_PATH, 'dev/dev.txt'), 23 | 'test': join(DATA_PATH, 'test_1/test_1.txt'), 24 | 'test2': join(DATA_PATH, 'test_2/test_2.txt'), 25 | } 26 | 27 | 28 | data_num = { 29 | 0: 6618, 30 | 1: 946, 31 | 2: 4645, 32 | 3: 13666, 33 | } 34 | train_list = [0, 1] 35 | totle_sample = 0 36 | for t in train_list: 37 | totle_sample += data_num[t] 38 | TAG = 'd2' 39 | 40 | 41 | def __get_config(): 42 | _config = edict() 43 | return _config 44 | 45 | 46 | def __get_logger(): 47 | import logging 48 | import datetime 49 | if not os.path.isdir(os.path.join(MAIN_PATH, 'logs')): 50 | os.makedirs(os.path.join(MAIN_PATH, 'logs')) 51 | 52 | LOG_PATH = os.path.join(MAIN_PATH, 'logs/log_{}.txt'.format(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))) 53 | 54 | logging.basicConfig(filename=LOG_PATH, 55 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level=logging.INFO, 56 | filemode='w', datefmt='%Y-%m-%d%I:%M:%S %p') 57 | _logger = logging.getLogger(__name__) 58 | 59 | # 添加日志输出到控制台 60 | console = logging.StreamHandler() 61 | _logger.addHandler(console) 62 | _logger.setLevel(logging.INFO) 63 | 64 | return _logger 65 | 66 | 67 | config = __get_config() 68 | logger = __get_logger() -------------------------------------------------------------------------------- /code/conversation_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ################################################################################ 4 | # 5 | # Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved 6 | # 7 | ################################################################################ 8 | """ 9 | File: conversation_client.py 10 | """ 11 | 12 | from __future__ import print_function 13 | 14 | import sys 15 | import socket 16 | import importlib 17 | 18 | importlib.reload(sys) 19 | 20 | SERVER_IP = "127.0.0.1" 21 | SERVER_PORT = 8601 22 | 23 | def conversation_client(text): 24 | """ 25 | conversation_client 26 | """ 27 | mysocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 28 | mysocket.connect((SERVER_IP, SERVER_PORT)) 29 | mysocket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096 * 5) 30 | mysocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4096 * 5) 31 | 32 | mysocket.sendall(text.encode()) 33 | result = mysocket.recv(4096 * 5).decode() 34 | 35 | mysocket.close() 36 | 37 | return result 38 | 39 | 40 | def main(): 41 | """ 42 | main 43 | """ 44 | if len(sys.argv) < 2: 45 | print("Usage: " + sys.argv[0] + " eval_file") 46 | exit() 47 | 48 | for line in open(sys.argv[1], encoding='utf-8'): 49 | response = conversation_client(line.strip()) 50 | print(response) 51 | 52 | 53 | if __name__ == '__main__': 54 | try: 55 | main() 56 | except KeyboardInterrupt: 57 | print("\nExited from the program ealier!") 58 | -------------------------------------------------------------------------------- /code/conversation_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ################################################################################ 4 | # 5 | # Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved 6 | # 7 | ################################################################################ 8 | """ 9 | File: conversation_server.py 10 | """ 11 | 12 | from __future__ import print_function 13 | 14 | import sys 15 | sys.path.append("../") 16 | import socket 17 | import importlib 18 | from _thread import start_new_thread 19 | from predict.predict_final import FinalPredict 20 | 21 | importlib.reload(sys) 22 | 23 | SERVER_IP = "127.0.0.1" 24 | SERVER_PORT = 8601 25 | 26 | print("starting conversation server ...") 27 | print("binding socket ...") 28 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 29 | s.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096 * 20) 30 | s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4096 * 20) 31 | bufsize = s.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) 32 | print( "Buffer size [After]: %d" %bufsize) 33 | #Bind socket to local host and port 34 | try: 35 | s.bind((SERVER_IP, SERVER_PORT)) 36 | except socket.error as msg: 37 | print("Bind failed. Error Code : " + str(msg[0]) + " Message " + msg[1]) 38 | exit() 39 | #Start listening on socket 40 | s.listen(10) 41 | print("bind socket success !") 42 | 43 | print("loading model...") 44 | model = FinalPredict() 45 | print("load model success !") 46 | 47 | print("start conversation server success !") 48 | 49 | 50 | def clientthread(conn, addr): 51 | """ 52 | client thread 53 | """ 54 | logstr = "addr:" + addr[0]+ "_" + str(addr[1]) 55 | try: 56 | #Receiving from client 57 | param_ori = conn.recv(4096 * 20) 58 | param = param_ori.decode('utf-8', "ignore") 59 | # logstr += "\tparam:" + param 60 | if param is not None: 61 | response = model.predict(param.strip()) 62 | logstr += "\tresponse:" + response 63 | conn.sendall(response.encode()) 64 | conn.close() 65 | print(logstr + "\n") 66 | except Exception as e: 67 | print(logstr + "\n", e) 68 | print('==========') 69 | conn.close() 70 | raise 71 | 72 | 73 | while True: 74 | conn, addr = s.accept() 75 | start_new_thread(clientthread, (conn, addr)) 76 | s.close() 77 | -------------------------------------------------------------------------------- /code/data_deal/input_ct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/19 22:39 6 | @File :input_ct.py 7 | @Desc : 8 | """ 9 | from data_deal.base_input import * 10 | import random 11 | 12 | 13 | class CT_Tokenizer(Tokenizer): 14 | def truncate_sequence(self, 15 | max_length, 16 | first_sequence, 17 | second_sequence=None, 18 | pop_index=-1): 19 | """截断总长度 20 | """ 21 | if second_sequence is None: 22 | second_sequence = [] 23 | 24 | while True: 25 | total_length = len(first_sequence) + len(second_sequence) 26 | if total_length <= max_length: 27 | break 28 | elif len(first_sequence) > len(second_sequence): 29 | first_sequence.pop(1) 30 | else: 31 | second_sequence.pop(pop_index) 32 | 33 | 34 | class CTInput(BaseInput): 35 | def __init__(self, *args, **kwargs): 36 | super(CTInput, self).__init__(*args, **kwargs) 37 | 38 | self.last_sample_num = None 39 | 40 | # 模型相关 41 | self.token_dict, self.keep_tokens = load_vocab( 42 | join(BERT_PATH, 'vocab.txt'), 43 | startwith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '[unused2]'], 44 | simplified=True, max_num=9000) 45 | logger.info('Len of token_dict:{}'.format(len(self.token_dict))) 46 | self.tokenizer = CT_Tokenizer(self.token_dict) 47 | self.batch_size = 8 48 | 49 | self._label_context = { 50 | '0': 0, 51 | '1': 1, 52 | } 53 | 54 | def generator(self, batch_size=4, data_type=0, need_shuffle=False, cycle=False, need_douban=True): 55 | if not isinstance(data_type, list): 56 | data_type = [data_type] 57 | data_files = [] 58 | for t in data_type: 59 | if t not in self.data_dict.keys(): 60 | raise ValueError('data_type {} not in dict: {}'.format(t, self.data_dict.keys())) 61 | data_files.append(self.data_dict[t]) 62 | X, S, L = [], [], [] 63 | sample_iter = self.get_sample(data_files, need_shuffle=need_shuffle, cycle=cycle) 64 | if need_douban: 65 | douban_iter = self._get_douban(join(DATA_PATH, 'douban_train.txt'), cycle=True) 66 | else: 67 | douban_iter = None 68 | info = True 69 | while True: 70 | if not need_douban or random.random() < 0.3: 71 | sample = next(sample_iter) 72 | bot_first = self.reader._check_bot_first(sample['goal']) 73 | if bot_first is None: 74 | continue 75 | add_n = 1 if bot_first else 0 76 | context_str = 'conversation' if 'conversation' in sample.keys() else 'history' 77 | if len(sample[context_str]) < 2: 78 | continue 79 | end_n = random.randint(2, len(sample[context_str])) 80 | if end_n % 2 != add_n: 81 | if end_n == 2: 82 | end_n = 3 83 | else: 84 | end_n -= 1 85 | # label 86 | context = sample[context_str][:end_n] 87 | label = '1' 88 | if random.random() < 0.5: 89 | if end_n <= len(sample[context_str]) - 2: 90 | context.pop(-1) 91 | context.append(sample[context_str][end_n + 1]) 92 | label = '0' 93 | context = [re.sub(self.reader.goal_num_comp, '', s).replace(' ', '') for s in context] 94 | else: 95 | sample = next(douban_iter) 96 | context = sample[1:] 97 | label = sample[0] 98 | x, s, l = self.encode(context, label) 99 | if info: 100 | logger.info('input: {}'.format(' '.join(self.tokenizer.ids_to_tokens(x)))) 101 | info = False 102 | if x is None: 103 | continue 104 | X.append(x) 105 | S.append(s) 106 | L.append(l) 107 | if len(X) >= batch_size: 108 | X = sequence_padding(X) 109 | S = sequence_padding(S) 110 | yield [X, S], L 111 | X, S, L = [], [], [] 112 | 113 | def encode(self, ori_context, label=None): 114 | if label: 115 | label = self._label_context[label] 116 | ori_context = list(map(lambda _s: str(_s).strip().replace(' ', ''), ori_context)) 117 | context = [] 118 | for i, sentence in enumerate(ori_context[:-1]): 119 | context.extend(self.tokenizer.tokenize(sentence, add_cls=(i == 0), add_sep=(i >= len(ori_context) - 2))) 120 | if i < len(ori_context) - 2: 121 | context.append('[unused2]') 122 | x, s = self.tokenizer.encode(first_text=context, 123 | second_text=ori_context[-1], max_length=128) 124 | return x, s, label 125 | 126 | def _get_douban(self, file_path, cycle=True): 127 | with open(file_path, mode='r', encoding='utf-8') as fr: 128 | while True: 129 | line = fr.readline() 130 | _rn = 0 131 | while not line: 132 | line = fr.readline() 133 | if _rn > 10: 134 | if cycle: 135 | fr.seek(0) 136 | else: 137 | raise StopIteration 138 | _rn += 1 139 | line = line.strip().split('\t') 140 | yield line 141 | 142 | 143 | def test(): 144 | rc = CTInput() 145 | rc_it = rc.generator(4) 146 | i = 0 147 | for [X, S], L in rc_it: 148 | for x, l in zip(X, L): 149 | print(rc.tokenizer.ids_to_tokens(x)) 150 | print(l) 151 | print() 152 | i += 1 153 | if i % 30 == 0: 154 | if input('C?') == 'q': 155 | break 156 | 157 | 158 | if __name__ == '__main__': 159 | test() -------------------------------------------------------------------------------- /code/data_deal/input_goal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/8 21:08 6 | @File :goal_input.py 7 | @Desc : 8 | """ 9 | from data_deal.base_input import * 10 | 11 | 12 | class GoalInput(BaseInput): 13 | def __init__(self): 14 | super(GoalInput, self).__init__() 15 | self.max_len = 360 16 | 17 | def encode_predict(self, sample: dict): 18 | context, goals, turns = self.reader.trans_sample(sample, need_bot_trans=False) 19 | if context is None: 20 | return None, None 21 | token_ids = [] 22 | segs = [] 23 | 24 | token_ids.extend(self.tokenizer.encode(sample['situation'])[0]) 25 | segs.extend([0] * len(token_ids)) 26 | turn = False 27 | for i, sentence, goal, turn in zip(list(range(len(context))), context, goals, turns): 28 | this_goal = self.reader.all_goals[goal if goal > 0 else 1] 29 | goal_tokens, _ = self.tokenizer.encode(this_goal) 30 | goal_tokens = goal_tokens[1:-1] 31 | token_ids += goal_tokens 32 | token_ids += [self.tokenizer._token_goal_id] 33 | 34 | segs += [0 if turn else 1] * (len(goal_tokens) + 1) 35 | 36 | sen_tokens, _ = self.tokenizer.encode(sentence) 37 | sen_tokens = sen_tokens[1:] 38 | token_ids += sen_tokens 39 | if turn: 40 | segs += [1] * len(sen_tokens) 41 | else: 42 | segs += [0] * len(sen_tokens) 43 | if turn: 44 | raise ValueError('last turn is not user') 45 | if len(token_ids) > self.max_len: 46 | token_ids = token_ids[:1] + token_ids[1 - self.max_len:] 47 | segs = segs[:1] + segs[1 - self.max_len:] 48 | return token_ids, segs 49 | 50 | def encode(self, sample: dict, need_goal_mask=True): 51 | context, goals, turns = self.reader.trans_sample(sample, need_bot_trans=False) 52 | if context is None: 53 | return None, None 54 | token_ids = [] 55 | segs = [] 56 | 57 | token_ids.extend(self.tokenizer.encode(sample['situation'])[0]) 58 | segs.extend([0] * len(token_ids)) 59 | for i, sentence, goal, turn in zip(list(range(len(context))), context, goals, turns): 60 | if turn and goal != 0: # 未知只有test才有 61 | if len(token_ids) > self.max_len: 62 | token_ids = token_ids[:1] + token_ids[1 - self.max_len:] 63 | segs = segs[:1] + segs[1 - self.max_len:] 64 | yield token_ids.copy(), segs.copy(), goal 65 | if need_goal_mask and i > 0: 66 | if goal > 1 and random.random() < 0.5: 67 | if 'history' not in sample.keys(): 68 | goal = 1 69 | this_goal = self.reader.all_goals[goal if goal > 0 else 1] 70 | goal_tokens, _ = self.tokenizer.encode(this_goal) 71 | goal_tokens = goal_tokens[1:-1] 72 | token_ids += goal_tokens 73 | token_ids += [self.tokenizer._token_goal_id] 74 | 75 | segs += [0 if turn else 1] * (len(goal_tokens) + 1) 76 | 77 | sen_tokens, _ = self.tokenizer.encode(sentence) 78 | sen_tokens = sen_tokens[1:] 79 | token_ids += sen_tokens 80 | if turn: 81 | segs += [1] * len(sen_tokens) 82 | else: 83 | segs += [0] * len(sen_tokens) 84 | 85 | def generator(self, batch_size=12, data_type=0, need_shuffle=False, cycle=False): 86 | data_dict = { 87 | 0: join(DATA_PATH, 'train/train.txt'), 88 | 1: join(DATA_PATH, 'dev/dev.txt'), 89 | 2: join(DATA_PATH, 'test_1/test_1.txt'), 90 | 3: join(DATA_PATH, 'test_2/test_2.txt'), 91 | } 92 | if not isinstance(data_type, list): 93 | data_type = [data_type] 94 | data_files = [] 95 | for t in data_type: 96 | if t not in data_dict.keys(): 97 | raise ValueError('data_type {} not in dict: {}'.format(t, data_dict.keys())) 98 | data_files.append(data_dict[t]) 99 | X, S, L = [], [], [] 100 | sample_iter = self.get_sample(data_files, need_shuffle=need_shuffle, cycle=cycle) 101 | while True: 102 | sample = next(sample_iter) 103 | piece_iter = self.encode(sample) 104 | for x, s, l in piece_iter: 105 | if x is None: 106 | continue 107 | X.append(x) 108 | S.append(s) 109 | L.append(l) 110 | if len(X) >= batch_size: 111 | X = sequence_padding(X) 112 | S = sequence_padding(S) 113 | yield [X, S], L 114 | X, S, L = [], [], [] 115 | -------------------------------------------------------------------------------- /code/data_deal/input_rc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | @Time : 2020/5/10 9:59 6 | @Author : Apple QXTD 7 | @File : input_rc.py 8 | @Desc: : 9 | """ 10 | from data_deal.base_input import * 11 | from data_deal.trans_output import TransOutput 12 | 13 | 14 | class RCInput(BaseInput): 15 | def __init__(self, *args, **kwargs): 16 | super(RCInput, self).__init__(*args, **kwargs) 17 | 18 | self.last_sample_num = None 19 | self.dict_path = join(BERT_PATH, 'vocab.txt') 20 | 21 | token_dict, self.keep_tokens = load_vocab( 22 | dict_path=self.dict_path, 23 | simplified=True, 24 | startwith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'], 25 | ) 26 | self.tokenizer = Tokenizer(token_dict, do_lower_case=True) 27 | 28 | self.max_p_len = 368 - 64 - 46 - 5 29 | self.max_q_len = 64 30 | self.max_a_len = 46 31 | self.batch_size = 4 32 | self.need_evaluate = True 33 | self.out_trans = TransOutput() 34 | 35 | if self.from_pre_trans: 36 | self.data_dict = { 37 | 0: join(DATA_PATH, 'trans', 'trans_rc_0.txt'), 38 | 1: join(DATA_PATH, 'trans', 'trans_rc_1.txt'), 39 | 2: join(DATA_PATH, 'trans', 'trans_rc_2.txt'), 40 | 3: join(DATA_PATH, 'trans', 'trans_rc_3.txt'), 41 | } 42 | self._a = 0 43 | self._b = 0 44 | 45 | def generator(self, batch_size=4, data_type=0, need_shuffle=False, cycle=False): 46 | if not isinstance(data_type, list): 47 | data_type = [data_type] 48 | data_files = [] 49 | for t in data_type: 50 | if t not in self.data_dict.keys(): 51 | raise ValueError('data_type {} not in dict: {}'.format(t, self.data_dict.keys())) 52 | data_files.append(self.data_dict[t]) 53 | X, S, A = [], [], [] 54 | sample_iter = self.get_sample(data_files, need_shuffle=need_shuffle, cycle=cycle) 55 | while True: 56 | sample = next(sample_iter) 57 | for x, s, a in self.encode(sample): 58 | if x is None: 59 | continue 60 | X.append(x) 61 | S.append(s) 62 | A.append(a) 63 | # dx = self.tokenizer.decode(x) 64 | # da = self.tokenizer.decode(a) 65 | # print(dx) 66 | # print(da) 67 | if len(X) >= batch_size: 68 | X = sequence_padding(X) 69 | S = sequence_padding(S) 70 | A = sequence_padding(A, self.max_a_len) 71 | yield [X, S], A 72 | X, S, A = [], [], [] 73 | 74 | def get_rc_sample(self, sample): 75 | if self.from_pre_trans: 76 | results = [sample] 77 | else: 78 | context, goals, turns, ori_replace_dict = self.reader.trans_sample(sample, need_replace_dict=True) 79 | if context is None: 80 | return [] 81 | context_str = 'conversation' if 'conversation' in sample.keys() else 'history' 82 | history = sample[context_str] 83 | results = [] 84 | for i, sentence, goal, turn, ori_rp_dict in zip( 85 | list(range(len(context))), context, goals, turns, ori_replace_dict): 86 | if not turn: 87 | continue 88 | replace_dict = self.out_trans.search_choices(sample, sentence, history=history[:i + 1]) 89 | if len(replace_dict) > 0: 90 | results.append( 91 | { 92 | 'history': history[:i] + [sentence], 93 | 'replace_dict': replace_dict, 94 | 'result': ori_rp_dict, 95 | } 96 | ) 97 | return results 98 | 99 | def encode(self, sample: dict): 100 | samples = self.get_rc_sample(sample) 101 | for sample in samples: 102 | for q_key, answer in sample['result'].items(): 103 | if q_key not in sample['replace_dict'].keys(): 104 | continue 105 | self._a += 1 106 | context = sample['replace_dict'][q_key] # it's a list 107 | context = '|'.join(context).replace(' ', '') # 全部使用 | 作为分隔符 108 | question = '|'.join(sample['history']).replace(' ', '') # 全部使用 | 作为分隔符 109 | answer = answer.replace(' ', '') 110 | question += '|{}'.format(q_key) # question额外添加询问的标记 111 | # 如果长度超了,就截取 112 | if len(context) > self.max_p_len - 5: 113 | answer_start = self.dynamic_find(context, answer) 114 | if answer_start < 0: 115 | continue 116 | trunc_res = self.trans_sample((context, question, answer, answer_start)) 117 | if trunc_res is None: 118 | continue 119 | context, question, answer = trunc_res[:3] 120 | # 编码 121 | a_token_ids, _ = self.tokenizer.encode(answer, max_length=self.max_a_len + 1) 122 | q_token_ids, _ = self.tokenizer.encode(question) 123 | while len(q_token_ids) > self.max_q_len + 1: 124 | q_token_ids.pop(1) 125 | p_token_ids, _ = self.tokenizer.encode(context, max_length=self.max_p_len + 1) 126 | token_ids = [self.tokenizer._token_cls_id] 127 | token_ids += ([self.tokenizer._token_mask_id] * self.max_a_len) 128 | token_ids += [self.tokenizer._token_sep_id] 129 | token_ids += (q_token_ids[1:] + p_token_ids[1:]) 130 | segment_ids = [0] * len(token_ids) 131 | self._b += 1 132 | yield token_ids, segment_ids, a_token_ids[1:] 133 | 134 | def trans_sample(self, sample): 135 | context, question, answer, answer_start = sample 136 | if len(question) > self.max_q_len: 137 | question = question[-self.max_q_len:] 138 | if len(answer) > self.max_a_len: 139 | answer = answer[:self.max_a_len] 140 | if len(context) - len(answer) > 220: 141 | tail_len = len(context) - len(answer) - answer_start 142 | if tail_len > answer_start: # 截取尾部的文本 143 | tail_index = random.randint(-tail_len, int(-tail_len / 2)) 144 | context = context[:tail_index] 145 | answer_start = self.dynamic_find(context, answer) 146 | if answer_start < 0: 147 | answer = '' 148 | else: 149 | end_index = random.randint(int(answer_start / 2), answer_start) 150 | context = context[end_index:] 151 | answer_start = self.dynamic_find(context, answer) 152 | if answer_start < 0: 153 | answer = '' 154 | if len(context) > self.max_p_len: 155 | if answer_start < 0: 156 | context = context[:self.max_p_len] 157 | else: 158 | offset = len(context) - self.max_p_len 159 | if answer_start >= offset: 160 | context = context[offset:] 161 | answer_start -= offset 162 | elif len(answer) + answer_start <= len(context) - offset: 163 | context = context[:self.max_p_len] 164 | else: 165 | # answer 最大长度64 这种情况不存在 166 | return None 167 | return context, question, answer, answer_start 168 | 169 | def dynamic_find(self, sentence, piece): 170 | answer_start = -1 171 | p = 0 172 | while answer_start < 0: 173 | if p * 2 >= len(piece) - 4: 174 | break 175 | answer_start = sentence.find(piece[p * 2:]) 176 | p += 1 177 | if answer_start > 0: 178 | answer_start = max(0, answer_start - p * 2 + 2) 179 | return answer_start 180 | 181 | def test(): 182 | rc = RCInput() 183 | rc_it = rc.generator(4) 184 | i = 0 185 | for s in rc_it: 186 | i += 1 187 | print('i: ', i) 188 | print(rc._a) 189 | print(rc._b) 190 | if i % 30 == 0: 191 | if input('C?') == 'q': 192 | break 193 | 194 | 195 | if __name__ == '__main__': 196 | test() -------------------------------------------------------------------------------- /code/data_deal/pre_trans.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/9 19:46 6 | @File :pre_trans.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from data_deal.base_input import * 12 | 13 | 14 | def trans(data_type): 15 | data_dict = { 16 | 0: join(DATA_PATH, 'train/train.txt'), 17 | 1: join(DATA_PATH, 'dev/dev.txt'), 18 | 2: join(DATA_PATH, 'test_1/test_1.txt'), 19 | 3: join(DATA_PATH, 'test_2/test_2.txt'), 20 | } 21 | output_dir = join(DATA_PATH, 'trans') 22 | if not os.path.isdir(output_dir): 23 | os.makedirs(output_dir) 24 | 25 | output_path = join(output_dir, 'trans_{}.txt'.format(data_type)) 26 | 27 | data_input = BaseInput() 28 | 29 | all_data = [] 30 | 31 | data_iter = data_input.get_sample(data_dict[data_type], need_shuffle=False, cycle=False) 32 | sn = 0 33 | for sample in data_iter: 34 | context, goals, turns, unused_goals, replace_dicts = data_input.reader.trans_sample( 35 | sample, return_rest_goals=True, need_replace_dict=True) 36 | sample.update( 37 | { 38 | 'context': context, 39 | 'goals': goals, 40 | 'turns': turns, 41 | 'unused_goals': unused_goals, 42 | 'replace_dicts': replace_dicts, 43 | } 44 | ) 45 | all_data.append(sample) 46 | sn += 1 47 | if sn % 58 == 0: 48 | print('\rnum {}'.format(sn), end=' ') 49 | # if sn > 30: 50 | # break 51 | print('\nOver: ', sn) 52 | with open(output_path, encoding='utf-8', mode='w') as fw: 53 | for data in all_data: 54 | fw.writelines(json.dumps( 55 | data, 56 | ensure_ascii=False, 57 | # indent=4, separators=(',',':') 58 | ) + '\n') 59 | 60 | def trans_v2(data_type): 61 | output_dir = join(DATA_PATH, 'trans') 62 | if not os.path.isdir(output_dir): 63 | os.makedirs(output_dir) 64 | 65 | input_path = join(output_dir, 'trans_{}.txt'.format(data_type)) 66 | output_path = join(output_dir, 'trans_{}_trim.txt'.format(data_type)) 67 | 68 | data_input = BaseInput() 69 | 70 | all_data = [] 71 | 72 | data_iter = data_input.get_sample(input_path, need_shuffle=False, cycle=False) 73 | sn = 0 74 | change = 0 75 | for sample in data_iter: 76 | turns = sample['turns'] 77 | if not (turns is None or len(turns) == 0): 78 | if not turns[0]: 79 | turns = turns[1:] 80 | if len(turns) > sum(turns) * 2: 81 | context, goals, turns, unused_goals, replace_dicts = data_input.reader.trans_sample( 82 | sample, return_rest_goals=True, need_replace_dict=True) 83 | sample.update( 84 | { 85 | 'context': context, 86 | 'goals': goals, 87 | 'turns': turns, 88 | 'unused_goals': unused_goals, 89 | 'replace_dicts': replace_dicts, 90 | } 91 | ) 92 | change += 1 93 | all_data.append(sample) 94 | sn += 1 95 | if sn % 58 == 0: 96 | print('\rnum {} change {}'.format(sn, change), end=' ') 97 | # if sn > 30: 98 | # break 99 | print('\nOver: ', sn) 100 | with open(output_path, encoding='utf-8', mode='w') as fw: 101 | for data in all_data: 102 | fw.writelines(json.dumps( 103 | data, 104 | ensure_ascii=False, 105 | # indent=4, separators=(',',':') 106 | ) + '\n') 107 | 108 | 109 | if __name__ == '__main__': 110 | trans(0) 111 | trans(1) 112 | trans(2) 113 | trans(3) -------------------------------------------------------------------------------- /code/data_deal/trans_output.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/4/29 20:16 6 | @File :trans_output.py 7 | @Desc : 8 | """ 9 | from cfg import * 10 | from data_deal.base_input import BaseRead 11 | import re 12 | import numpy as np 13 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 14 | 15 | 16 | class TransOutput(object): 17 | def __init__(self, rc_tag=''): 18 | self.reader = BaseRead() 19 | self.search_comp = re.compile('\[[PCKpck]-[^\[\]]+\]') 20 | self.search_comp_date = re.compile('\[[kK]-\[n\]\[n\]\[n\]\]') 21 | self._enc_comp = re.compile('\[[PCKpck]-[^\[\]]{,8}$') 22 | self.goal_comp = re.compile('\[[0-9]\]') 23 | self.date_comp = re.compile('\d{1,2} ?月\d{1,2} ?[日号]') 24 | self._birthday_comp = re.compile('\d[\d ]{3,}(-)[\d ]+(-)[\d ]+') 25 | self.replace_couple = [ 26 | (' ', ''), 27 | ('同意', '喜欢'), 28 | ('没有接受', '拒绝'), 29 | ('接受', '喜欢'), 30 | ('喜好', '喜欢'), 31 | ] 32 | if rc_tag != '': 33 | from model.model_rc import BertCL 34 | self.rc_model_cls = BertCL(tag=rc_tag, is_predict=True) 35 | else: 36 | self.rc_model_cls = None 37 | 38 | def clean_sentence(self, sentence): 39 | for k, v in self.replace_couple: 40 | sentence = sentence.replace(k, v) 41 | return sentence 42 | 43 | def pre_trans(self, answer): 44 | if '生日' in answer and len(answer) < 20: 45 | for d in self.date_comp.findall(answer): 46 | answer = answer.replace(d, '[k-生日]') 47 | return answer 48 | 49 | def trans_output(self, sample: dict, answer: str): 50 | answer = self.clean_sentence(answer) 51 | answer = self.pre_trans(answer) 52 | user_profile = {} 53 | for k, v in sample['user_profile'].items(): 54 | clean_k = self.clean_sentence(k) 55 | user_profile[clean_k] = user_profile.get(clean_k, []) 56 | if not isinstance(v, list): 57 | v = [v] 58 | user_profile[clean_k].extend(v) 59 | kg_dict = {} 60 | for k in sample['knowledge']: 61 | clean_p = self.clean_sentence(self.reader.clean_kg_type(k[1])) 62 | kg_dict[clean_p] = kg_dict.get(clean_p, []) 63 | kg_dict[clean_p].append({ 64 | 'S': k[0], 65 | 'O': k[2], 66 | }) 67 | 68 | # 寻找goal 69 | context_str = 'conversation' if 'conversation' in sample.keys() else 'history' 70 | exist_goals = [] 71 | for s in sample[context_str]: 72 | exist_goals.extend(self.goal_comp.findall(s)) 73 | if len(exist_goals) == 0: 74 | max_goal = 0 75 | else: 76 | max_goal = max([s[1] for s in exist_goals]) 77 | goal = sample['goal'] 78 | idx = goal.find('[{}]'.format(max_goal)) 79 | if idx < 0: 80 | idx = goal.find('[{}]'.format(int(max_goal) + 1)) 81 | if idx >= 0: 82 | goal = goal[idx:] 83 | else: 84 | goal = '' 85 | 86 | replace_items = self.search_comp.findall(answer) 87 | # 新闻的回复修正。新闻训练时候提取不干净,会出现多个碎片句子。这时候会出现: [K-新闻] [P-喜欢的明星] 88 | spe_items = {'[K-新闻]', '[k-新闻]', '[P-喜欢的明星]', '[p-喜欢的明星]'} 89 | inner = set(replace_items).intersection(spe_items) 90 | if len(inner) >= 2: 91 | # 寻找 b_i 92 | dots = list(',。?!,.?!') 93 | b_i = answer.find('[P-喜欢的明星]') 94 | if b_i < 0: 95 | b_i = answer.find('[p-喜欢的明星]') 96 | while b_i > 1: 97 | if answer[b_i - 1] in dots: 98 | break 99 | b_i -= 1 100 | # 寻找 e_i 101 | e_i = answer.find('[K-新闻]') 102 | if e_i < 0: 103 | e_i = answer.find('[k-新闻]') 104 | while e_i > 1: 105 | if answer[e_i - 1] in dots: 106 | break 107 | e_i -= 1 108 | answer = answer[:b_i] + answer[e_i:] 109 | replace_items = self.search_comp.findall(answer) 110 | 111 | replace_items.extend(self.search_comp_date.findall(answer)) 112 | replace_cuple = [] 113 | last_choice = None 114 | exist_none_replace = False 115 | for rp_item in replace_items: 116 | choice = None 117 | last_rp_len = len(replace_cuple) 118 | if rp_item[1] in ['p', 'P']: 119 | choice = user_profile.get(rp_item[3:-1], '') 120 | if isinstance(choice, list): 121 | choice = self.judge_choices(choice, sample, goal=goal, last_choice=last_choice) 122 | replace_cuple.append((rp_item, choice)) 123 | elif rp_item[1] in ['k', 'K']: 124 | # 新闻和评论额外的进行判定 125 | choices = kg_dict.get(rp_item[3:-1], []) 126 | obj = [d['O'] for d in choices] 127 | sbj = [d['S'] for d in choices] 128 | choice = self.judge_choices(obj, sample, sbj, goal=goal, last_choice=last_choice, 129 | identifier=rp_item, response=answer) 130 | replace_cuple.append((rp_item, choice)) 131 | elif rp_item[1] in ['c', 'C']: 132 | choices = kg_dict.get(rp_item[3:-1], []) 133 | choices = [d['S'] for d in choices] 134 | choice = self.judge_choices(choices, sample, goal=goal, last_choice=last_choice) 135 | replace_cuple.append((rp_item, choice)) 136 | else: 137 | logger.info('=' * 20) 138 | logger.info('Error rp item: {}'.format(rp_item)) 139 | logger.info('KG: {}'.format(kg_dict)) 140 | logger.info('P: {}'.format(user_profile)) 141 | replace_cuple.append((rp_item, '')) 142 | if choice is not None and len(choice) > 0: 143 | last_choice = choice 144 | if len(replace_cuple) == last_rp_len or choice == '': 145 | exist_none_replace = True 146 | 147 | all_tags = [] # 有些可能有重复的? 148 | for k, v in replace_cuple: 149 | if k in all_tags: 150 | continue 151 | all_tags.append(k) 152 | start = answer.find(k) 153 | if start < 0: 154 | continue 155 | if '[n][n][n]' in k: 156 | if '~' in v: 157 | v = v.replace('~', '转') 158 | answer = answer[:start] + v + answer[start + len(k):] 159 | # 清除多余的标记 160 | for k in self.search_comp.findall(answer): 161 | answer = answer.replace(k, '') 162 | for k in self._enc_comp.findall(answer): 163 | answer = answer.replace(k, '') 164 | # 日期转换 165 | sp = self._birthday_comp.search(answer) 166 | if sp: 167 | idx_0 = sp.group().find('-') 168 | idx_1 = sp.group()[idx_0 + 1:].find('-') + idx_0 + 1 169 | sp_str = list(sp.group()) 170 | sp_str[idx_0] = '年' 171 | sp_str[idx_1] = '月' 172 | sp_str = ''.join(sp_str) + '日' 173 | answer_after = answer[:sp.span()[0]] + sp_str 174 | if sp.span()[1] < len(answer) and answer[sp.span()[1]] == '号': 175 | answer_after = answer_after[:-1] + answer[sp.span()[1]:] 176 | elif sp.span()[1] + 1 < len(answer) and answer[sp.span()[1]:sp.span()[1] + 2] == ' 号': 177 | answer_after = answer_after[:-1] + answer[sp.span()[1] + 1:] 178 | else: 179 | answer_after += answer[sp.span()[1]:] 180 | answer = answer_after 181 | # 标点符号去重 182 | answer = re.sub('([,,.!??!。])+', '\\1', answer) 183 | return answer.replace(' ', ''), exist_none_replace 184 | 185 | def search_choices(self, sample: dict, answer: str, history:list): 186 | """纯粹给训练做样本删选 history需要包含正确的answer""" 187 | answer = self.clean_sentence(answer) 188 | answer = self.pre_trans(answer) 189 | user_profile = {} 190 | for k, v in sample['user_profile'].items(): 191 | clean_k = self.clean_sentence(k) 192 | user_profile[clean_k] = user_profile.get(clean_k, []) 193 | if not isinstance(v, list): 194 | v = [v] 195 | user_profile[clean_k].extend(v) 196 | kg_dict = {} 197 | for k in sample['knowledge']: 198 | clean_p = self.clean_sentence(self.reader.clean_kg_type(k[1])) 199 | kg_dict[clean_p] = kg_dict.get(clean_p, []) 200 | kg_dict[clean_p].append({ 201 | 'S': k[0], 202 | 'O': k[2], 203 | }) 204 | 205 | replace_items = self.search_comp.findall(answer) 206 | replace_items.extend(self.search_comp_date.findall(answer)) 207 | replace_dict = {} 208 | for rp_item in replace_items: 209 | if rp_item[1] in ['p', 'P']: 210 | choices = user_profile.get(rp_item[3:-1], '') 211 | if not isinstance(choices, list): 212 | choices = None 213 | choices = self.filter_choices(choices, history[:-1]) 214 | elif rp_item[1] in ['k', 'K']: 215 | choices = kg_dict.get(rp_item[3:-1], []) 216 | obj = [d['O'] for d in choices] 217 | sbj = [d['S'] for d in choices] 218 | choices = self.filter_choices(obj, history, sbj) 219 | elif rp_item[1] in ['c', 'C']: 220 | choices = kg_dict.get(rp_item[3:-1], []) 221 | choices = [d['S'] for d in choices] 222 | choices = self.filter_choices(choices, history[:-1]) 223 | else: 224 | logger.info('=' * 20) 225 | logger.info('Error rp item: {}'.format(rp_item)) 226 | logger.info('KG: {}'.format(kg_dict)) 227 | logger.info('P: {}'.format(user_profile)) 228 | continue 229 | if choices is not None: 230 | replace_dict[rp_item] = choices 231 | return replace_dict 232 | 233 | def judge_choices(self, choices: list, sample: dict, sbj=None, goal='', 234 | last_choice=None, identifier=None, response=None): 235 | if len(choices) == 0: 236 | return '' 237 | if len(choices) == 1: 238 | return choices[0] 239 | if sbj is not None: 240 | assert len(sbj) == len(choices) 241 | scores = [0] * len(choices) 242 | context_str = 'conversation' if 'conversation' in sample.keys() else 'history' 243 | for i, choice in enumerate(choices): 244 | # 上下文 245 | check_word = choice if sbj is None else sbj[i] 246 | if sbj is not None: # 如果上一个选择和这个主题相同,就给个最高分数加成 247 | if check_word == last_choice: 248 | scores[i] += 20 249 | context = sample[context_str] 250 | for j in range(1, len(context) + 1): 251 | sentence = context[-j] 252 | scores[i] += (self._get_score(check_word, sentence) / min(j, 4)) 253 | if sbj is not None: 254 | if choice.replace(' ', '') in sentence.replace(' ', ''): 255 | scores[i] -= 2 / min(j, 4) 256 | if last_choice is not None: 257 | scores[i] += (self._get_score(check_word, last_choice) * 1.3) 258 | # 和goal 的 匹配 259 | if goal != '': 260 | scores[i] += (self._get_score(check_word, goal) / 4) 261 | # 内容的匹配 262 | if sbj is not None: 263 | scores[i] += self.bleu(choice.replace(' ', ''), goal.replace(' ', '')) 264 | # 如果obj存在,优先选择subject内容短的 265 | if sbj is not None: 266 | scores[i] += 2 / (len(choice) + 5) 267 | if identifier is not None and response is not None and \ 268 | identifier in ['[K-新闻]', '[K-评论]', '[k-新闻]', '[k-评论]'] and self.rc_model_cls is not None: 269 | history = sample[context_str] 270 | score_gap = np.array(scores).mean() 271 | cands = [] 272 | for c, s in zip(choices, scores): 273 | if c in cands: 274 | continue 275 | if s >= score_gap: 276 | cands.append(c) 277 | gen_res = self.get_rc_result(history, response, cands, identifier) 278 | if len(gen_res) < 4: 279 | index = np.array(scores).argmax() 280 | result = choices[index] 281 | else: 282 | for i, choice in enumerate(choices): 283 | if gen_res in choice: 284 | scores[i] += 1 285 | index = np.array(scores).argmax() 286 | result = choices[index] 287 | if gen_res in result: # 修正,补全残缺的话 288 | b_i = result.find(gen_res) 289 | e_i = b_i + len(gen_res) 290 | dot_str = list(',.?!:,。?!:') 291 | while e_i < len(result): 292 | if result[e_i] not in dot_str: 293 | e_i += 1 294 | else: 295 | break 296 | result = result[b_i:e_i] 297 | else: 298 | index = np.array(scores).argmax() 299 | result = choices[index] 300 | return result 301 | 302 | def filter_choices(self, choices: list, history: list, sbj=None): 303 | """纯粹给训练做样本删选""" 304 | if choices is None: 305 | return None 306 | if len(choices) == 0: 307 | return None 308 | if len(choices) == 1: 309 | return None 310 | if sbj is not None: 311 | assert len(sbj) == len(choices) 312 | scores = [0] * len(choices) 313 | for i, choice in enumerate(choices): 314 | # 上下文 315 | check_word = choice if sbj is None else sbj[i] 316 | for j in range(1, len(history) + 1): 317 | sentence = history[-j].replace(' ', '') 318 | scores[i] += (self._get_score(check_word, sentence) / min(j, 4)) 319 | # 内容的部分就不做比较 320 | max_score = max(scores) 321 | keep_choices = [] 322 | for c, s in zip(choices, scores): 323 | if c in keep_choices: 324 | continue 325 | if s >= max_score: 326 | keep_choices.append(c) 327 | if len(keep_choices) <= 1: 328 | return None 329 | return keep_choices 330 | 331 | def get_rc_result(self, history:list, response:str, cands:list, identifier:str): 332 | question = '|'.join(history) + '|{}|{}'.format(response, identifier) 333 | context = '|'.join(cands) 334 | with self.rc_model_cls.session.graph.as_default(): 335 | with self.rc_model_cls.session.as_default(): 336 | predict_answer = self.rc_model_cls.predict(question, context) 337 | return predict_answer 338 | 339 | def _get_score(self, choice, sentence): 340 | choice_clean = choice.replace(' ', '') 341 | sentence = sentence.replace(' ', '') 342 | if choice_clean in sentence: 343 | return 2.0 344 | else: 345 | c = choice.split(' ') 346 | if len(c) > 1: 347 | for c_ in c: 348 | if c_.replace(' ', '') in sentence: 349 | return 1.0 350 | return 0.0 351 | 352 | def bleu(self, sen0, sen1): 353 | return sentence_bleu([list(sen0)], list(sen1), smoothing_function=SmoothingFunction().method1) 354 | 355 | def edit_distance(self, word1, word2): 356 | if word1 == word2: 357 | return 0 358 | len1 = len(word1) 359 | len2 = len(word2) 360 | dp = np.zeros((len1 + 1, len2 + 1)) 361 | for i in range(len1 + 1): 362 | dp[i][0] = i 363 | for j in range(len2 + 1): 364 | dp[0][j] = j 365 | for i in range(1, len1 + 1): 366 | for j in range(1, len2 + 1): 367 | delta = 0 if word1[i - 1] == word2[j - 1] else 1 368 | dp[i][j] = min(dp[i - 1][j - 1] + delta, min(dp[i - 1][j] + 1, dp[i][j - 1] + 1)) 369 | return int(dp[len1][len2]) 370 | -------------------------------------------------------------------------------- /code/model/bert_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/4/28 22:07 6 | @File :bert_lm.py 7 | @Desc : 8 | """ 9 | from bert4keras_5_8.models import build_transformer_model 10 | from bert4keras_5_8.backend import keras, K, tf 11 | from bert4keras_5_8.optimizers import Adam, extend_with_gradient_accumulation 12 | from bert4keras_5_8.snippets import AutoRegressiveDecoder 13 | from data_deal.base_input import BaseInput 14 | import numpy as np 15 | from cfg import * 16 | 17 | 18 | class Response(AutoRegressiveDecoder): 19 | """基于随机采样的故事续写 20 | """ 21 | 22 | def __init__(self, model, session, data_deal:BaseInput, *args, **kwargs): 23 | self.model = model 24 | self.data_deal = data_deal 25 | self.max_len = 512 26 | self.session = session 27 | super(Response, self).__init__(*args, **kwargs) 28 | 29 | @AutoRegressiveDecoder.set_rtype('probas') 30 | def predict(self, inputs, output_ids, step): 31 | token_ids = inputs[0] 32 | segment_ids = inputs[1] 33 | token_ids = np.concatenate([token_ids, output_ids], 1) 34 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) 35 | if token_ids.shape[1] > self.max_len: 36 | token_ids = token_ids[:, -self.max_len:] 37 | segment_ids = segment_ids[:, -self.max_len:] 38 | with self.session.graph.as_default(): 39 | with self.session.as_default(): 40 | res = self.model.predict([token_ids, segment_ids])[:, -1] 41 | return res 42 | 43 | def generate(self, sample, goals=None, need_goal=True, force_goal=False, random=False): 44 | if goals is None: 45 | goals = [] 46 | token_ids, segs, goal_index = self.data_deal.encode_predict_final( 47 | sample, goals, need_goal=need_goal, force_goal=force_goal, silent=False) 48 | if random: 49 | if token_ids is None: 50 | return [] 51 | res = self.nucleus_sample([token_ids, segs], 3, topk=3) 52 | res = [self.data_deal.tokenizer.decode(r) for r in res] 53 | else: 54 | if token_ids is None: 55 | return '' 56 | res = self.beam_search([token_ids, segs], 1) 57 | res = self.data_deal.tokenizer.decode(res) 58 | # if goal_index: 59 | # if isinstance(res, list): 60 | # res = ['[{}]{}'.format(goal_index, s) for s in res] 61 | # else: 62 | # res = '[{}]{}'.format(goal_index, res) 63 | return res 64 | 65 | def check_goal_end(self, sample, end_id): 66 | token_ids, segs = self.data_deal.encode_predict_final(sample, cand_goals=[], need_goal=False) 67 | with self.session.graph.as_default(): 68 | with self.session.as_default(): 69 | score = self.model.predict([[token_ids], [segs]])[0, -1] 70 | m = np.argmax(score) 71 | if m != end_id: 72 | return True 73 | else: 74 | return False 75 | 76 | def goal_generate(self, sample, n=5): 77 | token_ids, segs, goal_index = self.data_deal.encode_predict_final(sample, cand_goals=[], need_goal=False) 78 | results = self.nucleus_sample([token_ids, segs], n=n, topk=20) 79 | return [self.data_deal.tokenizer.decode(res) for res in results] 80 | 81 | 82 | class BertLM(object): 83 | def __init__(self, keep_tokens, load_path=None): 84 | keras.backend.clear_session() 85 | gpu_config = tf.ConfigProto() 86 | gpu_config.gpu_options.allow_growth = True 87 | keras.backend.set_session(tf.Session(config=gpu_config)) 88 | self.session = keras.backend.get_session() 89 | 90 | need_load = False 91 | if load_path and os.path.exists(load_path): 92 | need_load = True 93 | 94 | self.model = build_transformer_model( 95 | join(BERT_PATH, 'bert_config.json'), 96 | None if need_load else join(BERT_PATH, 'bert_model.ckpt'), 97 | application='lm', 98 | keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表 99 | ) 100 | self.model.summary() 101 | 102 | if need_load: 103 | logger.info('=' * 15 + 'Load from checkpoint: {}'.format(load_path)) 104 | self.model.load_weights(load_path) 105 | 106 | def compile(self): 107 | # 交叉熵作为loss,并mask掉输入部分的预测 108 | y_true = self.model.input[0][:, 1:] # 目标tokens 109 | y_mask = self.model.input[1][:, 1:] # 目标mask 110 | y_mask = K.cast(y_mask, K.floatx()) # 转为浮点型 111 | y_pred = self.model.output[:, :-1] # 预测tokens,预测与目标错开一位 112 | cross_entropy = K.sparse_categorical_crossentropy(y_true, y_pred) 113 | cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask) 114 | self.model.add_loss(cross_entropy) 115 | opt = extend_with_gradient_accumulation(Adam)(learning_rate=0.000015, grad_accum_steps=2) 116 | self.model.compile(optimizer=opt) 117 | -------------------------------------------------------------------------------- /code/model/extract_embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/18 20:51 6 | @File :extract_embedding.py 7 | @Desc : 8 | """ 9 | from bert4keras_5_8.models import build_transformer_model 10 | from bert4keras_5_8.backend import keras, tf 11 | from bert4keras_5_8.tokenizers import Tokenizer 12 | from bert4keras_5_8.snippets import sequence_padding 13 | from cfg import * 14 | 15 | 16 | class BertEmb(object): 17 | def __init__(self): 18 | keras.backend.clear_session() 19 | gpu_config = tf.ConfigProto() 20 | gpu_config.gpu_options.allow_growth = True 21 | keras.backend.set_session(tf.Session(config=gpu_config)) 22 | self.session = keras.backend.get_session() 23 | 24 | self.tokenizer = Tokenizer(join(BERT_PATH, 'vocab.txt'), do_lower_case=True) # 建立分词器 25 | self.model = build_transformer_model( 26 | join(BERT_PATH, 'bert_config.json'), 27 | join(BERT_PATH, 'bert_model.ckpt'), 28 | ) 29 | 30 | def get_embedding(self, sentences): 31 | X, S = [], [] 32 | for sentence in sentences: 33 | token_ids, segment_ids = self.tokenizer.encode(sentence) 34 | X.append(token_ids) 35 | S.append(segment_ids) 36 | X = sequence_padding(X) 37 | S = sequence_padding(S) 38 | with self.session.graph.as_default(): 39 | with self.session.as_default(): 40 | result = self.model.predict([X, S]) 41 | return result -------------------------------------------------------------------------------- /code/model/model_context.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/18 21:01 6 | @File :model_context.py 7 | @Desc : 8 | """ 9 | from bert4keras_5_8.backend import keras, search_layer, K, tf 10 | from bert4keras_5_8.models import build_transformer_model 11 | from bert4keras_5_8.optimizers import Adam 12 | from bert4keras_5_8.layers import Lambda, Dense, Input 13 | from bert4keras_5_8.snippets import sequence_padding 14 | from utils.snippet import adversarial_training 15 | import re 16 | from cfg import * 17 | 18 | 19 | class ModelContext(object): 20 | def __init__(self, keep_tokens, load_path=None): 21 | keras.backend.clear_session() 22 | gpu_config = tf.ConfigProto() 23 | gpu_config.gpu_options.allow_growth = True 24 | keras.backend.set_session(tf.Session(config=gpu_config)) 25 | self.session = keras.backend.get_session() 26 | 27 | need_load = False 28 | if load_path and os.path.exists(load_path): 29 | need_load = True 30 | 31 | bert = build_transformer_model( 32 | join(BERT_PATH, 'bert_config.json'), 33 | None if need_load else join(BERT_PATH, 'bert_model.ckpt'), 34 | keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表 35 | return_keras_model=False, 36 | ) 37 | 38 | layers_out_lambda = Lambda(lambda x: x[:, 0]) 39 | layers_out_dense = Dense(units=2, 40 | activation='softmax', 41 | kernel_initializer=bert.initializer) 42 | 43 | output = layers_out_lambda(bert.model.output) 44 | output = layers_out_dense(output) 45 | 46 | self.model = keras.models.Model(bert.model.input, output, name='Final-Model') 47 | if need_load: 48 | logger.info('=' * 15 + 'Load from checkpoint: {}'.format(load_path)) 49 | self.model.load_weights(load_path) 50 | self.data_deal = None 51 | 52 | def compile(self): 53 | self.model.compile( 54 | loss='sparse_categorical_crossentropy', 55 | optimizer=Adam(2e-5), 56 | metrics=['sparse_categorical_accuracy'] 57 | ) 58 | adversarial_training(self.model, 'Embedding-Token', 0.3) 59 | 60 | def predict(self, contexts): 61 | if self.data_deal is None: 62 | from data_deal.input_ct import CTInput 63 | self.data_deal = CTInput(from_pre_trans=False) 64 | X, S = [], [] 65 | for context in contexts: 66 | context = [re.sub(self.data_deal.reader.goal_num_comp, '', s).replace(' ', '') for s in context] 67 | x, s, l = self.data_deal.encode(ori_context=context) 68 | X.append(x) 69 | S.append(s) 70 | X = sequence_padding(X) 71 | S = sequence_padding(S) 72 | with self.session.graph.as_default(): 73 | with self.session.as_default(): 74 | R = self.model.predict([X, S]) 75 | return R[:, 1] -------------------------------------------------------------------------------- /code/model/model_goal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/8 21:56 6 | @File :goal_predict.py 7 | @Desc : 8 | """ 9 | 10 | from bert4keras_5_8.models import build_transformer_model 11 | from bert4keras_5_8.backend import keras, K, tf, search_layer 12 | from bert4keras_5_8.optimizers import Adam, extend_with_gradient_accumulation 13 | from bert4keras_5_8.layers import Lambda, Dense 14 | from utils.snippet import adversarial_training 15 | import numpy as np 16 | from cfg import * 17 | 18 | 19 | class BertGoal(object): 20 | def __init__(self, keep_tokens, num_classes, load_path=None): 21 | keras.backend.clear_session() 22 | gpu_config = tf.ConfigProto() 23 | gpu_config.gpu_options.allow_growth = True 24 | keras.backend.set_session(tf.Session(config=gpu_config)) 25 | self.session = keras.backend.get_session() 26 | 27 | need_load = False 28 | if load_path and os.path.exists(load_path): 29 | need_load = True 30 | bert = build_transformer_model( 31 | config_path=join(BERT_PATH, 'bert_config.json'), 32 | checkpoint_path=None if need_load else join(BERT_PATH, 'bert_model.ckpt'), 33 | return_keras_model=False, 34 | keep_tokens=keep_tokens, 35 | ) 36 | output = Lambda(lambda x: x[:, 0])(bert.model.output) 37 | output = Dense( 38 | units=num_classes, 39 | activation='softmax', 40 | kernel_initializer=bert.initializer 41 | )(output) 42 | 43 | self.model = keras.models.Model(bert.model.input, output) 44 | # self.model.summary() 45 | 46 | if need_load: 47 | logger.info('=' * 15 + 'Load from checkpoint: {}'.format(load_path)) 48 | self.model.load_weights(load_path) 49 | self.data_deal = None 50 | 51 | def predict(self, sample:dict): 52 | if self.data_deal is None: 53 | from data_deal.input_goal import GoalInput 54 | self.data_deal = GoalInput() 55 | x, s = self.data_deal.encode_predict(sample) 56 | res = self.model.predict([[x], [s]])[0] 57 | return self.data_deal.reader.all_goals[np.argmax(res)] 58 | 59 | def compile(self): 60 | opt = extend_with_gradient_accumulation(Adam)(learning_rate=0.000015, grad_accum_steps=2) 61 | self.model.compile( 62 | loss='sparse_categorical_crossentropy', 63 | optimizer=opt, 64 | metrics=['sparse_categorical_accuracy'], 65 | ) 66 | adversarial_training(self.model, 'Embedding-Token', 0.5) -------------------------------------------------------------------------------- /code/model/model_rc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | @Time : 2020/5/10 9:58 6 | @Author : Apple QXTD 7 | @File : model_rc.py 8 | @Desc: : 9 | """ 10 | from cfg import * 11 | import re 12 | import numpy as np 13 | from bert4keras_5_8.backend import keras, K, tf, search_layer 14 | from bert4keras_5_8.models import build_transformer_model 15 | from bert4keras_5_8.snippets import sequence_padding 16 | from keras.layers import Lambda 17 | from keras.models import Model 18 | from data_deal.input_rc import RCInput 19 | from bert4keras_5_8.optimizers import Adam, extend_with_gradient_accumulation 20 | 21 | 22 | class BertCL: 23 | def __init__(self, tag='d', is_predict=False, load_path=None): 24 | self.save_path = join(MODEL_PATH, 'rc_' + tag, 'trained.h5') 25 | 26 | keras.backend.clear_session() 27 | gpu_config = tf.ConfigProto() 28 | gpu_config.gpu_options.allow_growth = True 29 | self.session = tf.Session(config=gpu_config) 30 | keras.backend.set_session(self.session) 31 | 32 | self.data_deal = RCInput() 33 | self.config_path = join(BERT_PATH, 'bert_config.json') 34 | self.checkpoint_path = join(BERT_PATH, 'bert_model.ckpt') 35 | 36 | self.tokenizer = self.data_deal.tokenizer 37 | self.max_p_len = self.data_deal.max_p_len 38 | self.max_q_len = self.data_deal.max_q_len 39 | self.max_a_len = self.data_deal.max_a_len 40 | self.batch_size = self.data_deal.batch_size 41 | 42 | model = build_transformer_model( 43 | self.config_path, 44 | None if is_predict else self.checkpoint_path, 45 | model='bert', 46 | with_mlm=True, 47 | keep_tokens=self.data_deal.keep_tokens, # 只保留keep_tokens中的字,精简原字表 48 | ) 49 | output = Lambda(lambda x: x[:, 1:self.max_a_len + 1])(model.output) 50 | self.model = Model(model.input, output) 51 | # self.model.summary() 52 | if load_path: 53 | logger.info('Load from init checkpoint {} .'.format(load_path)) 54 | self.model.load_weights(load_path) 55 | elif os.path.exists(self.save_path): 56 | logger.info('Load from init checkpoint {} .'.format(self.save_path)) 57 | self.model.load_weights(self.save_path) 58 | 59 | def predict(self, question, contexts, return_items=False, in_passage=True): 60 | if isinstance(contexts, str): 61 | contexts = [contexts] 62 | passages = [] 63 | if len(question) == 0: 64 | return None 65 | for context in contexts: 66 | add = True 67 | while add: 68 | # 每间隔200进行拆分 69 | passages.append(context[:self.max_p_len]) 70 | if len(context) <= self.max_p_len: 71 | add = False 72 | else: 73 | context = context[200:] 74 | if len(passages) == 0: 75 | return None 76 | answer = self.gen_answer(question, passages, in_passage=in_passage) 77 | answer = self.max_in_dict(answer) 78 | if not return_items: 79 | if answer is None: 80 | answer = '' 81 | else: 82 | answer = answer[0][0] 83 | return answer 84 | 85 | @staticmethod 86 | def get_ngram_set(x, n): 87 | """生成ngram合集,返回结果格式是: 88 | {(n-1)-gram: set([n-gram的第n个字集合])} 89 | """ 90 | result = {} 91 | for i in range(len(x) - n + 1): 92 | k = tuple(x[i:i + n]) 93 | if k[:-1] not in result: 94 | result[k[:-1]] = set() 95 | result[k[:-1]].add(k[-1]) 96 | return result 97 | 98 | def gen_answer(self, question, passages, in_passage=True): 99 | """由于是MLM模型,所以可以直接argmax解码。 100 | """ 101 | all_p_token_ids, token_ids, segment_ids = [], [], [] 102 | 103 | for passage in passages: 104 | passage = re.sub(u' |、|;|,', ',', passage) 105 | p_token_ids, _ = self.tokenizer.encode(passage, max_length=self.max_p_len + 1) 106 | q_token_ids, _ = self.tokenizer.encode(question, max_length=self.max_q_len + 1) 107 | all_p_token_ids.append(p_token_ids[1:]) 108 | token_ids.append([self.tokenizer._token_cls_id]) 109 | token_ids[-1] += ([self.tokenizer._token_mask_id] * self.max_a_len) 110 | token_ids[-1] += [self.tokenizer._token_sep_id] 111 | token_ids[-1] += (q_token_ids[1:] + p_token_ids[1:]) 112 | segment_ids.append([0] * len(token_ids[-1])) 113 | 114 | token_ids = sequence_padding(token_ids) 115 | segment_ids = sequence_padding(segment_ids) 116 | with self.session.graph.as_default(): 117 | with self.session.as_default(): 118 | probas = self.model.predict([token_ids, segment_ids], batch_size=3) 119 | results = {} 120 | for t, p in zip(all_p_token_ids, probas): 121 | a, score = tuple(), 0. 122 | for i in range(self.max_a_len): 123 | # pi是将passage以外的token的概率置零 124 | if in_passage: 125 | idxs = list(self.get_ngram_set(t, i + 1)[a]) 126 | if self.tokenizer._token_sep_id not in idxs: 127 | idxs.append(self.tokenizer._token_sep_id) 128 | pi = np.zeros_like(p[i]) 129 | pi[idxs] = p[i, idxs] 130 | else: 131 | pi = p[i] 132 | a = a + (pi.argmax(),) 133 | score += pi.max() 134 | if a[-1] == self.tokenizer._token_sep_id: 135 | break 136 | score = score / (i + 1) 137 | a = self.tokenizer.decode(a) 138 | if a: 139 | results[a] = results.get(a, []) + [score] 140 | results = { 141 | k: (np.array(v) ** 2).sum() / (sum(v) + 1) 142 | for k, v in results.items() 143 | } 144 | return results 145 | 146 | @staticmethod 147 | def max_in_dict(d): 148 | if d: 149 | return sorted(d.items(), key=lambda s: -s[1]) 150 | 151 | def compile(self): 152 | 153 | def masked_cross_entropy(y_true, y_pred): 154 | """交叉熵作为loss,并mask掉padding部分的预测 155 | """ 156 | y_true = K.reshape(y_true, [K.shape(y_true)[0], -1]) 157 | y_mask = K.cast(K.not_equal(y_true, 0), K.floatx()) 158 | cross_entropy = K.sparse_categorical_crossentropy(y_true, y_pred) 159 | cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask) 160 | return cross_entropy 161 | 162 | opt = extend_with_gradient_accumulation(Adam, name='accum')(grad_accum_steps=3, learning_rate=3e-5) 163 | self.model.compile(loss=masked_cross_entropy, optimizer=opt) 164 | 165 | 166 | def test(): 167 | m = BertLM() 168 | m.compile() 169 | 170 | 171 | if __name__ == '__main__': 172 | test() -------------------------------------------------------------------------------- /code/model/model_recall.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/18 19:42 6 | @File :model_recall.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | 13 | from cfg import * 14 | import numpy as np 15 | from utils.sif import Sentence2Vec 16 | from gensim.models import Word2Vec 17 | from sklearn.decomposition import PCA 18 | import joblib 19 | from utils.snippet import normalization 20 | from model.extract_embedding import BertEmb 21 | from data_deal.base_input import BaseInput 22 | from annoy import AnnoyIndex 23 | import json 24 | import re 25 | 26 | recall_path = join(MODEL_PATH, 'recall') 27 | if not os.path.isdir(recall_path): 28 | os.makedirs(recall_path) 29 | 30 | 31 | class RC_CFG(object): 32 | def __init__(self): 33 | self.max_seq_len = 128 34 | self.emd_dim = 768 35 | self.pca_dim = 188 36 | 37 | 38 | recall_config = RC_CFG() 39 | 40 | 41 | def train_step_1(): 42 | model_emb = BertEmb() 43 | az_comp = re.compile('[a-zA-Z0-9]+') 44 | num_comp = re.compile('[0-9]') 45 | start_num_comp = re.compile('\[\d\]') 46 | 47 | data_input = BaseInput(from_pre_trans=True) 48 | 49 | questions = [] 50 | answers = [] 51 | 52 | logger.info('calucate sentences ...') 53 | 54 | for sample in data_input.get_sample([0, 1], need_shuffle=False, cycle=False): 55 | # if len(questions) > 1000: 56 | # break 57 | context, turns = sample['context'], sample['turns'] 58 | if turns is None: 59 | continue 60 | for i, turn in enumerate(turns): 61 | if i == 0: 62 | continue 63 | if turn: 64 | ans = re.sub(start_num_comp, '', context[i]) 65 | q = re.sub(start_num_comp, '', context[i - 1]) 66 | if len(num_comp.findall(ans)) > 0: # 包含数字的回复全部丢弃 67 | continue 68 | questions.append(re.sub(az_comp, '', q)) 69 | answers.append(ans) 70 | 71 | print(f'questions: {questions[:2]}') 72 | print(f'answers: {answers[:2]}') 73 | print(f'len: {len(questions)}') 74 | logger.info('split sentences ...') 75 | splited_sentences = [] 76 | for doc in questions[:1000000]: 77 | splited_sentences.append(list(doc)) 78 | 79 | logger.info('train gensim ...') 80 | word_model = Word2Vec(splited_sentences, min_count=1, size=recall_config.emd_dim, iter=0) 81 | sif_model = Sentence2Vec(word_model, max_seq_len=recall_config.max_seq_len, components=2) 82 | logger.info('gensim train done .') 83 | del splited_sentences, word_model 84 | 85 | logger.info('get vecotrs and train pc...') 86 | 87 | # Memory will explode, rewrite the logic here 88 | sentence_vectors = [] 89 | vec_batch = 10000 90 | pca = PCA(n_components=recall_config.pca_dim, whiten=True, random_state=2112) 91 | 92 | pca_n = min(300000, len(questions)) 93 | has_pca_trained = False 94 | 95 | for b_i, e_i in zip(range(0, len(questions), vec_batch), range(vec_batch, len(questions) + vec_batch, vec_batch)): 96 | sentences_out = model_emb.get_embedding(questions[b_i:e_i]) 97 | splited_sentences = [] 98 | for doc in questions[b_i:e_i]: 99 | splited_sentences.append(list(doc)) 100 | sentences_out = sif_model.cal_output(splited_sentences, sentences_out) 101 | if e_i >= pca_n: 102 | if has_pca_trained: 103 | sentence_vectors.extend(normalization(pca.transform(sentences_out))) 104 | else: 105 | logger.info('Train PCA ... pca_n num: {}'.format(pca_n)) 106 | sentence_vectors.extend(sentences_out) 107 | pca.fit(np.stack(sentence_vectors[:pca_n])) 108 | sentence_vectors = list(normalization(pca.transform(np.stack(sentence_vectors)))) 109 | has_pca_trained = True 110 | else: 111 | sentence_vectors.extend(sentences_out) 112 | del sentences_out, splited_sentences 113 | logger.info(' complete one batch. batch_size: {} percent {:.2f}%'.format( 114 | vec_batch, (100 * min(len(questions), e_i) / len(questions)))) 115 | 116 | sentence_vectors = np.stack(sentence_vectors) 117 | 118 | sentences_emb = sif_model.train_pc(sentence_vectors) 119 | print(sentences_emb.shape) 120 | logger.info('train pc over.') 121 | 122 | logger.info('save model') 123 | joblib.dump(sif_model, os.path.join(recall_path, 'bert_sif.sif')) 124 | joblib.dump(pca, os.path.join(recall_path, 'bert_pca.pc')) 125 | json.dump(answers, open(join(recall_path, 'answers.json'), mode='w', encoding='utf-8'), 126 | ensure_ascii=False, indent=4, separators=(',', ':')) 127 | np.save(join(recall_path, 'sentences_emb'), sentences_emb) 128 | 129 | 130 | def train_step_2(): 131 | logger.info('train_step_2 ...') 132 | final_q_embs = np.load(join(recall_path, 'sentences_emb.npy')) 133 | 134 | annoy_model = AnnoyIndex(recall_config.pca_dim, metric='angular') 135 | logger.info('add annoy...') 136 | for i, emb in enumerate(final_q_embs): 137 | annoy_model.add_item(i, emb) 138 | logger.info('build annoy...') 139 | annoy_model.build(88) 140 | annoy_model.save(join(recall_path, 'annoy.an')) 141 | logger.info('build over...') 142 | 143 | 144 | class SearchEMb: 145 | def __init__(self, top_n=3): 146 | self.model_emb = BertEmb() 147 | self.sif = joblib.load(join(recall_path, 'bert_sif.sif')) 148 | self.pca = joblib.load(join(recall_path, 'bert_pca.pc')) 149 | self.answers = json.load(open(join(recall_path, 'answers.json'), encoding='utf-8')) 150 | self.annoy = AnnoyIndex(recall_config.pca_dim, metric='angular') 151 | self.annoy.load(join(recall_path, 'annoy.an')) 152 | 153 | self.az_comp = re.compile('[a-zA-Z0-9]+') 154 | self.start_num_comp = re.compile('\[\d\]') 155 | self.top_n = top_n 156 | 157 | def get_recall(self, sentence): 158 | sentence = re.sub(self.start_num_comp, '', sentence) 159 | sentence = re.sub(self.az_comp, '', sentence) 160 | res_indexs, distances = self.annoy.get_nns_by_vector(self.get_emb(sentence), self.top_n, include_distances=True) 161 | results = [] 162 | for idx in res_indexs: 163 | results.append(self.answers[idx]) 164 | return results, distances 165 | 166 | def get_emb(self, sentence): 167 | vectors = self.model_emb.get_embedding([sentence]) 168 | mid_vectors = self.sif.cal_output([list(sentence)], vectors) 169 | mid_vectors = normalization(self.pca.transform(mid_vectors)) 170 | return self.sif.predict_pc(mid_vectors)[0] 171 | 172 | 173 | if __name__ == '__main__': 174 | train_step_1() 175 | train_step_2() 176 | -------------------------------------------------------------------------------- /code/predict/check_predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | @Time : 2020/5/5 14:12 6 | @Author : Apple QXTD 7 | @File : check_predict.py 8 | @Desc: : 9 | """ 10 | import os, sys 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | 13 | from cfg import * 14 | from model.bert_lm import BertLM, Response 15 | from data_deal.base_input import BaseInput 16 | from data_deal.trans_output import TransOutput 17 | import jieba 18 | # from model.model_goal import BertGoal 19 | 20 | 21 | tag = TAG 22 | # tag = 'd4-6ep-ng' 23 | save_dir = join(MODEL_PATH, 'BertLM_' + tag) 24 | save_path = join(save_dir, 'trained.h5') 25 | data_input = BaseInput() 26 | model_cls = BertLM(data_input.keep_tokens, load_path=save_path) 27 | response = Response(model_cls.model, 28 | model_cls.session, 29 | data_input, 30 | start_id=None, 31 | end_id=data_input.tokenizer._token_sep_id, 32 | maxlen=40 33 | ) 34 | goal_response = Response(model_cls.model, 35 | model_cls.session, 36 | data_input, 37 | start_id=None, 38 | end_id=data_input.tokenizer._token_goal_id, 39 | maxlen=10 40 | ) 41 | out_trans = TransOutput(rc_tag='') 42 | 43 | goal_dir = join(MODEL_PATH, 'Goal_' + tag) 44 | goal_path = join(goal_dir, 'trained.h5') 45 | # goal_cls = BertGoal(data_input.keep_tokens, num_classes=len(data_input.reader.all_goals), load_path=goal_path) 46 | 47 | 48 | test_iter = data_input.get_sample( 49 | 2, 50 | need_shuffle=False, 51 | cycle=False 52 | ) 53 | 54 | 55 | def cal_participle(samp:dict): 56 | words = [] 57 | words.extend(samp['situation'].split(' ')) 58 | words.extend(samp['goal'].split(' ')) 59 | for k, v in samp['user_profile'].items(): 60 | if not isinstance(v, list): 61 | v = [v] 62 | for _v in v: 63 | words.extend(_v.split(' ')) 64 | for kg in samp['knowledge']: 65 | words.extend(kg[2].split(' ')) 66 | words = set(words) 67 | words = [w for w in words if len(w) > 1] 68 | return words 69 | 70 | 71 | skip = 28 72 | i = 0 73 | last_sample = None 74 | 75 | for sample in test_iter: 76 | i += 1 77 | if i <= 1: 78 | last_sample = sample 79 | continue 80 | if i <= skip: 81 | last_sample = sample 82 | continue 83 | 84 | samp_words = cal_participle(sample) 85 | for w in samp_words: 86 | jieba.add_word(w) 87 | goals = goal_response.goal_generate(last_sample, n=4) 88 | goals = list(set(goals)) 89 | # goals = [goal_cls.predict(last_sample)] 90 | answer_res = response.generate(last_sample, goals=goals) 91 | answer, tag = out_trans.trans_output(last_sample, answer_res) 92 | if tag: 93 | answer_res = response.generate(last_sample, goals=goals, random=True) 94 | for res in answer_res: 95 | answer, tag = out_trans.trans_output(last_sample, res) 96 | if not tag: 97 | break 98 | if tag: 99 | answer_res = response.generate(last_sample, goals=goals, force_goal=True, random=True) 100 | for res in answer_res: 101 | answer, tag = out_trans.trans_output(last_sample, res) 102 | if not tag: 103 | break 104 | e_i = 0 105 | if answer[0] == '[': 106 | for j in range(1, min(4, len(answer))): 107 | if answer[j] == ']': 108 | e_i = j + 1 109 | break 110 | answer = answer[:e_i] + ' ' + ' '.join(jieba.lcut(answer[e_i:])) 111 | last_sample = sample 112 | 113 | print('\n=====> Over: ', i) -------------------------------------------------------------------------------- /code/predict/check_predict_lm_ct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/19 22:44 6 | @File :check_predict_lm_ct.py 7 | @Desc : 8 | """ 9 | 10 | import os, sys 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | 13 | from cfg import * 14 | from model.bert_lm import BertLM, Response 15 | from data_deal.base_input import BaseInput 16 | from data_deal.input_ct import CTInput 17 | from data_deal.trans_output import TransOutput 18 | from model.model_context import ModelContext 19 | from model.model_recall import SearchEMb 20 | import jieba 21 | import time 22 | import numpy as np 23 | import re 24 | 25 | import argparse 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--type', type=int, 29 | help=r'default is 2', 30 | default=4) 31 | args = parser.parse_args(sys.argv[1:]) 32 | data_type = args.type 33 | save_dir = join(MODEL_PATH, 'BertLM_' + TAG) 34 | save_path = join(save_dir, 'trained.h5') 35 | if not os.path.isdir(OUT_PATH): 36 | os.makedirs(OUT_PATH) 37 | output_dir = join(OUT_PATH, 'out_{}_{}_{}.txt'.format( 38 | data_type, TAG, time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time())))) 39 | 40 | data_input = BaseInput(from_pre_trans=True) 41 | model_cls = BertLM(data_input.keep_tokens, load_path=save_path) 42 | response = Response(model_cls.model, 43 | model_cls.session, 44 | data_input, 45 | start_id=None, 46 | end_id=data_input.tokenizer._token_sep_id, 47 | maxlen=40 48 | ) 49 | goal_response = Response(model_cls.model, 50 | model_cls.session, 51 | data_input, 52 | start_id=None, 53 | end_id=data_input.tokenizer._token_goal_id, 54 | maxlen=10 55 | ) 56 | out_trans = TransOutput(rc_tag='') 57 | search_rc = SearchEMb(top_n=3) 58 | 59 | ct_dir = join(MODEL_PATH, 'CT_' + TAG) 60 | ct_path = join(ct_dir, 'trained.h5') 61 | ct_input = CTInput(from_pre_trans=False) 62 | model_ct_cls = ModelContext(ct_input.keep_tokens, load_path=ct_path) 63 | del ct_input 64 | 65 | test_iter = data_input.get_sample( 66 | data_type, 67 | need_shuffle=False, 68 | cycle=False 69 | ) 70 | 71 | 72 | def cal_participle(samp: dict): 73 | words = [] 74 | words.extend(samp['situation'].split(' ')) 75 | words.extend(samp['goal'].split(' ')) 76 | for k, v in samp['user_profile'].items(): 77 | if not isinstance(v, list): 78 | v = [v] 79 | for _v in v: 80 | words.extend(_v.split(' ')) 81 | for kg in samp['knowledge']: 82 | words.extend(kg[2].split(' ')) 83 | words = set(words) 84 | words = [w for w in words if len(w) > 1] 85 | return words 86 | 87 | 88 | skip = 218 89 | i = 0 90 | for sample in test_iter: 91 | i += 1 92 | if i <= skip: 93 | continue 94 | samp_words = cal_participle(sample) 95 | for w in samp_words: 96 | jieba.add_word(w) 97 | 98 | goals = goal_response.goal_generate(sample, n=4) 99 | goals = list(set(goals)) 100 | history = sample['history'] 101 | final_answers = [] 102 | turn = 0 103 | while len(final_answers) <= 0: 104 | answer_res = response.generate(sample, goals=goals, random=True) 105 | score_mul = [1] * len(answer_res) 106 | if (len(history) > 1 and len(history[-1]) > 4 and '新闻' not in ''.join(history[-2:])) or turn > 0: 107 | rc_ans, rc_dis = search_rc.get_recall(history[-1]) 108 | answer_res.extend(rc_ans) 109 | score_mul = score_mul + np.minimum((1.0 - np.array(rc_dis)) * 0.5 + 1.0, 0.99).tolist() 110 | # 去重 转换 111 | mid_res_clean = [] 112 | mid_sc = [] 113 | for ans, sc in zip(answer_res, score_mul): 114 | sentence = re.sub(data_input.reader.goal_num_comp, '', ans) 115 | if sentence in mid_res_clean: 116 | continue 117 | trans_answer, tag = out_trans.trans_output(sample, ans) 118 | if tag: 119 | continue 120 | final_answers.append(trans_answer) 121 | mid_sc.append(sc) 122 | mid_res_clean.append(sentence) 123 | score_mul = mid_sc 124 | turn += 1 125 | if turn > 5: 126 | final_answers = ['是的呢'] 127 | score_mul = [1.0] 128 | logger.warning('No proper answer! \n{}'.format(history)) 129 | # CT score 130 | final_contexts = [history + [ans] for ans in final_answers] 131 | scores = model_ct_cls.predict(final_contexts) 132 | scores_md = np.multiply(scores, np.array(score_mul)) 133 | answer = final_answers[np.argmax(scores_md)] 134 | 135 | e_i = 0 136 | if answer[0] == '[': 137 | for j in range(1, 4): 138 | if answer[j] == ']': 139 | e_i = j + 1 140 | break 141 | answer_split = answer[:e_i] + ' ' + ' '.join(jieba.lcut(answer[e_i:])) 142 | print('\n=====> Over: ', i) 143 | -------------------------------------------------------------------------------- /code/predict/predict_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :apple.li 5 | @Time :2020/5/21 14:23 6 | @File :predict_final.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from cfg import * 12 | from model.bert_lm import BertLM, Response 13 | from data_deal.base_input import BaseInput 14 | from data_deal.trans_output import TransOutput 15 | import json 16 | import collections 17 | 18 | 19 | class FinalPredict(object): 20 | def __init__(self): 21 | save_dir = join(MODEL_PATH, 'BertLM_' + TAG) 22 | save_path = join(save_dir, 'trained.h5') 23 | 24 | data_input = BaseInput(from_pre_trans=False) 25 | model_cls = BertLM(data_input.keep_tokens, load_path=save_path) 26 | self.response = Response(model_cls.model, 27 | model_cls.session, 28 | data_input, 29 | start_id=None, 30 | end_id=data_input.tokenizer._token_sep_id, 31 | maxlen=40 32 | ) 33 | self.goal_response = Response(model_cls.model, 34 | model_cls.session, 35 | data_input, 36 | start_id=None, 37 | end_id=data_input.tokenizer._token_goal_id, 38 | maxlen=10 39 | ) 40 | self.out_trans = TransOutput(rc_tag='') 41 | 42 | def predict(self, text): 43 | 44 | try: 45 | sample = json.loads(text, encoding="utf-8", object_pairs_hook=collections.OrderedDict) 46 | except Exception: 47 | print('Error type: ', text) 48 | raise 49 | 50 | # 格式转换: 51 | sample['goal'] = self.strip_list(sample['goal']) 52 | sample['situation'] = self.strip_list(sample['situation']) 53 | if isinstance(sample['user_profile'], list): 54 | # 修正 : , 的问题 55 | ori_p = sample['user_profile'][0] 56 | new_p = ori_p 57 | # tag = False 58 | # for c in ori_p: 59 | # if c == ':': 60 | # if tag: 61 | # c = r'","' 62 | # tag = False 63 | # else: 64 | # tag = True 65 | # elif c == ',': 66 | # tag = False 67 | # new_p += c 68 | # print('=========== new_p: ', new_p) 69 | sample['user_profile'] = json.loads(new_p) 70 | print(sample['user_profile']) 71 | 72 | ct_str = 'history' if 'history' in sample.keys() else 'conversation' 73 | 74 | goal = sample["goal"] 75 | knowledge = sample["knowledge"] 76 | history = sample[ct_str] 77 | response = sample["response"] if "response" in sample else "null" 78 | assert 'user_profile' in sample.keys(), 'user_profile is needed !' 79 | assert 'situation' in sample.keys(), 'situation is needed !' 80 | 81 | # 清理history的格式 82 | history = [s.replace('bot:', '') for s in history] 83 | history = [s.replace('Bot:', '') for s in history] 84 | sample[ct_str] = history 85 | 86 | # 对goal进行格式转换。。 87 | if isinstance(goal, list): 88 | raise ValueError('goal 需要为类似test的原始格式的!') 89 | bot_first = True if len(history) % 2 == 0 else False 90 | goal_str = '' 91 | exist_goals = [] 92 | for i, goal_triple in enumerate(goal): 93 | if goal_triple[0] not in exist_goals: 94 | exist_goals.append(goal_triple[0]) 95 | if i == 0: 96 | goal_str = goal_str + '[{}] {} ( {} {} ) --> '.format( 97 | i + 1, goal_triple[0], 'Bot 主动' if bot_first else 'User 主动', goal_triple[2]) 98 | else: 99 | goal_str = goal_str + ' --> [{}] {} ( {} )'.format( 100 | i + 1, goal_triple[0], goal_triple[2]) 101 | sample['goal'] = goal_str 102 | 103 | goals = self.goal_response.goal_generate(sample, n=4) 104 | goals = list(set(goals)) 105 | print('goals: ', goals) 106 | answer_res = self.response.generate(sample, goals=goals) 107 | answer, tag = self.out_trans.trans_output(sample, answer_res) 108 | in_context = answer in sample[ct_str] 109 | if tag or in_context: 110 | print('Ori generation: {}'.format(answer_res)) 111 | answer_res = self.response.generate(sample, goals=goals, random=True, force_goal=in_context) 112 | print('More generation: {}'.format(answer_res)) 113 | for res in answer_res: 114 | answer, tag = self.out_trans.trans_output(sample, res) 115 | if not tag: 116 | break 117 | if tag: 118 | answer_res = self.response.generate(sample, goals=goals, force_goal=True, random=True) 119 | print('More More generation: {}'.format(answer_res)) 120 | for res in answer_res: 121 | answer, tag = self.out_trans.trans_output(sample, res) 122 | if not tag: 123 | break 124 | print() 125 | return answer 126 | 127 | def strip_list(self, value): 128 | res = '' 129 | if isinstance(value, list): 130 | for v in value: 131 | res += self.strip_list(v) 132 | else: 133 | res = value 134 | return res 135 | 136 | 137 | def test(): 138 | m = FinalPredict() 139 | s = r'{"situation": ["聊天 时间 : 晚上 20 : 00 , 在 家里"], "history": ["你好 啊"], "goal": [["[1] 问答 ( User 主动 按 『 参考 知识 』 问 『 周迅 』 的 信息 , Bot 回答 , User 满意 并 好评 ) --> ' \ 140 | r'[2] 关于 明星 的 聊天 ( Bot 主动 , 根据 给定 的 明星 信息 聊 『 周迅 』 相关 内容 , 至少 要 聊 2 轮 , 避免 话题 切换 太 僵硬 , 不够 自然 ) --> [3] 电影 推荐 ( Bot 主动 , Bot 使用 『 李米的猜想 』 ' \ 141 | r'的 某个 评论 当做 推荐 理由 来 推荐 『 李米的猜想 』 , User 先问 电影 『 国家 地区 、 导演 、 类型 、 主演 、 口碑 、 评分 』 中 的 一个 或 多个 , Bot 回答 , 最终 User 接受 ) --> [4] 再见"]], ' \ 142 | r'"knowledge": [["周迅", "主演", "李米的猜想"], ["李米的猜想", "评论", "疯狂 的 女人 疯狂 地爱 着 一个 男人"], ["李米的猜想", "评论", "故事 可以 , 配乐 更棒 。"], ["李米的猜想", "评论", "周迅 的 灵性 在 这部 片子 里 展露 无遗 。"], ' \ 143 | r'["李米的猜想", "评论", "放肆 的 哭 , 为 爱 付出"]], "user_profile": ["{\"姓名\": \"杨丽菲\", \"性别\": \"女\", \"居住地\": \"深圳\", \"年龄区间\": \"18-25\", \"职业状态\": \"学生\", \"喜欢 的 明星\": [\"周迅\"], ' \ 144 | r'\"喜欢 的 电影\": [\"苏州河\"], \"喜欢 的 poi\": [\"宅宅湘菜\"], \"同意 的 美食\": \" 剁椒鱼头\", \"同意 的 新闻\": \" 周迅 的新闻\", \"拒绝\": [\"音乐\"], \"接受 的 电影\": [\"巴尔扎克和小裁缝\", \"香港有个好莱坞\"], ' \ 145 | r'\"没有接受 的 电影\": [\"鸳鸯蝴蝶\"]}"]}' 146 | s2 = r'{"situation": ["聊天 时间 : 晚上 22 : 00 , 在 家里 聊天 主题 : 学习 退步"], "history": [], "goal": [["] 寒暄 ( Bot 主动 , 根据 给定 的 『 聊天 主题 』 寒暄 , 第一句 问候 要 带 User 名字 , 聊天 内容 不要 与 『 聊天 时间 』 矛盾 , 聊天 要 自然 , 不要 太 生硬 ) --> [2] 提问 ( Bot 主动 , 最 喜欢 谁 的 新闻 ? User 回答 ", " 最 喜欢 『 周杰伦 』 的 新闻 ) --> [3] 新闻 推荐 ( Bot 主动 , 推荐 『 周杰伦 』 的 新闻 『 台湾歌手 周杰伦 今天 被 聘请 成为 “ 中国 禁毒 宣传 形象大使 ” 。 周杰伦 表示 , 他 将 以 阳光 健康 的 形象 向 广大 青少年 发出 “ 拒绝 毒品 , 拥有 健康 ” 的 倡议 , 并 承诺 今后 将 积极 宣传 毒品 危害 , 倡导 全民 珍爱 生命 , 远离 毒品 。 转 , 与 周杰伦 一起 拒绝 毒品 ! 』 , User 接受 。 需要 聊 2 轮 ) --> [4] 再见"]], "knowledge": [["金立国", "喜欢 的 新闻", "周杰伦"], ["周杰伦", "新闻", "台湾歌手 周杰伦 今天 被 聘请 成为 “ 中国 禁毒 宣传 形象大使 ” 。 周杰伦 表示 , 他 将 以 阳光 健康 的 形象 向 广大 青少年 发出 “ 拒绝 毒品 , 拥有 健康 ” 的 倡议 , 并 承诺 今后 将 积极 宣传 毒品 危害 , 倡导 全民 珍爱 生命 , 远离 毒品 。 转 , 与 周杰伦 一起 拒绝 毒品 !"]], "user_profile": ["{\"姓名\": \"金立国:性别\": \"男:居住地\": \"厦门:年龄区间\": \"小于18:职业状态\": \"学生:喜欢 的 明星\": 周杰伦, \"喜欢 的 音乐\": 淡水海边, \"喜欢 的 兴趣点\": 探炉烤鱼(湾悦城店), \"同意 的 美食\": \" 烤鱼:同意 的 新闻\": \" 周杰伦 的新闻:拒绝\": 电影, \"接受 的 音乐\": 眼泪成诗(Live):刀马旦:骑士精神:屋顶:花海:黄浦江深, \"没有接受 的 音乐\": 迷魂曲:雨下一整晚}"]}' 147 | s3 = r'{"situation": ["聊天 时间 : 上午 8 : 00 , 去 上班 路上 聊天 主题 : 工作 压力 大"], "history": [], "goal": [["[1] 寒暄 ( Bot 主动 , 根据 给定 的 『 聊天 主题 』 寒暄 , 第一句 问候 要 带 User 名字 , 聊天 内容 不要 与 『 聊天 时间 』 矛盾 , 聊天 要 自然 , 不要 太 生硬 ) --> [2] 提问 ( Bot 主动 , 问 User 最 喜欢 的 电影 名 ? User 回答 ", " 最 喜欢 『 刺客聂隐娘 』 ) --> [3] 提问 ( Bot 主动 , 问 User 最 喜欢 『 刺客聂隐娘 』 的 哪个 主演 , 不 可以 问 User 『 刺客聂隐娘 』 的 主演 是 谁 。 User 回答 ", " 最 喜欢 『 舒淇 』 ) --> [4] 关于 明星 的 聊天 ( Bot 主动 , 根据 给定 的 明星 信息 聊 『 舒淇 』 相关 内容 , 至少 要 聊 2 轮 , 避免 话题 切换 太 僵硬 , 不够 自然 ) --> [5] 电影 推荐 ( Bot 主动 , Bot 使用 『 千禧曼波之蔷薇的名字 』 的 某个 评论 当做 推荐 理由 来 推荐 『 千禧曼波之蔷薇的名字 』 , User 拒绝 , 拒绝 原因 可以 是 『 看过 、 暂时 不想 看 、 对 这个 电影 不感兴趣 或 其他 原因 』 ; Bot 使用 『 飞一般爱情小说 』 的 某个 评论 当做 推荐 理由 来 推荐 『 飞一般爱情小说 』 , User 先问 电影 『 国家 地区 、 导演 、 类型 、 主演 、 口碑 、 评分 』 中 的 一个 或 多个 , Bot 回答 , 最终 User 接受 。 注意 ", " 不要 在 一句 话 推荐 两个 电影 ) --> [6] 再见"]], "knowledge": [["王力宏", "获奖", "华语 电影 传媒 大奖 _ 观众 票选 最受 瞩目 表现"], ["王力宏", "获奖", "台湾 电影 金马奖 _ 金马奖 - 最佳 原创 歌曲"], ["王力宏", "获奖", "华语 电影 传媒 大奖 _ 观众 票选 最受 瞩目 男演员"], ["王力宏", "获奖", "香港电影 金像奖 _ 金像奖 - 最佳 新 演员"], ["王力宏", "出生地", "美国 纽约"], ["王力宏", "简介", "男明星"], ["王力宏", "简介", "很 认真 的 艺人"], ["王力宏", "简介", "一向 严谨"], ["王力宏", "简介", "“ 小将 ”"], ["王力宏", "简介", "好 偶像"], ["王力宏", "体重", "67kg"], ["王力宏", "成就", "全球 流行音乐 金榜 年度 最佳 男歌手"], ["王力宏", "成就", "加拿大 全国 推崇 男歌手"], ["王力宏", "成就", "第 15 届华鼎奖 全球 最佳 歌唱演员 奖"], ["王力宏", "成就", "MTV 亚洲 音乐 台湾 最 受欢迎 男歌手"], ["王力宏", "成就", "两届 金曲奖 国语 男 演唱 人奖"], ["王力宏", "评论", "力宏 必然 是 最 棒 的 ~ ~ !"], ["王力宏", "评论", "永远 的 FOREVER LOVE ~ ! !"], ["王力宏", "评论", "在 银幕 的 表演 和 做 娱乐节目 一样 无趣 , 装 逼成 性"], ["王力宏", "评论", "PERFECT MR . RIGHT ! !"], ["王力宏", "评论", "有些 歌 一直 唱進 心底 。 。 高學歷 又 有 才 華 。 。"], ["王力宏", "生日", "1976 - 5 - 17"], ["王力宏", "身高", "180cm"], ["王力宏", "星座", "金牛座"], ["王力宏", "血型", "O型"], ["王力宏", "演唱", "一首 简单 的 歌 ( Live )"], ["王力宏", "演唱", "KISS GOODBYE ( Live )"], ["一首简单的歌(Live)", "评论", "一首 简单 的 歌 , 却是 一首 最 不 简单 的 歌 。"], ["一首简单的歌(Live)", "评论", "你 唱 的 也好 好听 , 是 宝藏 啊 兔 兔"], ["一首简单的歌(Live)", "评论", "超爱 王力宏 的 歌 , 但是 , 唱起来 真难 呀 , 哈哈哈 哈哈哈 , 这才 是 大神 级别 的 歌手 !"], ["一首简单的歌(Live)", "评论", "97 年 的 我 , 不 知道 是否 有 同道中人 , 一直 喜欢 这些 歌"], ["一首简单的歌(Live)", "评论", "07 年 那 年初三 , 第一次 无意 从 同学 手机 中 听到 , 深深 被 吸引 , 一直 如此"], ["KISS", "GOODBYE(Live) 评论", "明明 不爱 我 了 为什么 不放过 我"], ["KISS", "GOODBYE(Live) 评论", "得不到 就是 得不到 不要 说 你 不 想要"], ["KISS", "GOODBYE(Live) 评论", "我 知道 你 无意 想 绿 我 , 只是 忘 了 说 分手 , 只能 说 我 还是 太嫩 了 , 没想到 还是 会 被 影响 到 心情 , 我 是 真的 深深 被 你 打败 了"], ["KISS GOODBYE(Live)", "评论", "《 Kiss Goodbye 》 是 一首 朴实无华 、 自然 悦耳 的 抒情歌 , 歌曲 充分 展现 了 王力宏 自创 的 Chinked - out 音乐风格 的 独特 魅力 。"], ["KISS GOODBYE(Live)", "评论", "《 Kiss Goodbye 》 是 王氏 情歌 的 催泪 之作 ;"], ["KISS GOODBYE(Live)", "评论", "这 首歌曲 表达 了 恋人 每 一次 的 分离 都 让 人 难以 释怀 , 每 一次 “ Kiss Goodbye ” 都 让 人 更 期待 下 一次 的 相聚 。"], ["KISS GOODBYE(Live)", "评论", "王力宏 在 这 首歌 里 写出 了 恋人们 的 心声 , 抒发 了 恋人 之间 互相 思念 对方 的 痛苦 。"]], "user_profile": ["{\"姓名\": \"周明奇\", \"性别\": \"男\", \"居住地\": \"桂林\", \"年龄区间\": \"大于50\", \"职业状态\": \"工作\", \"喜欢 的 明星\": [\"舒淇\", \"周杰伦\"], \"喜欢 的 电影\": [\"刺客聂隐娘\"], \"喜欢 的 音乐\": [\"兰亭序\"], \"同意 的 美食\": \" 烤鱼\", \"同意 的 新闻\": \" 舒淇 的新闻; 周杰伦 的新闻\", \"拒绝\": [\"poi\"], \"接受 的 电影\": [], \"接受 的 音乐\": [], \"没有接受 的 电影\": [], \"没有接受 的 音乐\": []}"]}' 148 | s4 = r'{"situation": ["聊天 时间 : 中午 12 : 00 , 在 学校 聊天 主题 : 学习 退步"], "history": [], "goal": [["[1] 寒暄 ( Bot 主动 , 根据 给定 的 『 聊天 主题 』 寒暄 , 第一句 问候 要 带 User 名字 , 聊天 内容 不要 与 『 聊天 时间 』 矛盾 , 聊天 要 自然 , 不要 太 生硬 ) --> [2] 提问 ( Bot 主动 , 问 User 最 喜欢 的 电影 名 ? User 回答 ", " 最 喜欢 『 一起飞 』 ) --> [3] 提问 ( Bot 主动 , 问 User 最 喜欢 『 一起飞 』 的 哪个 主演 , 不 可以 问 User 『 一起飞 』 的 主演 是 谁 。 User 回答 ", " 最 喜欢 『 林志颖 』 ) --> [4] 关于 明星 的 聊天 ( Bot 主动 , 根据 给定 的 明星 信息 聊 『 林志颖 』 相关 内容 , 至少 要 聊 2 轮 , 避免 话题 切换 太 僵硬 , 不够 自然 ) --> [5] 电影 推荐 ( Bot 主动 , Bot 使用 『 天庭外传 』 的 某个 评论 当做 推荐 理由 来 推荐 『 天庭外传 』 , User 拒绝 , 拒绝 原因 可以 是 『 看过 、 暂时 不想 看 、 对 这个 电影 不感兴趣 或 其他 原因 』 ; Bot 使用 『 一屋哨牙鬼 』 的 某个 评论 当做 推荐 理由 来 推荐 『 一屋哨牙鬼 』 , User 先问 电影 『 国家 地区 、 导演 、 类型 、 主演 、 口碑 、 评分 』 中 的 一个 或 多个 , Bot 回答 , 最终 User 接受 。 注意 ", " 不要 在 一句 话 推荐 两个 电影 ) --> [6] 再见"]], "knowledge": [["张晓佳", "喜欢", "一起飞"], ["张晓佳", "喜欢", "林志颖"], ["林志颖", "出生地", "中国 台湾"], ["林志颖", "简介", "典型 的 完美 主义者"], ["林志颖", "简介", "娱乐圈 不老 男神"], ["林志颖", "简介", "隐形 富豪"], ["林志颖", "简介", "明星 艺人"], ["林志颖", "简介", "“ 完美 奶爸 ”"], ["林志颖", "体重", "58kg"], ["林志颖", "成就", "20 08 年度 最佳 公益 慈善 明星 典范"], ["林志颖", "成就", "华鼎奖 偶像 励志 类 最佳 男演员"], ["林志颖", "成就", "1996 年 马英九 颁赠 反毒 大使 奖章"], ["林志颖", "成就", "台湾 第一位 授薪 职业 赛车手"], ["林志颖", "成就", "最具 影响力 全能 偶像 艺人"], ["林志颖", "评论", "看过 了 他 的 个人 心路历程 , 很 值得 敬佩 !"], ["林志颖", "评论", "喜欢 他 的 娃娃脸 。 精美 的 五官 佩服 他 的 经历 , 他 , 偶像 !"], ["林志颖", "评论", "不 知道 该 怎么 来 形容 他 , 太 完美 的 一个 人 了 ! !"], ["林志颖", "评论", "喜欢 他 在 《 变 身 男女 》 中 与 姚笛 的 对手戏 。"], ["林志颖", "评论", "perfect “ boy ”"], ["林志 颖", "生日", "1974 - 10 - 15"], ["林志颖", "身高", "172cm"], ["林志颖", "星座", "天秤座"], ["林志颖", "血型", "O型"], ["林志颖", "属相", "虎"], ["林志颖", "主演", "一起飞"], ["林志 颖", "主演", "天庭外传"], ["林志颖", "主演", "一屋哨牙鬼"], ["天庭外传", "评论", "喜欢 这部 电影"], ["天庭外传", "评论", "没事 笑一笑 , 青春 永不 老"], ["天庭外传", "评论", "930 ", " 只能 笑笑 了 。 观影 方式 ", " VCD"], ["天庭外传", "评论", "这片 是 暑假 看 的 枪版 VCD 啊啊啊"], ["一屋哨牙鬼", "评论", "当时 看 还是 蛮 搞笑 的"], ["一屋哨牙鬼", "评论", "要是 提前 个 二十年 看 应该 能 感觉 好笑 吧 , 现在 实在 是 看不下去 了 !"], ["一屋哨牙鬼", "评论", "只是 为了 看 一下 这些 知名演员 的 年轻 时代 至于 剧情 不敢恭维"], ["一屋哨牙鬼", "评论", "大 烂片 啊 , 要不是 给 那么 多 明星 面子 , 肯定 一分 都 不 给"], ["一屋哨牙鬼", "评论", "怀念 那个 随随便便 都 能 出 不过 不失 作品 的 香港电影 黄金 年代 。"], ["一屋哨牙鬼", "国家地区", "中国香港"], ["一屋哨牙鬼", "导演", "曹建南 曾志伟"], ["一屋哨牙鬼", "类型", "恐怖 喜剧"], ["一屋哨牙鬼", "主演", "林志颖 朱茵 张卫健"], ["一屋哨牙鬼", "口 碑", "口碑 还 可以"], ["一屋哨牙鬼", "评分", "6.2"]], "user_profile": ["{\"姓名\": \"张晓佳\", \"性别\": \"女\", \"居住地\": \"保定\", \"年龄区间\": \"小于18\", \"职业状态\": \"学生\", \"喜欢 的 明星\": [\"林志颖\"], \"喜欢 的 电影\": [\"一起飞\"], \"喜欢 的 新闻\": [\"林志颖 的新闻\"], \"同意 的 美食\": \" 麻辣烫\", \"同意 的 poi\": \" 金权道韩式自助烤肉火锅\", \"拒绝\": [\"音乐\"], \"接受 的 电影\": [], \"没有接受 的 电影\": []}"]}' 149 | # print(m.predict(s)) 150 | # print(m.predict(s2)) 151 | print(m.predict(s4)) 152 | 153 | 154 | if __name__ == '__main__': 155 | test() -------------------------------------------------------------------------------- /code/predict/predict_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/4/29 21:41 6 | @File :predict_lm.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | from cfg import * 13 | from model.bert_lm import BertLM, Response 14 | from data_deal.base_input import BaseInput 15 | from data_deal.trans_output import TransOutput 16 | # from model.model_goal import BertGoal 17 | import jieba 18 | import time 19 | 20 | import argparse 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--type', type=int, 24 | help=r'default is 2', 25 | default=2) 26 | args = parser.parse_args(sys.argv[1:]) 27 | data_type = args.type 28 | save_dir = join(MODEL_PATH, 'BertLM_' + TAG) 29 | save_path = join(save_dir, 'trained.h5') 30 | if not os.path.isdir(OUT_PATH): 31 | os.makedirs(OUT_PATH) 32 | output_dir = join(OUT_PATH, 'out_{}_{}_{}.txt'.format( 33 | data_type, TAG, time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time())))) 34 | 35 | data_input = BaseInput(from_pre_trans=True) 36 | model_cls = BertLM(data_input.keep_tokens, load_path=save_path) 37 | response = Response(model_cls.model, 38 | model_cls.session, 39 | data_input, 40 | start_id=None, 41 | end_id=data_input.tokenizer._token_sep_id, 42 | maxlen=40 43 | ) 44 | goal_response = Response(model_cls.model, 45 | model_cls.session, 46 | data_input, 47 | start_id=None, 48 | end_id=data_input.tokenizer._token_goal_id, 49 | maxlen=10 50 | ) 51 | # out_trans = TransOutput(rc_tag=TAG) 52 | out_trans = TransOutput(rc_tag='') 53 | goal_dir = join(MODEL_PATH, 'Goal_' + TAG) 54 | goal_path = join(goal_dir, 'trained.h5') 55 | # goal_cls = BertGoal(data_input.keep_tokens, num_classes=len(data_input.reader.all_goals), load_path=goal_path) 56 | 57 | test_iter = data_input.get_sample( 58 | data_type, 59 | need_shuffle=False, 60 | cycle=False 61 | ) 62 | 63 | 64 | def cal_participle(samp: dict): 65 | words = [] 66 | words.extend(samp['situation'].split(' ')) 67 | words.extend(samp['goal'].split(' ')) 68 | for k, v in samp['user_profile'].items(): 69 | if not isinstance(v, list): 70 | v = [v] 71 | for _v in v: 72 | words.extend(_v.split(' ')) 73 | for kg in samp['knowledge']: 74 | words.extend(kg[2].split(' ')) 75 | words = set(words) 76 | words = [w for w in words if len(w) > 1] 77 | return words 78 | 79 | 80 | with open(output_dir, encoding='utf-8', mode='w') as fw: 81 | skip = 1374 82 | i = 0 83 | for sample in test_iter: 84 | i += 1 85 | # if i <= skip: 86 | # continue 87 | samp_words = cal_participle(sample) 88 | for w in samp_words: 89 | jieba.add_word(w) 90 | 91 | # has_goal = response.check_goal_end(sample, end_id=data_input.tokenizer._token_goal_id) 92 | # answer_res = response.generate(sample, has_goal=has_goal) 93 | goals = goal_response.goal_generate(sample, n=4) 94 | goals = list(set(goals)) 95 | # goals = [goal_cls.predict(sample)] 96 | answer_res = response.generate(sample, goals=goals) 97 | answer, tag = out_trans.trans_output(sample, answer_res) 98 | if tag: 99 | answer_res = response.generate(sample, goals=goals, random=True) 100 | for res in answer_res: 101 | answer, tag = out_trans.trans_output(sample, res) 102 | if not tag: 103 | break 104 | if tag: 105 | answer_res = response.generate(sample, goals=goals, force_goal=True, random=True) 106 | for res in answer_res: 107 | answer, tag = out_trans.trans_output(sample, res) 108 | if not tag: 109 | break 110 | e_i = 0 111 | if answer[0] == '[': 112 | for j in range(1, 4): 113 | if answer[j] == ']': 114 | e_i = j + 1 115 | break 116 | answer = answer[:e_i] + ' ' + ' '.join(jieba.lcut(answer[e_i:])) 117 | fw.writelines(answer + '\n') 118 | if i % 37 == 0: 119 | print('\rnum: {} '.format(i), end='') 120 | print('\n=====> Over: ', i) 121 | -------------------------------------------------------------------------------- /code/predict/predict_lm_ct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/19 21:55 6 | @File :predict_lm_ct.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | from cfg import * 13 | from model.bert_lm import BertLM, Response 14 | from data_deal.base_input import BaseInput 15 | from data_deal.input_ct import CTInput 16 | from data_deal.trans_output import TransOutput 17 | from model.model_context import ModelContext 18 | from model.model_recall import SearchEMb 19 | import jieba 20 | import time 21 | import numpy as np 22 | import re 23 | 24 | import argparse 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--type', type=int, 28 | help=r'default is 2', 29 | default=2) 30 | args = parser.parse_args(sys.argv[1:]) 31 | data_type = args.type 32 | save_dir = join(MODEL_PATH, 'BertLM_' + TAG) 33 | save_path = join(save_dir, 'trained.h5') 34 | if not os.path.isdir(OUT_PATH): 35 | os.makedirs(OUT_PATH) 36 | output_dir = join(OUT_PATH, 'out_{}_{}_{}.txt'.format( 37 | data_type, TAG, time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time())))) 38 | 39 | data_input = BaseInput(from_pre_trans=True) 40 | model_cls = BertLM(data_input.keep_tokens, load_path=save_path) 41 | response = Response(model_cls.model, 42 | model_cls.session, 43 | data_input, 44 | start_id=None, 45 | end_id=data_input.tokenizer._token_sep_id, 46 | maxlen=40 47 | ) 48 | goal_response = Response(model_cls.model, 49 | model_cls.session, 50 | data_input, 51 | start_id=None, 52 | end_id=data_input.tokenizer._token_goal_id, 53 | maxlen=10 54 | ) 55 | out_trans = TransOutput(rc_tag='') 56 | search_rc = SearchEMb(top_n=3) 57 | 58 | ct_dir = join(MODEL_PATH, 'CT_' + TAG) 59 | ct_path = join(ct_dir, 'trained.h5') 60 | ct_input = CTInput(from_pre_trans=False) 61 | model_ct_cls = ModelContext(ct_input.keep_tokens, load_path=ct_path) 62 | del ct_input 63 | 64 | test_iter = data_input.get_sample( 65 | data_type, 66 | need_shuffle=False, 67 | cycle=False 68 | ) 69 | 70 | 71 | def cal_participle(samp: dict): 72 | words = [] 73 | words.extend(samp['situation'].split(' ')) 74 | words.extend(samp['goal'].split(' ')) 75 | for k, v in samp['user_profile'].items(): 76 | if not isinstance(v, list): 77 | v = [v] 78 | for _v in v: 79 | words.extend(_v.split(' ')) 80 | for kg in samp['knowledge']: 81 | words.extend(kg[2].split(' ')) 82 | words = set(words) 83 | words = [w for w in words if len(w) > 1] 84 | return words 85 | 86 | 87 | with open(output_dir, encoding='utf-8', mode='w') as fw: 88 | skip = 1374 89 | i = 0 90 | for sample in test_iter: 91 | i += 1 92 | # if i <= skip: 93 | # continue 94 | samp_words = cal_participle(sample) 95 | for w in samp_words: 96 | jieba.add_word(w) 97 | 98 | goals = goal_response.goal_generate(sample, n=4) 99 | goals = list(set(goals)) 100 | history = sample['history'] 101 | final_answers = [] 102 | turn = 0 103 | while len(final_answers) <= 0: 104 | answer_res = response.generate(sample, goals=goals, random=True) 105 | score_mul = [1] * len(answer_res) 106 | if (len(history) > 1 and len(history[-1]) > 4 and '新闻' not in ''.join(history[-2:])) or turn > 0: 107 | rc_ans, rc_dis = search_rc.get_recall(history[-1]) 108 | answer_res.extend(rc_ans) 109 | score_mul = score_mul + np.minimum((1.0 - np.array(rc_dis)) * 0.5 + 1.0, 0.99).tolist() 110 | # 去重 转换 111 | mid_res_clean = [] 112 | mid_sc = [] 113 | for ans, sc in zip(answer_res, score_mul): 114 | sentence = re.sub(data_input.reader.goal_num_comp, '', ans) 115 | if sentence in mid_res_clean: 116 | continue 117 | trans_answer, tag = out_trans.trans_output(sample, ans) 118 | if tag: 119 | continue 120 | final_answers.append(trans_answer) 121 | mid_sc.append(sc) 122 | mid_res_clean.append(sentence) 123 | score_mul = mid_sc 124 | turn += 1 125 | if turn > 5: 126 | final_answers = ['是的呢'] 127 | score_mul = [1.0] 128 | logger.warning('No proper answer! \n{}'.format(history)) 129 | # CT score 130 | final_contexts = [history + [ans] for ans in final_answers] 131 | scores = model_ct_cls.predict(final_contexts) 132 | scores_md = np.multiply(scores, np.array(score_mul)) 133 | answer = final_answers[np.argmax(scores_md)] 134 | 135 | e_i = 0 136 | if answer[0] == '[': 137 | for j in range(1, 4): 138 | if answer[j] == ']': 139 | e_i = j + 1 140 | break 141 | answer = answer[:e_i] + ' ' + ' '.join(jieba.lcut(answer[e_i:])) 142 | fw.writelines(answer + '\n') 143 | if i % 37 == 0: 144 | print('\rnum: {} '.format(i), end='') 145 | print('\n=====> Over: ', i) 146 | -------------------------------------------------------------------------------- /code/score_fn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :apple.li 5 | @Time :2020/6/4 13:56 6 | @File :score_fn.py 7 | @Desc : 8 | """ 9 | 10 | 11 | def test_score(file_path): 12 | with open(file_path, encoding='utf-8') as fr: 13 | score = 0.0 14 | sample_num = 0 15 | history = [] 16 | while True: 17 | line = fr.readline() 18 | if not line: 19 | break 20 | line = line.strip() 21 | if line == '': 22 | print('\n'.join(history)) 23 | while True: 24 | s = input('Score ? ') 25 | try: 26 | s = float(s) 27 | break 28 | except ValueError: 29 | print('"s" Must be float type !') 30 | continue 31 | score += s 32 | sample_num += 1 33 | print('Score: {:.2f}. Num: {}. Ave score: {:.4f}\n'.format(score, sample_num, score / sample_num)) 34 | history = [] 35 | else: 36 | history.append(line) 37 | print('Final ===> Score: {:.2f}. Num: {}. Ave score: {:.4f}\n\n'.format(score, sample_num, score / sample_num)) 38 | 39 | 40 | if __name__ == '__main__': 41 | # test_score('../test_1_sample.txt') # 0.979 42 | test_score('../test_2_sample.txt') # 0.9650 43 | 44 | """ 45 | [1] 你 告诉 我 一下 几点 了 可以 吗 ? 46 | 现在 是 上午 8 点 哦 。 47 | 那 还好 , 迟 不了 , 谢谢 你 了 。 48 | 不客气哦,对了今天济南晴转多云,南风,最高气温:24℃,最低气温:14℃,注意保暖哦。 49 | 50 | [1] 你好 啊 , 麻烦 问 一下 现在 几点 了 ? 51 | 现在 是 20 18 年 10 月 17 日 , 上午 7 : 00 。 52 | 好 嘞 , 谢 啦 , 有 你 真 好 哦 。 53 | [2] 嘿嘿 , 能 帮到 你 我 很 开心 呢 , 天气 方面 要 不要 看看 呀 ? 54 | 好 啊 , 正想 问 你 呢 。 55 | 成都今天阴转小雨,无持续风向,最高气温:18℃,最低气温:14℃,注意保暖哦。 56 | 57 | [1] 我 想 问 有 关于 周杰伦 的 新闻 吗 ? 58 | 当然有啦。18日,周杰伦发布了新歌《等你下课》,勾起了大家对青春的回忆。除了周杰伦,你的青春日记里是否还有这些歌手?陳奕迅所長张惠妹aMEI_feat_阿密特刘若英梁静茹孙燕姿五月天王力宏……你还记得那些骑车上学听歌的岁月吗? 59 | """ 60 | -------------------------------------------------------------------------------- /code/test_answer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | from __future__ import print_function 4 | """ 5 | @Author :apple.li 6 | @Time :2020/6/4 13:39 7 | @File :test_answer.py 8 | @Desc : 9 | """ 10 | import sys 11 | import socket 12 | import importlib 13 | import json 14 | import re 15 | 16 | importlib.reload(sys) 17 | 18 | SERVER_IP = "127.0.0.1" 19 | SERVER_PORT = 8601 20 | goal_comp = re.compile('\[\d+\]\s*[^(]*') 21 | 22 | 23 | 24 | def conversation_client(text): 25 | """ 26 | conversation_client 27 | """ 28 | mysocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 29 | mysocket.connect((SERVER_IP, SERVER_PORT)) 30 | mysocket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096 * 5) 31 | mysocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4096 * 5) 32 | 33 | mysocket.sendall(text.encode()) 34 | result = mysocket.recv(4096 * 5).decode() 35 | 36 | mysocket.close() 37 | 38 | return result 39 | 40 | 41 | def main(file_path, out_path): 42 | """ 43 | main 44 | """ 45 | import numpy as np 46 | 47 | all_lines = [line for line in open(file_path, encoding='utf-8')] 48 | all_lines = np.random.choice(all_lines, size=200, replace=False) 49 | with open(out_path, encoding='utf-8', mode='w') as fw: 50 | for line in all_lines: 51 | sample = json.loads(line, encoding="utf-8") 52 | history = sample['history'] 53 | response = conversation_client(json.dumps(sample, ensure_ascii=False)) 54 | history.append(response) 55 | fw.writelines('\n'.join(history) + '\n\n') 56 | 57 | 58 | if __name__ == '__main__': 59 | main('../test_1.txt', '../test_1_sample.txt') 60 | main('../test_2.txt', '../test_2_sample.txt') 61 | -------------------------------------------------------------------------------- /code/test_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | from __future__ import print_function 4 | """ 5 | @Author :apple.li 6 | @Time :2020/5/21 17:52 7 | @File :test_client.py 8 | @Desc :File: conversation_client.py 9 | """ 10 | import sys 11 | import socket 12 | import importlib 13 | import json 14 | import re 15 | 16 | importlib.reload(sys) 17 | 18 | SERVER_IP = "127.0.0.1" 19 | SERVER_PORT = 8601 20 | goal_comp = re.compile('\[\d+\]\s*[^(]*') 21 | 22 | 23 | 24 | def conversation_client(text): 25 | """ 26 | conversation_client 27 | """ 28 | mysocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 29 | mysocket.connect((SERVER_IP, SERVER_PORT)) 30 | mysocket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096 * 5) 31 | mysocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4096 * 5) 32 | 33 | mysocket.sendall(text.encode()) 34 | result = mysocket.recv(4096 * 5).decode() 35 | 36 | mysocket.close() 37 | 38 | return result 39 | 40 | 41 | def main(): 42 | """ 43 | main 44 | """ 45 | if len(sys.argv) < 2: 46 | print("Usage: " + sys.argv[0] + " eval_file") 47 | exit() 48 | 49 | skip_n = 0 50 | for line in open(sys.argv[1], encoding='utf-8'): 51 | if skip_n > 0: 52 | skip_n -= 1 53 | continue 54 | sample = json.loads(line, encoding="utf-8") 55 | # all_goals = goal_comp.findall(sample['goal']) 56 | # all_goals = [[s[3:].strip(), '', ''] for s in all_goals] 57 | # sample['goal'] = all_goals 58 | print('\n\n' + '=' * 20) 59 | print('New goal: ', sample['goal']) 60 | history = sample['history'] 61 | 62 | print('Ori history:' + '\n'.join(history)) 63 | 64 | s = input('E :bot first; U: user first; C:continue; Z: quit; 其它:继续原有的样本对话历史\n') 65 | if s in ['E', 'e']: 66 | history = [] 67 | elif s in ['U', 'u']: 68 | s = input('Q:') 69 | history = [s] 70 | elif s in ['Z', 'z']: 71 | return 72 | elif s in ['continue', 'c', 'C']: 73 | continue 74 | else: 75 | try: 76 | s = int(s) 77 | skip_n = s 78 | print('Input: {}. Skip {} times'.format(s, skip_n)) 79 | continue 80 | except ValueError: 81 | print('Input: {}. Process dialogue'.format(s)) 82 | pass 83 | sample['history'] = history 84 | 85 | response = conversation_client(json.dumps(sample, ensure_ascii=False)) 86 | print('A: ', response) 87 | while True: 88 | question = input('Q:') 89 | if question in ['continue', 'c', 'C']: 90 | break 91 | if question in ['q', 'Q', 'Z', 'z']: 92 | return 93 | history.append(response) 94 | history.append(question) 95 | sample['history'] = history 96 | response = conversation_client(json.dumps(sample, ensure_ascii=False)) 97 | print('A: ', response) 98 | 99 | 100 | 101 | if __name__ == '__main__': 102 | try: 103 | main() 104 | except KeyboardInterrupt: 105 | print("\nExited from the program ealier!") 106 | -------------------------------------------------------------------------------- /code/train/train_bert_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/4/28 22:21 6 | @File :train_bert_lm.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | from cfg import * 13 | from model.bert_lm import BertLM, Response 14 | from data_deal.base_input import BaseInput 15 | from bert4keras_5_8.backend import keras 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--init_epoch', type=int, 20 | help=r'init epoch, you don\'t know ?', 21 | default=0) 22 | parser.add_argument('--epoch', type=int, 23 | help=r'init epoch, you don\'t know ?', 24 | default=3) 25 | args = parser.parse_args(sys.argv[1:]) 26 | 27 | batch_size = 4 28 | steps_per_epoch = 120 29 | 30 | epoches = int(args.epoch * totle_sample / batch_size / steps_per_epoch) 31 | init_epoch = args.init_epoch 32 | 33 | save_dir = join(MODEL_PATH, 'BertLM_' + TAG) 34 | if not os.path.isdir(save_dir): 35 | os.makedirs(save_dir) 36 | save_path = join(save_dir, 'trained.h5') 37 | 38 | data_input = BaseInput(from_pre_trans=True) 39 | 40 | model_cls = BertLM(data_input.keep_tokens, load_path=save_path) 41 | model_cls.compile() 42 | 43 | 44 | class LogRecord(keras.callbacks.Callback): 45 | def __init__(self): 46 | super(LogRecord, self).__init__() 47 | self._step = 1 48 | self.lowest = 1e10 49 | self.test_iter = data_input.get_sample( 50 | 3, 51 | need_shuffle=False, 52 | cycle=True 53 | ) 54 | self.response = Response(model_cls.model, 55 | model_cls.session, 56 | data_input, 57 | start_id=None, 58 | end_id=data_input.tokenizer._token_sep_id, 59 | maxlen=30 60 | ) 61 | 62 | def on_epoch_end(self, epoch, logs=None): 63 | for i in range(2): 64 | sample = next(self.test_iter) 65 | res = self.response.generate(sample) 66 | logger.info('==============') 67 | logger.info('Context: {}'.format(sample['history'])) 68 | logger.info('Goal: {}'.format(sample['goal'])) 69 | logger.info('Answer: {}\n'.format(res)) 70 | for j in range(7): 71 | # 很多重复的 72 | next(self.test_iter) 73 | 74 | def on_batch_end(self, batch, logs=None): 75 | self._step += 1 76 | if self._step % 20 == 0: 77 | logger.info('step: {} loss: {} '.format(self._step, logs['loss'])) 78 | 79 | 80 | checkpoint_callback = keras.callbacks.ModelCheckpoint( 81 | save_path, monitor='val_loss', verbose=0, save_best_only=False, 82 | save_weights_only=True, mode='min', period=3) 83 | tensorboard_callback = keras.callbacks.TensorBoard( 84 | log_dir=join(save_dir, 'tf_logs'), histogram_freq=0, write_graph=False, 85 | write_grads=False, update_freq=320) 86 | 87 | model_cls.model.fit_generator( 88 | data_input.generator( 89 | batch_size=batch_size, 90 | data_type=train_list, 91 | need_shuffle=True, 92 | cycle=True 93 | ), 94 | validation_data=data_input.generator( 95 | batch_size=batch_size, 96 | data_type=1, 97 | need_shuffle=True, 98 | cycle=True 99 | ), 100 | validation_steps=10, 101 | validation_freq=1, 102 | steps_per_epoch=steps_per_epoch, 103 | epochs=epoches, 104 | initial_epoch=init_epoch, 105 | verbose=2, 106 | class_weight=None, 107 | callbacks=[ 108 | checkpoint_callback, 109 | tensorboard_callback, 110 | LogRecord() 111 | ] 112 | ) 113 | -------------------------------------------------------------------------------- /code/train/train_ct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/19 22:28 6 | @File :train_ct.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | from cfg import * 13 | from model.model_context import ModelContext 14 | from data_deal.input_ct import CTInput 15 | from bert4keras_5_8.backend import keras 16 | import numpy as np 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--init_epoch', type=int, 21 | help=r'init epoch, you don\'t know ?', 22 | default=0) 23 | parser.add_argument('--epoch', type=int, 24 | help=r'init epoch, you don\'t know ?', 25 | default=6) 26 | args = parser.parse_args(sys.argv[1:]) 27 | 28 | batch_size = 4 29 | steps_per_epoch = 600 30 | 31 | epoches = int(args.epoch * totle_sample / batch_size / steps_per_epoch * 4.67) 32 | init_epoch = args.init_epoch 33 | 34 | save_dir = join(MODEL_PATH, 'CT_' + TAG) 35 | if not os.path.isdir(save_dir): 36 | os.makedirs(save_dir) 37 | save_path = join(save_dir, 'trained.h5') 38 | 39 | data_input = CTInput(from_pre_trans=False) 40 | 41 | model_cls = ModelContext(data_input.keep_tokens, load_path=save_path) 42 | model_cls.compile() 43 | 44 | 45 | class LogRecord(keras.callbacks.Callback): 46 | def __init__(self): 47 | super(LogRecord, self).__init__() 48 | self._step = 1 49 | self.lowest = 1e10 50 | self.test_iter = data_input.generator( 51 | 4, 52 | data_type=3, 53 | need_shuffle=False, 54 | cycle=True, 55 | need_douban=False 56 | ) 57 | 58 | def on_epoch_end(self, epoch, logs=None): 59 | [X, S], L = next(self.test_iter) 60 | result = model_cls.model.predict([X, S]) 61 | for x, l, r in zip(X, L, result): 62 | print(' '.join(data_input.tokenizer.ids_to_tokens(x)).rstrip(' [PAD]')) 63 | print('label: {} predict: {}\n'.format(l, np.argmax(r))) 64 | 65 | def on_batch_end(self, batch, logs=None): 66 | self._step += 1 67 | if self._step % 60 == 0: 68 | logger.info('step: {} loss: {} '.format(self._step, logs['loss'])) 69 | 70 | 71 | checkpoint_callback = keras.callbacks.ModelCheckpoint( 72 | save_path, monitor='val_loss', verbose=0, save_best_only=False, 73 | save_weights_only=True, mode='min', period=2) 74 | tensorboard_callback = keras.callbacks.TensorBoard( 75 | log_dir=join(save_dir, 'tf_logs'), histogram_freq=0, write_graph=False, 76 | write_grads=False, update_freq=320) 77 | 78 | model_cls.model.fit_generator( 79 | data_input.generator( 80 | batch_size=batch_size, 81 | data_type=train_list, 82 | need_shuffle=True, 83 | cycle=True 84 | ), 85 | validation_data=data_input.generator( 86 | batch_size=batch_size, 87 | data_type=1, 88 | need_shuffle=True, 89 | cycle=True 90 | ), 91 | validation_steps=10, 92 | validation_freq=1, 93 | steps_per_epoch=steps_per_epoch, 94 | epochs=epoches, 95 | initial_epoch=init_epoch, 96 | verbose=2, 97 | class_weight=None, 98 | callbacks=[ 99 | checkpoint_callback, 100 | tensorboard_callback, 101 | LogRecord() 102 | ] 103 | ) 104 | -------------------------------------------------------------------------------- /code/train/train_goal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/8 22:20 6 | @File :train_goal.py 7 | @Desc : 8 | """ 9 | import os, sys 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | 13 | from cfg import * 14 | from model.model_goal import BertGoal 15 | from data_deal.input_goal import GoalInput 16 | from bert4keras_5_8.backend import keras 17 | import argparse 18 | import numpy as np 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--init_epoch', type=int, 22 | help=r'init epoch, you don\'t know ?', 23 | default=0) 24 | parser.add_argument('--epoch', type=int, 25 | help=r'init epoch, you don\'t know ?', 26 | default=3) 27 | args = parser.parse_args(sys.argv[1:]) 28 | 29 | batch_size = 4 30 | steps_per_epoch = 120 31 | 32 | epoches = int(args.epoch * totle_sample / batch_size / steps_per_epoch) 33 | init_epoch = args.init_epoch 34 | 35 | save_dir = join(MODEL_PATH, 'Goal_' + TAG) 36 | if not os.path.isdir(save_dir): 37 | os.makedirs(save_dir) 38 | save_path = join(save_dir, 'trained.h5') 39 | 40 | data_input = GoalInput() 41 | 42 | model_cls = BertGoal(data_input.keep_tokens, num_classes=len(data_input.reader.all_goals), load_path=save_path) 43 | model_cls.compile() 44 | 45 | 46 | class LogRecord(keras.callbacks.Callback): 47 | def __init__(self): 48 | super(LogRecord, self).__init__() 49 | self._step = 1 50 | self.lowest = 1e10 51 | self.test_iter = data_input.generator( 52 | batch_size=2, 53 | data_type=3, 54 | need_shuffle=False, 55 | cycle=True 56 | ) 57 | 58 | def on_epoch_end(self, epoch, logs=None): 59 | X, L = next(self.test_iter) 60 | T = X[0] 61 | for i in range(len(T)): 62 | res = model_cls.model.predict(X) 63 | logger.info('==============') 64 | logger.info('Context: {}'.format(data_input.tokenizer.decode(T[i]))) 65 | logger.info('Goal: {} {}'.format(L[i], data_input.reader.all_goals[L[i]])) 66 | logger.info('Answer: {} {}\n'.format(np.argmax(res[i]), 67 | data_input.reader.all_goals[np.argmax(res[i])])) 68 | 69 | def on_batch_end(self, batch, logs=None): 70 | self._step += 1 71 | if self._step % 20 == 0: 72 | logger.info('step: {} loss: {} '.format(self._step, logs['loss'])) 73 | 74 | 75 | checkpoint_callback = keras.callbacks.ModelCheckpoint( 76 | save_path, monitor='val_loss', verbose=0, save_best_only=False, 77 | save_weights_only=True, mode='min', period=3) 78 | tensorboard_callback = keras.callbacks.TensorBoard( 79 | log_dir=join(save_dir, 'tf_logs'), histogram_freq=0, write_graph=False, 80 | write_grads=False, update_freq=320) 81 | 82 | model_cls.model.fit_generator( 83 | data_input.generator( 84 | batch_size=batch_size, 85 | data_type=train_list, 86 | need_shuffle=True, 87 | cycle=True 88 | ), 89 | validation_data=data_input.generator( 90 | batch_size=batch_size, 91 | data_type=1, 92 | need_shuffle=True, 93 | cycle=True 94 | ), 95 | validation_steps=10, 96 | validation_freq=1, 97 | steps_per_epoch=steps_per_epoch, 98 | epochs=epoches, 99 | initial_epoch=init_epoch, 100 | verbose=2, 101 | class_weight=None, 102 | callbacks=[ 103 | checkpoint_callback, 104 | tensorboard_callback, 105 | LogRecord() 106 | ] 107 | ) 108 | -------------------------------------------------------------------------------- /code/train/train_rc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Time : 2020/5/10 11:10 5 | @Author : Apple QXTD 6 | @File : train_rc.py 7 | @Desc: : 8 | """ 9 | import os, sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from cfg import * 12 | from bert4keras_5_8.backend import keras 13 | from data_deal.input_rc import RCInput 14 | from model.model_rc import BertCL 15 | 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--init_epoch', type=int, 20 | help=r'init epoch, you don\'t know ?', 21 | default=0) 22 | parser.add_argument('--epoch', type=int, 23 | help=r'init epoch, you don\'t know ?', 24 | default=3) 25 | args = parser.parse_args(sys.argv[1:]) 26 | 27 | 28 | steps_per_epoch = 120 29 | batch_size = 3 30 | 31 | epoches = int(args.epoch * totle_sample / batch_size / steps_per_epoch) 32 | 33 | save_dir = join(MODEL_PATH, 'rc_' + TAG) 34 | save_path = join(save_dir, 'trained.h5') 35 | if not os.path.isdir(save_dir): 36 | os.makedirs(save_dir) 37 | 38 | data_input = RCInput(from_pre_trans=True) 39 | model_cls = BertCL(tag=TAG) 40 | model_cls.compile() 41 | 42 | tokenizer = data_input.tokenizer 43 | max_p_len = data_input.max_p_len 44 | max_q_len = data_input.max_q_len 45 | max_a_len = data_input.max_a_len 46 | 47 | 48 | class Evaluate(keras.callbacks.Callback): 49 | def __init__(self, ): 50 | super(Evaluate, self).__init__() 51 | self.wait = 0 52 | self.stopped_epoch = 0 53 | self.best_weights = None 54 | 55 | self.eva_iter = data_input.get_sample(data_files=3, cycle=True) 56 | self._step = 0 57 | 58 | def on_batch_end(self, batch, logs=None): 59 | self._step += 1 60 | if self._step % 100 == 0: 61 | logger.info(' step: {} loss: {} '.format(self._step, logs['loss'])) 62 | 63 | def on_epoch_end(self, epoch, logs=None): 64 | sample = next(self.eva_iter) 65 | samples = data_input.get_rc_sample(sample) 66 | for sample in samples: 67 | for q_key, answer in sample['result'].items(): 68 | if q_key not in sample['replace_dict'].keys(): 69 | continue 70 | context = sample['replace_dict'][q_key] # it's a list 71 | context = '|'.join(context).replace(' ', '') # 全部使用 | 作为分隔符 72 | question = '|'.join(sample['history']).replace(' ', '') # 全部使用 | 作为分隔符 73 | answer = answer.replace(' ', '') 74 | question += '|{}'.format(q_key) # question额外添加询问的标记 75 | predict_answer = model_cls.predict(question, context) 76 | 77 | logger.info('Context: {}'.format(context)) 78 | logger.info('Question: {}'.format(question)) 79 | logger.info('Answer: {}'.format(answer)) 80 | logger.info('Gen Answer: {}'.format(predict_answer)) 81 | logger.info('') 82 | 83 | 84 | if __name__ == '__main__': 85 | evaluator = Evaluate() 86 | 87 | checkpoint_callback = keras.callbacks.ModelCheckpoint( 88 | save_path, monitor='loss', verbose=0, save_best_only=False, 89 | save_weights_only=True, mode='min', period=1) 90 | tensorboard_callback = keras.callbacks.TensorBoard( 91 | log_dir=join(save_dir, 'tf_logs'), histogram_freq=0, write_graph=False, 92 | write_grads=False, update_freq=160) 93 | early_stop_callback = keras.callbacks.EarlyStopping() 94 | model_cls.model.fit_generator( 95 | data_input.generator( 96 | batch_size=batch_size, 97 | data_type=train_list, 98 | need_shuffle=True, 99 | cycle=True 100 | ), 101 | steps_per_epoch=steps_per_epoch, 102 | epochs=epoches, 103 | verbose=2, 104 | callbacks=[ 105 | checkpoint_callback, 106 | tensorboard_callback, 107 | evaluator, 108 | ]) 109 | -------------------------------------------------------------------------------- /code/train/z_t.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :apple.li 5 | @Time :2020/5/8 18:15 6 | @File :z_t.py 7 | @Desc : 8 | """ 9 | 10 | import os, sys 11 | import re 12 | 13 | birthday_comp = re.compile('\d[\d ]{3,}(-)[\d ]+(-)[\d ]+') 14 | file = r'../../output/out_2020-05-10_15-23-31.txt' 15 | file_w = r'../../output/out_2020-05-10_15-23-31-rn.txt' 16 | i = 0 17 | with open(file, encoding='utf-8') as fr: 18 | with open(file_w, mode='w', encoding='utf-8') as fw: 19 | while True: 20 | i += 1 21 | line = fr.readline() 22 | if not line: 23 | break 24 | else: 25 | line = line.strip() 26 | sp = birthday_comp.search(line) 27 | if sp: 28 | idx_0 = sp.group().find('-') 29 | idx_1 = sp.group()[idx_0 + 1:].find('-') + idx_0 + 1 30 | sp_str = list(sp.group()) 31 | sp_str[idx_0] = '年' 32 | sp_str[idx_1] = '月' 33 | sp_str = ''.join(sp_str) + '日' 34 | line_after = line[:sp.span()[0]] + sp_str 35 | if sp.span()[1] < len(line) and line[sp.span()[1]] == '号': 36 | line_after = line_after[:-1] + line[sp.span()[1]:] 37 | elif sp.span()[1] + 1 < len(line) and line[sp.span()[1]:sp.span()[1] + 2] == ' 号': 38 | line_after = line_after[:-1] + line[sp.span()[1] + 1:] 39 | else: 40 | line_after += line[sp.span()[1]:] 41 | line = line_after 42 | fw.writelines(line + '\n') 43 | if i % 43 == 0: 44 | print('\r', i , end=' ') -------------------------------------------------------------------------------- /code/utils/sif.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :apple.li 5 | @Time :2019/10/18 21:57 6 | @File :sif.py 7 | @Desc : 8 | """ 9 | 10 | # Author: Oliver Borchers 11 | # Copyright (C) 2019 Oliver Borchers 12 | 13 | from gensim.models.base_any2vec import BaseWordEmbeddingsModel 14 | from gensim.models.keyedvectors import BaseKeyedVectors 15 | 16 | from gensim.matutils import unitvec 17 | 18 | from sklearn.decomposition import TruncatedSVD 19 | from wordfreq import get_frequency_dict 20 | 21 | from six.moves import xrange 22 | 23 | import logging 24 | import warnings 25 | import psutil 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | from numpy import float32 as REAL, sum as np_sum, vstack, zeros, ones, \ 30 | dtype, sqrt, newaxis, empty, full, expand_dims 31 | 32 | EPS = 1e-8 33 | 34 | CY_ROUTINES = 0 35 | 36 | 37 | def s2v_train(sentences, len_sentences, outer_vecs, max_seq_len, wv, weights): 38 | """Train sentence embedding on a list of sentences 39 | 40 | Called internally from :meth:`~fse.models.sentence2vec.Sentence2Vec.train`. 41 | 42 | Parameters 43 | ---------- 44 | sentences : iterable of list of str 45 | The corpus used to train the model. 46 | len_sentences : int 47 | Length of the sentence iterable 48 | wv : :class:`~gensim.models.keyedvectors.BaseKeyedVectors` 49 | The BaseKeyedVectors instance containing the vectors used for training 50 | weights : np.ndarray 51 | Weights used in the summation of the vectors 52 | 53 | Returns 54 | ------- 55 | np.ndarray 56 | The sentence embedding matrix of dim len(sentences) * vector_size 57 | int 58 | Number of words in the vocabulary actually used for training. 59 | int 60 | Number of sentences used for training. 61 | """ 62 | size = wv.vector_size 63 | vlookup = wv.vocab 64 | 65 | w_trans = weights[:, None] 66 | 67 | output = empty((len_sentences, size), dtype=REAL) 68 | for i in range(len_sentences): 69 | output[i] = full(size, EPS, dtype=REAL) 70 | 71 | effective_words = 0 72 | effective_sentences = 0 73 | 74 | for i, s in enumerate(sentences): 75 | sentence_idx = [vlookup[w].index for w in s if w in vlookup] 76 | if len(sentence_idx): 77 | v = np_sum(outer_vecs[ 78 | i][1:min(max_seq_len, len(sentence_idx) + 1), :] * 79 | w_trans[sentence_idx[:max_seq_len - 1]], axis=0) 80 | effective_words += len(sentence_idx) 81 | effective_sentences += 1 82 | v *= 1 / len(sentence_idx) 83 | v /= sqrt(np_sum(v.dot(v))) 84 | output[i] = v 85 | 86 | return output.astype(REAL), effective_words, effective_sentences 87 | 88 | 89 | class Sentence2Vec(): 90 | """Compute smooth inverse frequency weighted or averaged sentence emeddings. 91 | 92 | This implementation is based on the 2017 ICLR paper (https://openreview.net/pdf?id=SyK00v5xx): 93 | Arora S, Liang Y, Ma T (2017) A Simple but Tough-to-Beat Baseline for Sentence Embeddings. Int. Conf. Learn. Represent. (Toulon, France), 1–16. 94 | All corex routines are optimized based on the Gensim routines (https://github.com/RaRe-Technologies/gensim) 95 | 96 | Attributes 97 | ---------- 98 | model : :class:`~gensim.models.keyedvectors.BaseKeyedVectors` or :class:`~gensim.models.keyedvectors.BaseWordEmbeddingsModel` 99 | This object essentially contains the mapping between words and embeddings. To compute the sentence embeddings 100 | the wv.vocab and wv.vector elements are required. 101 | 102 | numpy.ndarray : sif_weights 103 | Contains the pre-computed SIF weights. 104 | """ 105 | 106 | def __init__(self, model, max_seq_len, alpha=1e-3, components=1, no_frequency=False, lang="en"): 107 | """ 108 | 109 | Parameters 110 | ---------- 111 | model : :class:`~gensim.models.keyedvectors.BaseKeyedVectors` or :class:`~gensim.models.keyedvectors.BaseWordEmbeddingsModel` 112 | This object essentially contains the mapping between words and embeddings. To compute the sentence embeddings 113 | the wv.vocab and wv.vector elements are required. 114 | alpha : float, optional 115 | Parameter which is used to weigh each individual word based on its probability p(w). 116 | If alpha = 1, train simply computes the averaged sentence representation. 117 | components : int, optional 118 | Number of principal components to remove from the sentence embeddings. Independent of alpha. 119 | no_frequency : bool, optional 120 | Some pre-trained embeddings, i.e. "GoogleNews-vectors-negative300.bin", do not contain information about 121 | the frequency of a word. As the frequency is required for estimating the weights, no_frequency induces 122 | into the wv.vocab.count class based on :class:`~wordfreq` 123 | lang : str, optional 124 | If no frequency information is available, you can choose the language to estimate the frequency. 125 | See https://github.com/LuminosoInsight/wordfreq 126 | 127 | Returns 128 | ------- 129 | numpy.ndarray 130 | Sentence embedding matrix of dim len(sentences) * dimension 131 | 132 | Examples 133 | -------- 134 | Initialize and train a :class:`~fse.models.sentence2vec.Sentence2Vec` model 135 | 136 | """ 137 | 138 | if isinstance(model, BaseWordEmbeddingsModel): 139 | self.model = model.wv 140 | elif isinstance(model, BaseKeyedVectors): 141 | self.model = model 142 | else: 143 | raise RuntimeError("Model must be child of BaseWordEmbeddingsModel or BaseKeyedVectors.") 144 | 145 | if not hasattr(self.model, 'vectors'): 146 | raise RuntimeError("Parameters required for predicting sentence embeddings not found.") 147 | 148 | assert alpha >= 0 & components >= 0 149 | 150 | self.alpha = float(alpha) 151 | self.components = int(components) 152 | self.no_frequency = bool(no_frequency) 153 | self.lang = str(lang) 154 | 155 | self.sif_weights = self._precompute_sif_weights(self.model, self.alpha, no_frequency, lang) 156 | self.pc = None 157 | self.max_seq_len = max_seq_len 158 | 159 | def _compute_principal_component(self, vectors, npc=1): 160 | """Compute the n principal components for the sentence embeddings 161 | 162 | Notes 163 | ----- 164 | Adapted from https://github.com/PrincetonML/SIF/blob/master/src/SIF_embedding.py 165 | 166 | Parameters 167 | ---------- 168 | vectors : numpy.ndarray 169 | The sentence embedding matrix of dim len(sentences) * vector_size. 170 | npc : int, optional 171 | The number of principal components to be computed. Default : 1. 172 | 173 | Returns 174 | ------- 175 | numpy.ndarray 176 | The principal components as computed by the TruncatedSVD 177 | 178 | """ 179 | logger.info("computing %d principal components", npc) 180 | svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0, algorithm="randomized") 181 | svd.fit(vectors) 182 | return svd.components_ 183 | 184 | def _remove_principal_component(self, vectors, npc=1, train_pc=True): 185 | """Remove the projection from the sentence embeddings 186 | 187 | Notes 188 | ----- 189 | Adapted from https://github.com/PrincetonML/SIF/blob/master/src/SIF_embedding.py 190 | 191 | Parameters 192 | ---------- 193 | vectors : numpy.ndarray 194 | The sentence embedding matrix of dim len(sentences) * vector_size. 195 | npc : int, optional 196 | The number of principal components to be computed. Default : 1. 197 | 198 | Returns 199 | ------- 200 | numpy.ndarray 201 | The sentence embedding matrix of dim len(sentences) * vector size after removing the projection 202 | 203 | """ 204 | if not train_pc and self.pc is None: 205 | raise RuntimeError('not trained!') 206 | if train_pc: 207 | self.pc = self._compute_principal_component(vectors, npc) 208 | logger.debug("removing %d principal components", npc) 209 | if npc == 1: 210 | vectors_rpc = vectors - vectors.dot(self.pc.transpose()) * self.pc 211 | else: 212 | vectors_rpc = vectors - vectors.dot(self.pc.transpose()).dot(self.pc) 213 | # sum_of_vecs = sqrt(np_sum(vectors_rpc * vectors_rpc, axis=-1)) 214 | # sum_of_vecs = expand_dims(sum_of_vecs, axis=1) 215 | # vectors_rpc /= sum_of_vecs 216 | return vectors_rpc 217 | 218 | def _precompute_sif_weights(self, wv, alpha=1e-3, no_frequency=False, lang="en"): 219 | """Precompute the weights used in the vector summation 220 | 221 | Parameters 222 | ---------- 223 | wv : `~gensim.models.keyedvectors.BaseKeyedVectors` 224 | A gensim keyedvectors child that contains the word vectors and the vocabulary 225 | alpha : float, optional 226 | Parameter which is used to weigh each individual word based on its probability p(w). 227 | If alpha = 0, the model computes the average sentence embedding. Common values range from 1e-5 to 1e-1. 228 | For more information, see the original paper. 229 | no_frequency : bool, optional 230 | Use a the commonly available frequency table if the Gensim model does not contain information about 231 | the frequency of the words (see model.wv.vocab.count). 232 | lang : str, optional 233 | Determines the language of the frequency table used to compute the weights. 234 | 235 | Returns 236 | ------- 237 | numpy.ndarray 238 | The vector of weights for all words in the model vocabulary 239 | 240 | """ 241 | logger.info("pre-computing SIF weights") 242 | 243 | if no_frequency: 244 | logger.info("no frequency mode: using wordfreq for estimation (lang=%s)", lang) 245 | freq_dict = get_frequency_dict(str(lang), wordlist='best') 246 | 247 | for w in wv.index2word: 248 | if w in freq_dict: 249 | wv.vocab[w].count = int(freq_dict[w] * (2 ** 31 - 1)) 250 | else: 251 | wv.vocab[w].count = 1 252 | 253 | if alpha > 0: 254 | corpus_size = 0 255 | # Set the dtype correct for cython estimation 256 | sif = zeros(shape=len(wv.vocab), dtype=REAL) 257 | 258 | for k in wv.index2word: 259 | # Compute normalization constant 260 | corpus_size += wv.vocab[k].count 261 | 262 | for idx, k in enumerate(wv.index2word): 263 | pw = wv.vocab[k].count / corpus_size 264 | sif[idx] = alpha / (alpha + pw) 265 | else: 266 | sif = ones(shape=len(wv.vocab), dtype=REAL) 267 | 268 | return sif 269 | 270 | def _estimate_memory(self, len_sentences, vocab_size, vector_size): 271 | """Estimate the size of the embedding in memoy 272 | 273 | Notes 274 | ----- 275 | Directly adapted from gensim 276 | 277 | Parameters 278 | ---------- 279 | len_sentences : int 280 | Length of the sentences iterable 281 | vocab_size : int 282 | Size of the vocabulary 283 | vector_size : int 284 | Vector size of the sentence embedding 285 | 286 | Returns 287 | ------- 288 | dict 289 | Dictionary of esitmated sizes 290 | 291 | """ 292 | report = {} 293 | report["sif_weights"] = vocab_size * dtype(REAL).itemsize 294 | report["sentence_vectors"] = len_sentences * vector_size * dtype(REAL).itemsize 295 | report["total"] = sum(report.values()) 296 | mb_size = int(report["sentence_vectors"] / 1024 ** 2) 297 | logger.info( 298 | "estimated required memory for %i sentences and %i dimensions: %i MB (%i GB)", 299 | len_sentences, 300 | vector_size, 301 | mb_size, 302 | int(mb_size / 1024) 303 | ) 304 | 305 | if report["total"] >= 0.95 * psutil.virtual_memory()[1]: 306 | warnings.warn("Sentence2Vec: The sentence embeddings will likely not fit into RAM.") 307 | 308 | return report 309 | 310 | def normalize(self, sentence_matrix, inplace=True): 311 | """Normalize the sentence_matrix rows to unit_length 312 | 313 | Notes 314 | ----- 315 | Directly adapted from gensim 316 | 317 | Parameters 318 | ---------- 319 | sentence_matrix : numpy.ndarray 320 | The sentence embedding matrix of dim len(sentences) * vector_size 321 | inplace : bool, optional 322 | 323 | Returns 324 | ------- 325 | numpy.ndarray 326 | The sentence embedding matrix of dim len(sentences) * vector_size 327 | """ 328 | logger.info("computing L2-norms of sentence embeddings") 329 | if inplace: 330 | for i in xrange(len(sentence_matrix)): 331 | sentence_matrix[i, :] /= sqrt((sentence_matrix[i, :] ** 2).sum(-1)) 332 | else: 333 | output = (sentence_matrix / sqrt((sentence_matrix ** 2).sum(-1))[..., newaxis]).astype(REAL) 334 | return output 335 | 336 | def cal_output(self, sentences, outer_vecs, **kwargs): 337 | """Train the model on sentences 338 | 339 | Parameters 340 | ---------- 341 | outer_vecs: shape=[N, S, F] 342 | sentences : iterable of list of str 343 | The `sentences` iterable can be simply a list of lists of tokens, but for larger corpora, 344 | consider an iterable that streams the sentences directly from disk/network. 345 | 346 | Returns 347 | ------- 348 | numpy.ndarray 349 | The sentence embedding matrix of dim len(sentences) * vector_size 350 | """ 351 | assert len(outer_vecs[0].shape) == 2, 'outer_vecs error: assert len(outer_vecs[0].shape) == 2' 352 | 353 | if sentences is None: 354 | raise RuntimeError("Provide sentences object") 355 | 356 | len_sentences = 0 357 | if not hasattr(sentences, '__len__'): 358 | len_sentences = sum(1 for _ in sentences) 359 | else: 360 | len_sentences = len(sentences) 361 | 362 | if len_sentences == 0: 363 | raise RuntimeError("Sentences must be non-empty") 364 | 365 | self._estimate_memory(len_sentences, len(self.model.vocab), self.model.vector_size) 366 | 367 | output, no_words, no_sents = s2v_train(sentences, len_sentences, outer_vecs, self.max_seq_len 368 | , self.model, self.sif_weights) 369 | 370 | logger.debug("finished computing sentence embeddings of %i effective sentences with %i effective words", 371 | no_sents, no_words) 372 | return output 373 | 374 | def train_pc(self, sentence_vectors): 375 | """Train the model on sentences 376 | 377 | Parameters 378 | ---------- 379 | inputs : Sentences Vectors 380 | This func train the pc vectors only 381 | Returns 382 | ------- 383 | numpy.ndarray 384 | The sentence embedding matrix of dim len(sentences) * vector_size 385 | """ 386 | if self.components > 0: 387 | sentence_vectors = self._remove_principal_component(sentence_vectors, self.components) 388 | else: 389 | logger.info('No need to train pc') 390 | 391 | return sentence_vectors 392 | 393 | def predict_pc(self, output): 394 | if self.components > 0: 395 | output = self._remove_principal_component(output, self.components, train_pc=False) 396 | 397 | return output 398 | 399 | def train(self, sentences, outer_vecs, **kwargs): 400 | """Train the model on sentences 401 | 402 | Parameters 403 | ---------- 404 | sentences : iterable of list of str 405 | The `sentences` iterable can be simply a list of lists of tokens, but for larger corpora, 406 | consider an iterable that streams the sentences directly from disk/network. 407 | 408 | Returns 409 | ------- 410 | numpy.ndarray 411 | The sentence embedding matrix of dim len(sentences) * vector_size 412 | """ 413 | output = self.cal_output(sentences, outer_vecs, **kwargs) 414 | 415 | if self.components > 0: 416 | output = self._remove_principal_component(output, self.components) 417 | 418 | return output 419 | 420 | def predict(self, sentences, outer_vecs, **kwargs): 421 | """Train the model on sentences 422 | 423 | Parameters 424 | ---------- 425 | sentences : iterable of list of str 426 | The `sentences` iterable can be simply a list of lists of tokens, but for larger corpora, 427 | consider an iterable that streams the sentences directly from disk/network. 428 | 429 | Returns 430 | ------- 431 | numpy.ndarray 432 | The sentence embedding matrix of dim len(sentences) * vector_size 433 | """ 434 | output = self.cal_output(sentences, outer_vecs, **kwargs) 435 | 436 | if self.components > 0: 437 | output = self._remove_principal_component(output, self.components, train_pc=False) 438 | 439 | return output 440 | -------------------------------------------------------------------------------- /code/utils/snippet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Author :Apple 5 | @Time :2020/5/18 19:44 6 | @File :snippet.py 7 | @Desc : 8 | """ 9 | from bert4keras_5_8.backend import keras, search_layer, K, tf 10 | import numpy as np 11 | 12 | 13 | def normalization(ar): 14 | return ar / np.sqrt(np.sum(np.square(ar), axis=-1, keepdims=True)) 15 | 16 | 17 | def adversarial_training(model, embedding_name, epsilon=1): 18 | """给模型添加对抗训练 19 | 其中model是需要添加对抗训练的keras模型,embedding_name 20 | 则是model里边Embedding层的名字。要在模型compile之后使用。 21 | """ 22 | if model.train_function is None: # 如果还没有训练函数 23 | model._make_train_function() # 手动make 24 | old_train_function = model.train_function # 备份旧的训练函数 25 | 26 | # 查找Embedding层 27 | for output in model.outputs: 28 | embedding_layer = search_layer(output, embedding_name) 29 | if embedding_layer is not None: 30 | break 31 | if embedding_layer is None: 32 | raise Exception('Embedding layer not found') 33 | 34 | # 求Embedding梯度 35 | embeddings = embedding_layer.embeddings # Embedding矩阵 36 | gradients = K.gradients(model.total_loss, [embeddings]) # Embedding梯度 37 | gradients = K.zeros_like(embeddings) + gradients[0] # 转为dense tensor 38 | 39 | # 封装为函数 40 | inputs = (model._feed_inputs + 41 | model._feed_targets + 42 | model._feed_sample_weights) # 所有输入层 43 | embedding_gradients = K.function( 44 | inputs=inputs, 45 | outputs=[gradients], 46 | name='embedding_gradients', 47 | ) # 封装为函数 48 | 49 | def train_function(inputs): # 重新定义训练函数 50 | grads = embedding_gradients(inputs)[0] # Embedding梯度 51 | delta = epsilon * grads / (np.sqrt((grads ** 2).sum()) + 1e-8) # 计算扰动 52 | K.set_value(embeddings, K.eval(embeddings) + delta) # 注入扰动 53 | outputs = old_train_function(inputs) # 梯度下降 54 | K.set_value(embeddings, K.eval(embeddings) - delta) # 删除扰动 55 | return outputs 56 | 57 | model.train_function = train_function # 覆盖原训练函数 58 | -------------------------------------------------------------------------------- /data/roberta/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | --------------------------------------------------------------------------------