├── utils ├── __init__.py ├── count.py ├── product.py ├── loss.py ├── utils.py ├── image_feature_extract.py └── bleu_evaluator.py ├── requirements.txt ├── doc └── images │ ├── best_vlgpt.jpg │ ├── data_eda1.jpg │ ├── data_eda2.jpg │ ├── data_fusion.jpg │ ├── baseline_mhred.jpg │ ├── data_preprocess.jpg │ ├── training_data1.jpg │ └── training_data2.jpg ├── .gitignore ├── online_test.sh ├── config └── Custom │ ├── config_base.json │ └── config_medium.json ├── gpt_model ├── activations.py ├── __init__.py ├── optimization.py ├── configuration_gpt2.py ├── configuration_openai.py ├── configuration_utils.py ├── file_utils.py ├── tokenization_bert.py └── modeling_openai.py ├── online_test_data ├── test_answers.json.example └── test_questions.json ├── README.md ├── featurizer ├── dialogue_dataset.py └── get_dataloader.py ├── inference.py └── train.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-ignite==0.4.1 2 | transformers==2.5.1 -------------------------------------------------------------------------------- /doc/images/best_vlgpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/best_vlgpt.jpg -------------------------------------------------------------------------------- /doc/images/data_eda1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/data_eda1.jpg -------------------------------------------------------------------------------- /doc/images/data_eda2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/data_eda2.jpg -------------------------------------------------------------------------------- /doc/images/data_fusion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/data_fusion.jpg -------------------------------------------------------------------------------- /doc/images/baseline_mhred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/baseline_mhred.jpg -------------------------------------------------------------------------------- /doc/images/data_preprocess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/data_preprocess.jpg -------------------------------------------------------------------------------- /doc/images/training_data1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/training_data1.jpg -------------------------------------------------------------------------------- /doc/images/training_data2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerrylsu/JDMDC2020-Solution-2nd/HEAD/doc/images/training_data2.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | experiments/ 3 | data/ 4 | dataset_cache* 5 | runs/ 6 | ParlAI/ 7 | __pycache__ 8 | .env/ 9 | runs/ 10 | .DS_Store 11 | gpt_personachat_cache_v2/ 12 | *dataset_cache* 13 | -------------------------------------------------------------------------------- /online_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | echo "Welcom to JDDC 2020" 3 | export TORCH_HOME=./.torch 4 | export PYTHONPATH=./ 5 | export CUDA_VISIBLE_DEVICES=0 6 | pip3 install -r requirements.txt -i https://pypi.doubanio.com/simple 7 | python3 data/online_data_preprocess.py 8 | python3 inference.py 9 | 10 | echo "Done!" 11 | -------------------------------------------------------------------------------- /utils/count.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def main(): 4 | data_file = "online_test_data/test_answers.json" 5 | with open(data_file) as f: 6 | data = json.load(f) 7 | ans_dict = {} 8 | for ans in data: 9 | text = ans["Answer"] 10 | if text not in ans_dict: 11 | ans_dict[text] = 1 12 | else: 13 | ans_dict[text] += 1 14 | ans_list = list(ans_dict.items()) 15 | ans_list = sorted(ans_list, key=lambda x: x[1], reverse=True) 16 | with open("ans_c.txt", "w") as f: 17 | for text, count in ans_list: 18 | f.write(F"{count} {text}\n") 19 | pass 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /config/Custom/config_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "afn": "gelu_new", 3 | "attn_pdrop": 0.1, 4 | "embd_pdrop": 0.1, 5 | "finetuning_task": null, 6 | "id2label": { 7 | "0": "LABEL_0", 8 | "1": "LABEL_1" 9 | }, 10 | "initializer_range": 0.02, 11 | "is_decoder": false, 12 | "label2id": { 13 | "LABEL_0": 0, 14 | "LABEL_1": 1 15 | }, 16 | "layer_norm_epsilon": 1e-05, 17 | "n_ctx": 512, 18 | "n_embd": 768, 19 | "n_head": 12, 20 | "n_layer": 12, 21 | "n_positions": 513, 22 | "n_special": 0, 23 | "num_labels": 2, 24 | "output_attentions": false, 25 | "output_hidden_states": false, 26 | "output_past": true, 27 | "predict_special_tokens": true, 28 | "pruned_heads": {}, 29 | "resid_pdrop": 0.1, 30 | "summary_activation": null, 31 | "summary_first_dropout": 0.1, 32 | "summary_proj_to_labels": true, 33 | "summary_type": "cls_index", 34 | "summary_use_proj": true, 35 | "torchscript": false, 36 | "use_bfloat16": false, 37 | "feature_extracting": true, 38 | "vocab_size": 29473 39 | } 40 | -------------------------------------------------------------------------------- /config/Custom/config_medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "afn": "gelu_new", 3 | "attn_pdrop": 0.1, 4 | "embd_pdrop": 0.1, 5 | "finetuning_task": null, 6 | "id2label": { 7 | "0": "LABEL_0", 8 | "1": "LABEL_1" 9 | }, 10 | "initializer_range": 0.02, 11 | "is_decoder": false, 12 | "label2id": { 13 | "LABEL_0": 0, 14 | "LABEL_1": 1 15 | }, 16 | "layer_norm_epsilon": 1e-05, 17 | "n_ctx": 512, 18 | "n_embd": 1024, 19 | "n_head": 16, 20 | "n_layer": 24, 21 | "n_positions": 513, 22 | "n_special": 0, 23 | "num_labels": 2, 24 | "output_attentions": false, 25 | "output_hidden_states": false, 26 | "output_past": true, 27 | "predict_special_tokens": true, 28 | "pruned_heads": {}, 29 | "resid_pdrop": 0.1, 30 | "summary_activation": null, 31 | "summary_first_dropout": 0.1, 32 | "summary_proj_to_labels": true, 33 | "summary_type": "cls_index", 34 | "summary_use_proj": true, 35 | "torchscript": false, 36 | "use_bfloat16": false, 37 | "feature_extracting": true, 38 | "vocab_size": 29473 39 | } 40 | -------------------------------------------------------------------------------- /utils/product.py: -------------------------------------------------------------------------------- 1 | import jieba 2 | 3 | class ProductInfo(object): 4 | def __init__(self, kb_f): 5 | self.kb_f = kb_f 6 | self.product_infos = {} 7 | self.load() 8 | pass 9 | 10 | def load(self): 11 | import json 12 | with open(self.kb_f) as f: 13 | infos = json.load(f) 14 | for index, info in enumerate(infos): 15 | # content = list(str(index)) 16 | content = [] 17 | pid = info["pid"] 18 | # for k, v in info.items(): 19 | # if k == "pid": 20 | # continue 21 | # content.extend(jieba.cut(k, cut_all=False)) 22 | # content.extend(jieba.cut(v, cut_all=False)) 23 | value = info["分类"] 24 | content = value # list(jieba.cut(value, cut_all=False)) 25 | self.product_infos[pid] = content 26 | 27 | def get_info_by_pid(self, pid=None): 28 | if not pid: 29 | return "未知" 30 | if pid not in self.product_infos: 31 | return "未知" 32 | # return pid + " " + " ".join(self.product_infos[pid]) + " <$$$> " 33 | return self.product_infos[pid] 34 | # return " ".join(product_infos[pid]) 35 | -------------------------------------------------------------------------------- /gpt_model/activations.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def swish(x): 8 | return x * torch.sigmoid(x) 9 | 10 | 11 | def _gelu_python(x): 12 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 13 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 14 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 15 | This is now written in C in torch.nn.functional 16 | Also see https://arxiv.org/abs/1606.08415 17 | """ 18 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 19 | 20 | 21 | if torch.__version__ < "1.4.0": 22 | gelu = _gelu_python 23 | else: 24 | gelu = F.gelu 25 | 26 | 27 | def gelu_new(x): 28 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 29 | Also see https://arxiv.org/abs/1606.08415 30 | """ 31 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 32 | 33 | 34 | ACT2FN = { 35 | "relu": F.relu, 36 | "swish": swish, 37 | "gelu": gelu, 38 | "tanh": F.tanh, 39 | "gelu_new": gelu_new, 40 | } 41 | 42 | 43 | def get_activation(activation_string): 44 | if activation_string in ACT2FN: 45 | return ACT2FN[activation_string] 46 | else: 47 | raise KeyError( 48 | "function {} not found in ACT2FN mapping {} or torch.nn.functional".format( 49 | activation_string, list(ACT2FN.keys()) 50 | ) 51 | ) 52 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = '4' 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class LabelSmoothSoftmaxCE(nn.Module): 8 | def __init__(self, 9 | lb_pos=0.9, 10 | lb_neg=0.005, 11 | reduction='mean', 12 | lb_ignore=255, 13 | ): 14 | super(LabelSmoothSoftmaxCE, self).__init__() 15 | self.lb_pos = lb_pos 16 | self.lb_neg = lb_neg 17 | self.reduction = reduction 18 | self.lb_ignore = lb_ignore 19 | self.log_softmax = nn.LogSoftmax(1) 20 | 21 | def forward(self, logits, label): 22 | logs = self.log_softmax(logits) 23 | ignore = label.data.cpu() == self.lb_ignore 24 | n_valid = (ignore == 0).sum() 25 | label[ignore] = 0 26 | lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1) 27 | label = self.lb_pos * lb_one_hot + self.lb_neg * (1 - lb_one_hot) 28 | ignore = ignore.nonzero() 29 | _, M = ignore.size() 30 | a, *b = ignore.chunk(M, dim=1) 31 | label[[a, torch.arange(label.size(1)), *b]] = 0 32 | 33 | if self.reduction == 'mean': 34 | loss = -torch.sum(torch.sum(logs * label, dim=1)) / n_valid 35 | elif self.reduction == 'none': 36 | loss = -torch.sum(logs * label, dim=1) 37 | return loss 38 | 39 | 40 | if __name__ == '__main__': 41 | torch.manual_seed(15) 42 | criterion = LabelSmoothSoftmaxCE(lb_pos=0.9, lb_neg=5e-3) 43 | 44 | out = torch.randn(10, 5).cuda() 45 | lbs = torch.randint(5, (10,)).cuda() 46 | print('out:', out) 47 | print('lbs:', lbs) 48 | 49 | import torch.nn.functional as F 50 | 51 | loss = criterion(out, lbs) 52 | print('loss:', loss) -------------------------------------------------------------------------------- /gpt_model/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | __version__ = "2.5.1" 6 | 7 | from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config 8 | from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig 9 | from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer 10 | # Configurations 11 | from .configuration_utils import PretrainedConfig 12 | # Files and general utilities 13 | 14 | from .file_utils import ( 15 | CONFIG_NAME, 16 | MODEL_CARD_NAME, 17 | PYTORCH_PRETRAINED_BERT_CACHE, 18 | PYTORCH_TRANSFORMERS_CACHE, 19 | TF2_WEIGHTS_NAME, 20 | TF_WEIGHTS_NAME, 21 | TRANSFORMERS_CACHE, 22 | WEIGHTS_NAME, 23 | add_end_docstrings, 24 | add_start_docstrings, 25 | cached_path, 26 | is_tf_available, 27 | is_torch_available, 28 | ) 29 | 30 | # Modeling 31 | if is_torch_available(): 32 | from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering 33 | from .modeling_gpt2 import ( 34 | GPT2PreTrainedModel, 35 | GPT2Model, 36 | GPT2LMHeadModel, 37 | GPT2DoubleHeadsModel, 38 | load_tf_weights_in_gpt2, 39 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, 40 | ) 41 | from .modeling_openai import ( 42 | OpenAIGPTPreTrainedModel, 43 | OpenAIGPTModel, 44 | OpenAIGPTLMHeadModel, 45 | OpenAIGPTDoubleHeadsModel, 46 | load_tf_weights_in_openai_gpt, 47 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, 48 | ) 49 | 50 | # Optimization 51 | from .optimization import ( 52 | AdamW, 53 | get_constant_schedule, 54 | get_constant_schedule_with_warmup, 55 | get_cosine_schedule_with_warmup, 56 | get_cosine_with_hard_restarts_schedule_with_warmup, 57 | get_linear_schedule_with_warmup, 58 | ) 59 | -------------------------------------------------------------------------------- /online_test_data/test_answers.json.example: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Id": "a11390785fb6eb9cedf248e3800299eb_0", 4 | "Answer": " 您好,这边是京东自营***旗舰店售后客服,请问有什么可以帮您的呐#E-s36#E-s36 您好,这边是京东自营***旗舰店售后客服,请问有什么可以帮您的呐#E-s36#E-s36" 5 | }, 6 | { 7 | "Id": "c82d617f8b4c36446784c4dd3babeabe_0", 8 | "Answer": "建议您自行查看一下 是否有库存1 想要的尺码有库存 是可以换的" 9 | }, 10 | { 11 | "Id": "59610ceba3a58eb13c7a4583e38d1680_0", 12 | "Answer": "给您带来这么多的麻烦实在是非常抱歉,鞋底含有避震科技,材质的密度比较大,鞋子在穿着使用后,材料会有一定弹性,发挥缓震作用,但外面涂层无弹性,所以会产生这样的情况。" 13 | }, 14 | { 15 | "Id": "553434cde6226eb28facaeccb4b61bb1_0", 16 | "Answer": "您收到货了吗 11521243**" 17 | }, 18 | { 19 | "Id": "49fe7e93d0fd0584fc8fab13c222c11c_0", 20 | "Answer": "这边没又办法确定呢" 21 | }, 22 | { 23 | "Id": "779a8ff138c1fa8f7d8553cfed1ead66_0", 24 | "Answer": "#E-s21#E-s21" 25 | }, 26 | { 27 | "Id": "2471df8c3701578a3eb306bfffab4111_0", 28 | "Answer": "尊敬的客户您好,请问有什么可以帮到您的呢~" 29 | }, 30 | { 31 | "Id": "ced106fe76a565cb319793f1d38661b0_0", 32 | "Answer": "已经过保了哦 需要您去就近的***网点看下哦 您可以拨打***全国售后服务热线400-6186-999,查询离您最近的售后网点咨询。或者选择在线查询,***售后服务网站; 感谢您对京东的支持~" 33 | }, 34 | { 35 | "Id": "90f9ce87b263cd18770dce9b23df108b_0", 36 | "Answer": "您现在可以在发热盘倒入白醋(白醋30% 水70%),选择可以煮开的功能, 把水煮沸腾,浸泡10-15分钟,就可以清除掉了 可以按这个方法清洗看看呢" 37 | }, 38 | { 39 | "Id": "c78f19fcb43af0362593993dad5ef24f_0", 40 | "Answer": " 这个哦" 41 | }, 42 | { 43 | "Id": "3b97e429cddcc5cee7d6492f463bca3b_0", 44 | "Answer": "尊敬的客户您好,这边是***自营旗舰店售后客服,请问有什么可以帮到您的呢?#E-s36#E-s36" 45 | }, 46 | { 47 | "Id": "d7bfa9a78b1289393557df1c509280f7_0", 48 | "Answer": "在的" 49 | }, 50 | { 51 | "Id": "c8397fa44f876ebd108c72de8ac0409b_0", 52 | "Answer": "售后单审核通过后可以填写哟" 53 | }, 54 | { 55 | "Id": "2db98b28381ed0c3aab4cfc49f371838_0", 56 | "Answer": "您洗掉色的清洗水有照片吗" 57 | }, 58 | { 59 | "Id": "d79eb8dbebff125b1159e133eba328d0_0", 60 | "Answer": "预计时间仅供参考 具体得看物流" 61 | }, 62 | { 63 | "Id": "bec8e9852bfea7449583cf228d2510b2_0", 64 | "Answer": "拍不了是吗 那您吧理由修改一下 按正常的换货走、" 65 | }, 66 | { 67 | "Id": "18794790761fb09ee013e9a89c1616e1_0", 68 | "Answer": "***正品质量都是有保障的哦~您放心购买" 69 | }, 70 | { 71 | "Id": "d87b76bc7dc5157bba6055a01849b063_0", 72 | "Answer": "实在抱歉呢 #E-s17" 73 | }, 74 | { 75 | "Id": "6fa0cc56db146efa6ff7d76b97d97209_0", 76 | "Answer": "限量的哈 您下单看看 没有的话就是抢完了呢 世界上最遥远的距离不是生与死,而是您拍下了我却忘记了付款,为了与您相遇,特飞鸽传书,仅需几分钟,错过需要再等1年,抓紧付款吧!" 77 | }, 78 | { 79 | "Id": "15cc323372499df3c4eceddfa8b1989e_0", 80 | "Answer": "如果发生退换货的话 个人原因运费自理 一般是8元的 15天内质量问题是免邮的哈 您在订单里申请退换货审核通过后里面会有退回方式等的提示的呢 么么哒 检测 没有膜 是我们承担的哦" 81 | } 82 | ] 83 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, HuggingFace Inc. 2 | # All rights reserved. This source code is licensed under the BSD-style license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | from datetime import datetime 5 | import json 6 | import logging 7 | import os 8 | import tarfile 9 | import tempfile 10 | import socket 11 | import torch 12 | from transformers import cached_path 13 | 14 | 15 | def download_pretrained_model(logger): 16 | """ Download and extract finetuned model from S3 """ 17 | resolved_archive_file = cached_path(HF_FINETUNED_MODEL) 18 | tempdir = tempfile.mkdtemp() 19 | logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir)) 20 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 21 | archive.extractall(tempdir) 22 | return tempdir 23 | 24 | 25 | def get_dataset(tokenizer, dataset_path, dataset_cache, logger): 26 | """ Get tokenized PERSONACHAT dataset from S3 or cache_no_pretrained.""" 27 | dataset_path = dataset_path 28 | dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ # To avoid using GPT cache_no_pretrained for GPT-2 and vice-versa 29 | if dataset_cache and os.path.isfile(dataset_cache): 30 | logger.info("Load tokenized dataset from cache_no_pretrained at %s", dataset_cache) 31 | dataset = torch.load(dataset_cache) 32 | else: 33 | logger.info("Download dataset from %s", dataset_path) 34 | personachat_file = cached_path(dataset_path) 35 | with open(personachat_file, "r", encoding="utf-8") as f: 36 | dataset = json.loads(f.read()) 37 | 38 | logger.info("Tokenize and encode the dataset") 39 | def tokenize(obj): 40 | if isinstance(obj, str): 41 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) 42 | if isinstance(obj, dict): 43 | return dict((n, tokenize(o)) if n != 'img_list' else (n, o) for n, o in obj.items()) 44 | return list(tokenize(o) for o in obj) 45 | dataset = tokenize(dataset) 46 | torch.save(dataset, dataset_cache) 47 | return dataset 48 | 49 | 50 | class AttrDict(dict): 51 | def __init__(self, *args, **kwargs): 52 | super(AttrDict, self).__init__(*args, **kwargs) 53 | self.__dict__ = self 54 | 55 | 56 | def make_logdir(model_name: str): 57 | """Create unique path to save results and checkpoints, e.g. runs/Sep22_19-45-59_gpu-7_gpt2""" 58 | # Code copied from ignite repo 59 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 60 | logdir = os.path.join( 61 | 'runs', current_time + '_' + socket.gethostname() + '_' + model_name) 62 | return logdir 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2020京东多模态对话JDMDC2020第二名解决方案 2 | 3 | [JDMDC2020官网:http://jddc.jd.com](http://jddc.jd.com) 4 | 5 | Hibot团队成绩:初赛第三,决赛第二 6 | 7 | Hibot队员:林克,茸茸,阿布,杰瑞 8 | 9 | ## 数据工作 10 | 11 | ### 数据集探索 12 | 13 | 数据分为五个字段:会话ID, 店铺类别, 商品ID, 对话文本(包含图片ID), 对话角色。通过商品ID与图片ID,来引用商品知识库数据和图片模态数据。 14 | 15 | ![data_eda1](doc/images/data_eda1.jpg) 16 | 17 | 统计原始会话(连续QQAA未合并)轮数长度以及单句文本长度(按字)。 18 | 19 | ![data_eda2](doc/images/data_eda2.jpg) 20 | 21 | ### 数据预处理 22 | 23 | ![data_preprocess](doc/images/data_preprocess.jpg) 24 | 25 | ### 训练数据构造 26 | 27 | ![training_data1](doc/images/training_data1.jpg) 28 | 29 | ![training_data2](doc/images/training_data2.jpg) 30 | 31 | ### 融合多模态与知识库 32 | 33 | ![data_fusion](doc/images/data_fusion.jpg) 34 | 35 | ## 技术方案 36 | 37 | ### 基线模型MHRED 38 | 39 | 官方提供基线生成模型MHRED:https://github.com/jd-aig/nlp_baai/tree/master/jddc2020_baseline/mhred/pytorch 40 | 41 | 基线模型MHRED复现BLEU分为:3.3853,在基线的基础上,我们加入了注意力机制BLEU提分到:5.6237。 42 | 43 | ![mhred](doc/images/baseline_mhred.jpg) 44 | 45 | ### 最佳模型VLGPT 46 | 47 | ![vlgpt](doc/images/best_vlgpt.jpg) 48 | 49 | ## 总结与引用 50 | 51 | **What did work** 52 | 53 | 1. 去除上下文为空的数据 54 | 55 | 2. 去除无意义回复,例如好的,嗯嗯,哦... 56 | 57 | 3. 根据数据集自定义字典,OOV问题 58 | 59 | 4. 带掩码的Loss,masked answer labels 60 | 61 | 5. 增加token type embedding区分角色 62 | 63 | 6. Base基础GPT模型 64 | 65 | 7. 各种注意力机制Self-Attention, Context-Attention, Masked SeLf-Attention[5][12] 66 | 67 | **What didn't work** 68 | 69 | 1. 删除通用模板回复 70 | 71 | 2. Label Smoothing提升微弱 [13] 72 | 73 | 3. 多头任务学习DoubleHead GPT(生成任务+是否为下一句) [14] 74 | 75 | 4. BERT中文wiki预训练模型 [15] 76 | 77 | 5. GPT中文wiki[16]与清华开源预训练模型 [17][18] 78 | 79 | 6. Medium中型GPT模型 80 | 81 | 7. 最大互信息Maximum Mutual Information(MMI) [19] 82 | 83 | 84 | [1] Kishore Papineni, Salim Roukos, et al. BLEU: a Method for Automatic Evaluation of Machine Translation 85 | 86 | [2] Boxing Chen, Colin Cherry et al. A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU 87 | 88 | [3] Amrita Saha, Mitesh Khapra, et al. Towards Building Large Scale Multimodal Domain-Aware Conversation Systems 89 | 90 | [4] Minh-Thang Luong, Hieu Pham, et al. Effective Approaches to Attention-based Neural Machine Translation 91 | 92 | [5] Ashish Vaswani, Noam Shazeer, et al. Attention Is All You Need 93 | 94 | [6] Jacob Devlin, Ming-Wei Chang, et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 95 | 96 | [7] Alec Radford, Karthik Narasimhan, et al. Improving Language Understanding by Generative Pre-Training 97 | 98 | [8] Alec Radford, Jeffrey Wu, et al. Language Models are Unsupervised Multitask Learners 99 | 100 | [9] https://huggingface.co - transformers 101 | 102 | [10] Thomas Wolf, Victor Sanh, et al. TransferTransfo: A Transfer Learning Approach for Neural Network Based Conversational Agents 103 | 104 | [11] https://github.com/huggingface/transfer-learning-conv-ai 105 | 106 | [12] https://jalammar.github.io/illustrated-transformer 107 | 108 | [13] Rafael Müller, Simon Kornblith, st al. when does label smoothing help? 109 | 110 | [14] https://huggingface.co/transformers/model_doc/gpt2.html#gpt2doubleheadsmodel 111 | 112 | [15] https://huggingface.co/bert-base-chinese 113 | 114 | [16] https://github.com/qingkongzhiqian/GPT2-Summary 115 | 116 | [17] https://cloud.tsinghua.edu.cn/f/4dfb8c6c22ae47fbbe98 117 | 118 | [18] Yida Wang, Pei Ke, et al. A Large-Scale Chinese Short-Text Conversation Dataset 119 | 120 | [19] Yizhe Zhang, Siqi Sun, et al. DialoGPT:Large-Scale Generative Pre-training for Conversational Response Generation -------------------------------------------------------------------------------- /utils/image_feature_extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | # 使用resnet18作为特征提取器, 对图像进行特征提取 4 | """ 5 | import os 6 | import argparse 7 | import PIL 8 | import torch 9 | import torchvision 10 | from torchvision import transforms 11 | from tqdm import tqdm 12 | from pathlib import Path 13 | 14 | # set default path for data and test data 15 | project_dir = Path(__file__).resolve().parent.parent 16 | 17 | # temporarily use resent18 image statistics 18 | data_transforms = { 19 | 'train': transforms.Compose([ 20 | transforms.RandomResizedCrop(224), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 24 | ]), 25 | 'val': transforms.Compose([ 26 | transforms.Resize(224), 27 | transforms.CenterCrop(224), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 30 | ]), 31 | } 32 | 33 | 34 | class Res18ImgFeatureExtractor(object): 35 | def __init__(self): 36 | self.feature_extractor = self._build_feature_extractor() 37 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | self.feature_extractor.to(self.device) # 将模型拷贝到相应设备 39 | self.feature_extractor.eval() # 设置模式, 此处只进行推理 40 | 41 | self.images_feature = {} # 键值对形式存储image feature 42 | 43 | def _build_feature_extractor(self): 44 | """构建特征提取器""" 45 | model_ft = torchvision.models.resnet18(pretrained=True) 46 | # 剔除last全连接 47 | feature_extractor = torch.nn.Sequential(*list(model_ft.children())[:-1]) 48 | 49 | return feature_extractor 50 | 51 | def save_image_feature(self, image_feature_path='./images_feature.pkl'): 52 | """ 存储image feature 53 | @param image_feature_path: 54 | @return: 55 | """ 56 | torch.save(self.images_feature, f=image_feature_path) 57 | 58 | def get_image_to_feature(self, image_dir, data_type='val'): 59 | """ 获取图像feature 60 | @param image_dir: 图像的存储路劲 61 | @param data_type 62 | @return: 63 | """ 64 | if not os.path.isdir(image_dir): 65 | raise FileExistsError("Image Directory No Exist.") 66 | 67 | image_files = os.listdir(image_dir) # 获取路径下所有文件名字 68 | for image_file in tqdm(image_files): 69 | image_path = os.path.join(image_dir, image_file) 70 | if os.path.isdir(image_path): 71 | # 列表不进行处理, 其实可以递归的, self.get 72 | continue 73 | image_data = self.read_image(image_path, data_type=data_type) 74 | # 0维扩充 75 | image_data = image_data.unsqueeze(0).to(self.device) 76 | image_feature = self.feature_extractor(image_data) 77 | # [1, 512, 1, 1] -> [512, ] 78 | image_feature = torch.flatten(image_feature, 1).cpu().data.numpy().squeeze() 79 | 80 | # 存储数据 81 | self.images_feature[image_file] = image_feature # image name is key 82 | 83 | def get_image_to_feature_from_dirs(self, image_dir_list, data_type='val'): 84 | """ 获取图像feature从多个文件夹下 85 | @param image_dir_list: 图像的存储路劲 86 | @param data_type 87 | @return: 88 | """ 89 | for image_dir in image_dir_list: 90 | if not os.path.isdir(image_dir): 91 | raise FileExistsError("Image Directory No Exist.") 92 | 93 | image_files = os.listdir(image_dir) # 获取路径下所有文件名字 94 | for image_file in tqdm(image_files): 95 | image_path = os.path.join(image_dir, image_file) 96 | if os.path.isdir(image_path): 97 | # 列表不进行处理, 其实可以递归的, self.get 98 | continue 99 | image_data = self.read_image(image_path, data_type=data_type) 100 | # 0维扩充 101 | image_data = image_data.unsqueeze(0).to(self.device) 102 | image_feature = self.feature_extractor(image_data) 103 | # [1, 512, 1, 1] -> [512, ] 104 | image_feature = torch.flatten(image_feature, 1).cpu().data.numpy().squeeze() 105 | 106 | # 存储数据 107 | self.images_feature[image_file] = image_feature # image name is key 108 | 109 | def read_image(self, image_path, data_type='val'): 110 | """ 读取图片 111 | @param image_path: 112 | @param data_type: 113 | @return: 114 | """ 115 | image = torch.zeros(3, 224, 224) 116 | try: 117 | image_tmp = PIL.Image.open(image_path) 118 | image = data_transforms[data_type](image_tmp) 119 | except Exception as err: 120 | print(err) 121 | 122 | return image 123 | 124 | 125 | if __name__ == "__main__": 126 | parser = argparse.ArgumentParser(description="设置图像特征提取参数") 127 | 128 | parser.add_argument('--img_dir', default=os.path.join(project_dir, 'data/images_train')) 129 | parser.add_argument('--img_dev_dir', default=os.path.join(project_dir, 'data/images_dev')) 130 | parser.add_argument('--data_type', default='train') 131 | parser.add_argument('--img_feature_path', default=os.path.join(project_dir, 'data/train_images_feature.pkl')) 132 | 133 | args = parser.parse_args() 134 | 135 | main = Res18ImgFeatureExtractor() 136 | if args.img_dev_dir: 137 | img_dir_list = [args.img_dir, args.img_dev_dir] 138 | else: 139 | img_dir_list = [args.img_dir] 140 | # main.get_image_to_feature(image_dir=args.img_dir, data_type=args.data_type) 141 | main.get_image_to_feature_from_dirs(img_dir_list, data_type=args.data_type) 142 | main.save_image_feature(image_feature_path=args.img_feature_path) 143 | -------------------------------------------------------------------------------- /utils/bleu_evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | # 使用nltk库实现BLEU评估算法 3 | """ 4 | import os 5 | import logging 6 | import json 7 | import argparse 8 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 9 | 10 | logging.basicConfig( 11 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 12 | datefmt="%m/%d/%Y %H:%M:%S", 13 | level=logging.INFO, 14 | ) 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class BLEUEvaluator(object): 19 | def __init__(self, 20 | predict_json_path='./online_test_data/test_answers.json', 21 | target_json_path='./online_test_data/test_answers_target.json', 22 | question_json_path=None, 23 | good_qa_threshold=None): 24 | with open(predict_json_path, mode='r', encoding='utf-8') as file_predict: 25 | self.predict_content = json.load(file_predict) 26 | with open(target_json_path, mode='r', encoding='utf-8') as file_target: 27 | self.target_content = json.load(file_target) 28 | 29 | self.good_qa_threshold = None 30 | if question_json_path and good_qa_threshold: 31 | # 二者条件都满足才进行此项操作 32 | self.good_qa_threshold = good_qa_threshold 33 | with open(question_json_path, mode='r', encoding='utf-8') as file_question: 34 | # 获取问题列表 35 | self.questions_dict = self.get_question_dict(json.load(file_question)) 36 | 37 | # 记录bleu分比较高的回答, 追溯其问题以及上下文 38 | self.good_test_questions = [] 39 | self.bad_test_questions = [] 40 | 41 | def eval(self): 42 | """评估函数""" 43 | predict_answers_dict = self.get_answer_dict(self.predict_content) 44 | target_answers_dict = self.get_answer_dict(self.target_content) 45 | 46 | return self.compute_bleu(predict_answers_dict, target_answers_dict) 47 | 48 | def compute_bleu(self, predict_dict, target_dict): 49 | """ 50 | Args: 51 | predict_dict: 预测字典列表 52 | target_dict: 53 | 54 | Returns: 55 | 56 | """ 57 | n_sum = 0 58 | smooth = SmoothingFunction() 59 | 60 | predict_data_length = len(predict_dict.keys()) 61 | _logger.info("all predict data size: {}.".format(predict_data_length)) 62 | for single_key in predict_dict.keys(): 63 | if not target_dict.get(single_key): 64 | # 跳过查不大不到目标的数据 65 | predict_dict -= 1 66 | continue 67 | 68 | target_list_three = target_dict.get(single_key).split("") 69 | n_eval_result = sentence_bleu(target_list_three, predict_dict.get(single_key), 70 | smoothing_function=smooth.method1) 71 | 72 | print(n_eval_result) 73 | if self.good_qa_threshold: 74 | if n_eval_result > self.good_qa_threshold: 75 | self.good_test_questions.append(self.good_qa_track(single_key, predict_dict.get(single_key), n_eval_result)) 76 | else: 77 | self.bad_test_questions.append(self.good_qa_track(single_key, predict_dict.get(single_key), n_eval_result)) 78 | 79 | n_sum += n_eval_result 80 | 81 | _logger.info("resize predict data size: {}.".format(predict_data_length)) 82 | 83 | return float(n_sum) / predict_data_length 84 | 85 | def good_qa_track(self, predict_single_key, predict_single_value, n_eval_result): 86 | """获取最佳QA对""" 87 | # 1. 对应ID的question context 88 | single_question = self.questions_dict.get(predict_single_key) 89 | # 2. 添加预测结果 90 | single_question['PredictAnswer'] = predict_single_value 91 | # 3. BLEU score 92 | single_question['BLEU'] = n_eval_result 93 | 94 | return single_question 95 | 96 | def get_answer_dict(self, answers_dict_list): 97 | """将dict list 转为dict, Id为key, answer为value""" 98 | answers_dict = {} 99 | for single_dict in answers_dict_list: 100 | answers_dict[single_dict.get('Id')] = single_dict.get('Answer') 101 | 102 | return answers_dict 103 | 104 | def get_question_dict(self, questions_dict_list): 105 | """将dict list 转为dict, Id为key, 整体dict为value""" 106 | questions_dict = {} 107 | for single_dict in questions_dict_list: 108 | questions_dict[single_dict.get('Id')] = single_dict 109 | 110 | return questions_dict 111 | 112 | def save_good_qa_question(self, file_path): 113 | """保存最佳qa的question""" 114 | with open(file_path, mode='w', encoding='utf-8') as fw: 115 | json.dump(self.good_test_questions, fw, ensure_ascii=False, indent=2) 116 | 117 | def save_bad_qa_question(self, file_path): 118 | """保存最佳qa的question""" 119 | with open(file_path, mode='w', encoding='utf-8') as fw: 120 | json.dump(self.bad_test_questions, fw, ensure_ascii=False, indent=2) 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser(description="Evaluation of the sentence generation effect.") 125 | 126 | parser.add_argument('-p', '--predict_json_path', 127 | default='../online_test_data/test_answers.json', 128 | type=str, 129 | help='The json file for the predict results.') 130 | parser.add_argument('-t', '--target_json_path', 131 | default='../online_test_data/test_answers_target.json', 132 | type=str, 133 | help='The json file for the target results.') 134 | parser.add_argument('-q', '--question_json_path', 135 | default='../online_test_data/test_questions.json', 136 | type=str, 137 | help='The json file for the question contents.') 138 | parser.add_argument('-gq', '--good_question_json_path', 139 | default='../online_test_data/good_test_questions.json', 140 | type=str, 141 | help='The json file for the save good ga track question contents.') 142 | parser.add_argument('-bq', '--bad_question_json_path', 143 | default='../online_test_data/bad_test_questions.json', 144 | type=str, 145 | help='The json file for the save bad ga track question contents.') 146 | parser.add_argument('-trd', '--good_qa_threshold', 147 | default=None, 148 | type=float, 149 | help='the judge threshold for the predict and question is good qa track.') 150 | 151 | args = parser.parse_args() 152 | 153 | evaluator = BLEUEvaluator(predict_json_path=args.predict_json_path, 154 | target_json_path=args.target_json_path, 155 | question_json_path=args.question_json_path, 156 | good_qa_threshold=args.good_qa_threshold) 157 | 158 | eval_result = evaluator.eval() 159 | _logger.info("eval result is {}.".format(eval_result)) 160 | 161 | if args.good_qa_threshold: 162 | _logger.info("save good qa track question contents.") 163 | evaluator.save_good_qa_question(args.good_question_json_path) 164 | evaluator.save_bad_qa_question(args.bad_question_json_path) 165 | -------------------------------------------------------------------------------- /gpt_model/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def get_constant_schedule(optimizer, last_epoch=-1): 29 | """ Create a schedule with a constant learning rate. 30 | """ 31 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 32 | 33 | 34 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 35 | """ Create a schedule with a constant learning rate preceded by a warmup 36 | period during which the learning rate increases linearly between 0 and 1. 37 | """ 38 | 39 | def lr_lambda(current_step): 40 | if current_step < num_warmup_steps: 41 | return float(current_step) / float(max(1.0, num_warmup_steps)) 42 | return 1.0 43 | 44 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 45 | 46 | 47 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 48 | """ Create a schedule with a learning rate that decreases linearly after 49 | linearly increasing during a warmup period. 50 | """ 51 | 52 | def lr_lambda(current_step): 53 | if current_step < num_warmup_steps: 54 | return float(current_step) / float(max(1, num_warmup_steps)) 55 | return max( 56 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 57 | ) 58 | 59 | return LambdaLR(optimizer, lr_lambda, last_epoch) 60 | 61 | 62 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1): 63 | """ Create a schedule with a learning rate that decreases following the 64 | values of the cosine function between 0 and `pi * cycles` after a warmup 65 | period during which it increases linearly between 0 and 1. 66 | """ 67 | 68 | def lr_lambda(current_step): 69 | if current_step < num_warmup_steps: 70 | return float(current_step) / float(max(1, num_warmup_steps)) 71 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 72 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 73 | 74 | return LambdaLR(optimizer, lr_lambda, last_epoch) 75 | 76 | 77 | def get_cosine_with_hard_restarts_schedule_with_warmup( 78 | optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1 79 | ): 80 | """ Create a schedule with a learning rate that decreases following the 81 | values of the cosine function with several hard restarts, after a warmup 82 | period during which it increases linearly between 0 and 1. 83 | """ 84 | 85 | def lr_lambda(current_step): 86 | if current_step < num_warmup_steps: 87 | return float(current_step) / float(max(1, num_warmup_steps)) 88 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 89 | if progress >= 1.0: 90 | return 0.0 91 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 92 | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) 94 | 95 | 96 | class AdamW(Optimizer): 97 | """ Implements Adam algorithm with weight decay fix. 98 | 99 | Parameters: 100 | lr (float): learning rate. Default 1e-3. 101 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 102 | eps (float): Adams epsilon. Default: 1e-6 103 | weight_decay (float): Weight decay. Default: 0.0 104 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 105 | """ 106 | 107 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 108 | if lr < 0.0: 109 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 110 | if not 0.0 <= betas[0] < 1.0: 111 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 112 | if not 0.0 <= betas[1] < 1.0: 113 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 114 | if not 0.0 <= eps: 115 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 116 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 117 | super().__init__(params, defaults) 118 | 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | 122 | Arguments: 123 | closure (callable, optional): A closure that reevaluates the model 124 | and returns the loss. 125 | """ 126 | loss = None 127 | if closure is not None: 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | for p in group["params"]: 132 | if p.grad is None: 133 | continue 134 | grad = p.grad.data 135 | if grad.is_sparse: 136 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 137 | 138 | state = self.state[p] 139 | 140 | # State initialization 141 | if len(state) == 0: 142 | state["step"] = 0 143 | # Exponential moving average of gradient values 144 | state["exp_avg"] = torch.zeros_like(p.data) 145 | # Exponential moving average of squared gradient values 146 | state["exp_avg_sq"] = torch.zeros_like(p.data) 147 | 148 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 149 | beta1, beta2 = group["betas"] 150 | 151 | state["step"] += 1 152 | 153 | # Decay the first and second moment running average coefficient 154 | # In-place operations to update the averages at the same time 155 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 156 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 157 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 158 | 159 | step_size = group["lr"] 160 | if group["correct_bias"]: # No bias correction for Bert 161 | bias_correction1 = 1.0 - beta1 ** state["step"] 162 | bias_correction2 = 1.0 - beta2 ** state["step"] 163 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 164 | 165 | p.data.addcdiv_(-step_size, exp_avg, denom) 166 | 167 | # Just adding the square of the weights to the loss function is *not* 168 | # the correct way of using L2 regularization/weight decay with Adam, 169 | # since that will interact with the m and v parameters in strange ways. 170 | # 171 | # Instead we want to decay the weights in a manner that doesn't interact 172 | # with the m/v parameters. This is equivalent to adding the square 173 | # of the weights to the loss with plain (non-momentum) SGD. 174 | # Add weight decay at the end (fixed version) 175 | if group["weight_decay"] > 0.0: 176 | p.data.add_(-group["lr"] * group["weight_decay"], p.data) 177 | 178 | return loss 179 | -------------------------------------------------------------------------------- /gpt_model/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", 28 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", 29 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", 30 | "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", 31 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json", 32 | } 33 | 34 | 35 | class GPT2Config(PretrainedConfig): 36 | """ 37 | This is the configuration class to store the configuration of a :class:`~transformers.GPT2Model`. 38 | It is used to instantiate an GPT-2 model according to the specified arguments, defining the model 39 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 40 | the GPT-2 `small `__ architecture. 41 | 42 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 43 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 44 | for more information. 45 | 46 | 47 | Args: 48 | vocab_size (:obj:`int`, optional, defaults to 50257): 49 | Vocabulary size of the GPT-2 model. Defines the different tokens that 50 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.GPT2Model`. 51 | n_positions (:obj:`int`, optional, defaults to 1024): 52 | The maximum sequence length that this model might ever be used with. 53 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 54 | n_ctx (:obj:`int`, optional, defaults to 1024): 55 | Dimensionality of the causal mask (usually same as n_positions). 56 | n_embd (:obj:`int`, optional, defaults to 768): 57 | Dimensionality of the embeddings and hidden states. 58 | n_layer (:obj:`int`, optional, defaults to 12): 59 | Number of hidden layers in the Transformer encoder. 60 | n_head (:obj:`int`, optional, defaults to 12): 61 | Number of attention heads for each attention layer in the Transformer encoder. 62 | resid_pdrop (:obj:`float`, optional, defaults to 0.1): 63 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 64 | embd_pdrop (:obj:`int`, optional, defaults to 0.1): 65 | The dropout ratio for the embeddings. 66 | attn_pdrop (:obj:`float`, optional, defaults to 0.1): 67 | The dropout ratio for the attention. 68 | layer_norm_epsilon (:obj:`float`, optional, defaults to 1e-5): 69 | The epsilon to use in the layer normalization layers 70 | initializer_range (:obj:`float`, optional, defaults to 16): 71 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 72 | summary_type (:obj:`string`, optional, defaults to "cls_index"): 73 | Argument used when doing sequence summary. Used in for the multiple choice head in 74 | :class:`~transformers.GPT2DoubleHeadsModel`. 75 | Is one of the following options: 76 | - 'last' => take the last token hidden state (like XLNet) 77 | - 'first' => take the first token hidden state (like Bert) 78 | - 'mean' => take the mean of all tokens hidden states 79 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) 80 | - 'attn' => Not implemented now, use multi-head attention 81 | summary_use_proj (:obj:`boolean`, optional, defaults to :obj:`True`): 82 | Argument used when doing sequence summary. Used in for the multiple choice head in 83 | :class:`~transformers.GPT2DoubleHeadsModel`. 84 | Add a projection after the vector extraction 85 | summary_activation (:obj:`string` or :obj:`None`, optional, defaults to :obj:`None`): 86 | Argument used when doing sequence summary. Used in for the multiple choice head in 87 | :class:`~transformers.GPT2DoubleHeadsModel`. 88 | 'tanh' => add a tanh activation to the output, Other => no activation. 89 | summary_proj_to_labels (:obj:`boolean`, optional, defaults to :obj:`True`): 90 | Argument used when doing sequence summary. Used in for the multiple choice head in 91 | :class:`~transformers.GPT2DoubleHeadsModel`. 92 | If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. 93 | summary_first_dropout (:obj:`float`, optional, defaults to 0.1): 94 | Argument used when doing sequence summary. Used in for the multiple choice head in 95 | :class:`~transformers.GPT2DoubleHeadsModel`. 96 | Add a dropout before the projection and activation 97 | 98 | Example:: 99 | 100 | from transformers import GPT2Model, GPT2Config 101 | 102 | # Initializing a GPT2 configuration 103 | configuration = GPT2Config() 104 | 105 | # Initializing a model from the configuration 106 | model = GPT2Model(configuration) 107 | 108 | # Accessing the model configuration 109 | configuration = model.config 110 | 111 | Attributes: 112 | pretrained_config_archive_map (Dict[str, str]): 113 | A dictionary containing all the available pre-trained checkpoints. 114 | """ 115 | 116 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 117 | model_type = "gpt2" 118 | 119 | def __init__( 120 | self, 121 | vocab_size=50257, 122 | n_positions=1024, 123 | n_ctx=1024, 124 | n_embd=768, 125 | n_layer=12, 126 | n_head=12, 127 | resid_pdrop=0.1, 128 | embd_pdrop=0.1, 129 | attn_pdrop=0.1, 130 | layer_norm_epsilon=1e-5, 131 | initializer_range=0.02, 132 | summary_type="cls_index", 133 | summary_use_proj=True, 134 | summary_activation=None, 135 | summary_proj_to_labels=True, 136 | summary_first_dropout=0.1, 137 | **kwargs 138 | ): 139 | super().__init__(**kwargs) 140 | 141 | self.vocab_size = vocab_size 142 | self.n_ctx = n_ctx 143 | self.n_positions = n_positions 144 | self.n_embd = n_embd 145 | self.n_layer = n_layer 146 | self.n_head = n_head 147 | self.resid_pdrop = resid_pdrop 148 | self.embd_pdrop = embd_pdrop 149 | self.attn_pdrop = attn_pdrop 150 | self.layer_norm_epsilon = layer_norm_epsilon 151 | self.initializer_range = initializer_range 152 | self.summary_type = summary_type 153 | self.summary_use_proj = summary_use_proj 154 | self.summary_activation = summary_activation 155 | self.summary_first_dropout = summary_first_dropout 156 | self.summary_proj_to_labels = summary_proj_to_labels 157 | 158 | @property 159 | def max_position_embeddings(self): 160 | return self.n_positions 161 | 162 | @property 163 | def hidden_size(self): 164 | return self.n_embd 165 | 166 | @property 167 | def num_attention_heads(self): 168 | return self.n_head 169 | 170 | @property 171 | def num_hidden_layers(self): 172 | return self.n_layer 173 | -------------------------------------------------------------------------------- /gpt_model/configuration_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" 28 | } 29 | 30 | 31 | class OpenAIGPTConfig(PretrainedConfig): 32 | """ 33 | This is the configuration class to store the configuration of an :class:`~transformers.OpenAIGPTModel`. 34 | It is used to instantiate an GPT model according to the specified arguments, defining the model 35 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 36 | the `GPT `__ architecture from OpenAI. 37 | 38 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 39 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 40 | for more information. 41 | 42 | Args: 43 | vocab_size (:obj:`int`, optional, defaults to 40478): 44 | Vocabulary size of the GPT model. Defines the different tokens that 45 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.CTRLModel`. 46 | n_positions (:obj:`int`, optional, defaults to 512): 47 | The maximum sequence length that this model might ever be used with. 48 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 49 | n_ctx (:obj:`int`, optional, defaults to 512): 50 | Dimensionality of the causal mask (usually same as n_positions). 51 | n_embd (:obj:`int`, optional, defaults to 768): 52 | Dimensionality of the embeddings and hidden states. 53 | n_layer (:obj:`int`, optional, defaults to 12): 54 | Number of hidden layers in the Transformer encoder. 55 | n_head (:obj:`int`, optional, defaults to 12): 56 | Number of attention heads for each attention layer in the Transformer encoder. 57 | afn (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 58 | The non-linear activation function (function or string) in the encoder and pooler. 59 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 60 | resid_pdrop (:obj:`float`, optional, defaults to 0.1): 61 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 62 | embd_pdrop (:obj:`int`, optional, defaults to 0.1): 63 | The dropout ratio for the embeddings. 64 | attn_pdrop (:obj:`float`, optional, defaults to 0.1): 65 | The dropout ratio for the attention. 66 | layer_norm_epsilon (:obj:`float`, optional, defaults to 1e-5): 67 | The epsilon to use in the layer normalization layers 68 | initializer_range (:obj:`float`, optional, defaults to 0.02): 69 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 70 | predict_special_tokens (:obj:`boolean`, optional, defaults to :obj:`True`): 71 | Whether special tokens should be predicted when the model is has a language modeling head. 72 | summary_type (:obj:`string`, optional, defaults to "cls_index"): 73 | Argument used when doing sequence summary. Used in for the multiple choice head in 74 | :class:`~transformers.OpenAIGPTDoubleHeadsModel`. 75 | Is one of the following options: 76 | - 'last' => take the last token hidden state (like XLNet) 77 | - 'first' => take the first token hidden state (like Bert) 78 | - 'mean' => take the mean of all tokens hidden states 79 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) 80 | - 'attn' => Not implemented now, use multi-head attention 81 | summary_use_proj (:obj:`boolean`, optional, defaults to :obj:`True`): 82 | Argument used when doing sequence summary. Used in for the multiple choice head in 83 | :class:`~transformers.OpenAIGPTDoubleHeadsModel`. 84 | Add a projection after the vector extraction 85 | summary_activation (:obj:`string` or :obj:`None`, optional, defaults to :obj:`None`): 86 | Argument used when doing sequence summary. Used in for the multiple choice head in 87 | :class:`~transformers.OpenAIGPTDoubleHeadsModel`. 88 | 'tanh' => add a tanh activation to the output, Other => no activation. 89 | summary_proj_to_labels (:obj:`boolean`, optional, defaults to :obj:`True`): 90 | Argument used when doing sequence summary. Used in for the multiple choice head in 91 | :class:`~transformers.OpenAIGPTDoubleHeadsModel`. 92 | If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. 93 | summary_first_dropout (:obj:`float`, optional, defaults to 0.1): 94 | Argument used when doing sequence summary. Used in for the multiple choice head in 95 | :class:`~transformers.OpenAIGPTDoubleHeadsModel`. 96 | Add a dropout before the projection and activation 97 | 98 | Example:: 99 | 100 | from transformers import OpenAIGPTConfig, OpenAIGPTModel 101 | 102 | # Initializing a GPT configuration 103 | configuration = OpenAIGPTConfig() 104 | 105 | # Initializing a model from the configuration 106 | model = OpenAIGPTModel(configuration) 107 | 108 | # Accessing the model configuration 109 | configuration = model.config 110 | 111 | Attributes: 112 | pretrained_config_archive_map (Dict[str, str]): 113 | A dictionary containing all the available pre-trained checkpoints. 114 | """ 115 | 116 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 117 | model_type = "openai-gpt" 118 | 119 | def __init__( 120 | self, 121 | vocab_size=40478, 122 | n_positions=512, 123 | n_ctx=512, 124 | n_embd=768, 125 | n_layer=12, 126 | n_head=12, 127 | afn="gelu", 128 | resid_pdrop=0.1, 129 | embd_pdrop=0.1, 130 | attn_pdrop=0.1, 131 | layer_norm_epsilon=1e-5, 132 | initializer_range=0.02, 133 | predict_special_tokens=True, 134 | summary_type="cls_index", 135 | summary_use_proj=True, 136 | summary_activation=None, 137 | summary_proj_to_labels=True, 138 | summary_first_dropout=0.1, 139 | **kwargs 140 | ): 141 | super().__init__(**kwargs) 142 | 143 | self.vocab_size = vocab_size 144 | self.n_ctx = n_ctx 145 | self.n_positions = n_positions 146 | self.n_embd = n_embd 147 | self.n_layer = n_layer 148 | self.n_head = n_head 149 | self.afn = afn 150 | self.resid_pdrop = resid_pdrop 151 | self.embd_pdrop = embd_pdrop 152 | self.attn_pdrop = attn_pdrop 153 | self.layer_norm_epsilon = layer_norm_epsilon 154 | self.initializer_range = initializer_range 155 | self.predict_special_tokens = predict_special_tokens 156 | self.summary_type = summary_type 157 | self.summary_use_proj = summary_use_proj 158 | self.summary_activation = summary_activation 159 | self.summary_first_dropout = summary_first_dropout 160 | self.summary_proj_to_labels = summary_proj_to_labels 161 | 162 | @property 163 | def max_position_embeddings(self): 164 | return self.n_positions 165 | 166 | @property 167 | def hidden_size(self): 168 | return self.n_embd 169 | 170 | @property 171 | def num_attention_heads(self): 172 | return self.n_head 173 | 174 | @property 175 | def num_hidden_layers(self): 176 | return self.n_layer 177 | -------------------------------------------------------------------------------- /featurizer/dialogue_dataset.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import torch 3 | import os 4 | from torch.utils.data import Dataset 5 | from torch.nn.utils.rnn import pad_sequence, pack_sequence 6 | from torchvision import transforms 7 | import PIL 8 | import numpy as np 9 | 10 | SPECIAL_TOKENS = ["[CLS]", "[SEP]", "[speaker1]", "[speaker2]", "", "[PAD]"] 11 | 12 | # temporarily use resent18 image statistics 13 | data_transforms = { 14 | 'train': transforms.Compose([ 15 | transforms.RandomResizedCrop(224), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor(), 18 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 19 | ]), 20 | 'val': transforms.Compose([ 21 | transforms.Resize(224), 22 | transforms.CenterCrop(224), 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]), 26 | } 27 | 28 | 29 | class DialoDataset(Dataset): 30 | def __init__(self, data, tokenizer, args): 31 | self.data = data 32 | self.tokenizer = tokenizer 33 | self.pad = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]) 34 | self.args = args 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | def __getitem__(self, item): 40 | persona = self.data[item]['personality'] 41 | history = self.data[item]['utterances'][0]['history'] 42 | img_list = self.data[item]['utterances'][0]['img_list'] 43 | reply = [] 44 | return self.process(persona, history, reply, img_list, self.tokenizer, self.args, with_eos=False) 45 | 46 | def process(self, persona, history, reply, img_list, tokenizer, args, lm_labels=False, with_eos=True): 47 | """ Build a sequence of input from 3 segments: persona, history and last reply. """ 48 | bos, eos, speaker1, speaker2, img_id = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) 49 | sequence = [[bos] + persona] + history + [reply + ([eos] if with_eos else [])] 50 | sequence = [sequence[0]] + [[speaker2 if (len(sequence) - i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])] 51 | instance = {} 52 | input_len, persona_len = len(list(chain(*sequence))), len(persona) 53 | instance["input_ids"] = list(chain(*sequence)) if input_len <= args.max_length else [bos] + persona + list(chain(*sequence))[-args.max_length + 1 + persona_len:] 54 | instance["token_type_ids"] = [bos] + [speaker1] * persona_len + [speaker1 if i % 2 else speaker2 for i, s in enumerate(sequence[1:]) for _ in s][-args.max_length + 1 + persona_len:] 55 | if input_len > args.max_length: 56 | instance["input_ids"][1 + persona_len] = instance["token_type_ids"][1 + persona_len] # Added 'speaker1' or 'speaker2' based on token_type_ids 57 | instance["mc_token_ids"] = len(instance["input_ids"]) - 1 58 | lm_labels_ids = [-100] * len(instance["input_ids"]) 59 | instance["lm_labels"] = lm_labels_ids[-args.max_length:] 60 | if lm_labels: 61 | lm_labels_ids = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:] 62 | instance["lm_labels"] = lm_labels_ids[-args.max_length:] 63 | # imges 64 | def image_transform(images, data_type): 65 | """Read all image data in the single utterance.""" 66 | resp_list = [] 67 | for image in images: 68 | img = torch.zeros(3, 224, 224) 69 | try: 70 | img_tmp = PIL.Image.open(image) 71 | img = data_transforms[data_type](img_tmp) 72 | except: 73 | print("can't open image file: ", image) 74 | pass 75 | finally: 76 | resp_list.append(img) 77 | return resp_list # 没有图片直接传空list 78 | def get_image_chars_indexes(array, token): 79 | """获取图片在input ids 中位置""" 80 | if not isinstance(array, np.ndarray): 81 | array = np.array(array) 82 | # 查找图片字符在上下文中的位置 83 | indexes = np.argwhere(array == token) 84 | return indexes.reshape(1, -1).tolist()[0] 85 | 86 | images_name = [os.path.join(args.image_path, image_name) for image_name in img_list] 87 | images_id = get_image_chars_indexes(instance["input_ids"], img_id) 88 | if not images_id and len(images_id) == 0: 89 | images_name = [] 90 | else: 91 | images_name = images_name[-len(images_id):] # 截取input ids还存在images 92 | instance["input_images"] = image_transform(images_name, "val") 93 | instance["image_ids"] = images_id 94 | assert len(instance["input_images"]) == len(instance["image_ids"]) 95 | return instance["input_ids"], instance["token_type_ids"], instance["input_images"], instance["image_ids"] 96 | 97 | def collate(self, batch): 98 | # input_ids = pad_sequence( 99 | # [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in batch], 100 | # batch_first=True, padding_value=self.pad) 101 | # token_type_ids = pad_sequence( 102 | # [torch.tensor(instance["token_type_ids"], dtype=torch.long) for instance in batch], 103 | # batch_first=True, padding_value=self.pad) 104 | # image_names, image_ids = (batch[0]["image_names"],), (batch[0]["image_ids"],) 105 | # return input_ids, token_type_ids, image_names, image_ids 106 | input_ids, token_type_ids, input_images, image_ids = zip(*batch) 107 | return input_ids, token_type_ids, input_images, image_ids 108 | 109 | 110 | class DialoImageDataset(Dataset): 111 | def __init__(self, dataset, images_feature_path, data_type): 112 | self.dataset = dataset 113 | self.data_type = data_type 114 | self.images_feature_json = torch.load(images_feature_path) if images_feature_path else None 115 | 116 | def __len__(self): 117 | return len(self.dataset["input_ids"]) 118 | 119 | def __getitem__(self, item): 120 | input_ids = self.dataset["input_ids"][item] 121 | token_type_ids = self.dataset["token_type_ids"][item] 122 | if self.images_feature_json: 123 | input_images = self.get_images_feature(self.dataset["image_names"][item]) 124 | else: 125 | input_images = self.image_transform(self.dataset["image_names"][item], self.data_type) 126 | image_ids = self.dataset["image_ids"][item] 127 | lm_labels = self.dataset["lm_labels"][item] 128 | mc_token_ids = self.dataset["mc_token_ids"][item] 129 | mc_labels = self.dataset["mc_labels"][item] 130 | return input_ids, token_type_ids, input_images, image_ids, lm_labels, mc_token_ids, mc_labels 131 | 132 | def image_transform(self, images, data_type): 133 | """Read all image data in the single utterance.""" 134 | resp_list = [] 135 | for image in images: 136 | img = torch.zeros(3, 224, 224) 137 | try: 138 | img_tmp = PIL.Image.open(image) 139 | img = data_transforms[data_type](img_tmp) 140 | except: 141 | print("can't open image file: ", image) 142 | pass 143 | finally: 144 | resp_list.append(img) 145 | return resp_list # 没有图片直接传空list 146 | 147 | def get_images_feature(self, images_name): 148 | """获取image feature""" 149 | images_feature = [] 150 | for image_name in images_name: 151 | tmp = self.images_feature_json.get(image_name, np.zeros(512, dtype=np.float32)) 152 | images_feature.append(torch.from_numpy(tmp)) 153 | 154 | return images_feature 155 | 156 | def get_images_feature_padding(self, single_images_name, single_images_id, sentence_length): 157 | """获取image feature 158 | @param single_images_name: images name(key) 159 | @param single_images_id: 160 | @param sentence_length: 161 | @return: 162 | """ 163 | sample_image_embed = np.zeros(512) # image feature sample 164 | sentence_embeds = [sample_image_embed] * sentence_length 165 | assert len(single_images_name) == len(single_images_id) 166 | if single_images_name and len(single_images_name) > 0: 167 | for idx, image_name in enumerate(single_images_name): 168 | i = single_images_id[idx] 169 | image_feature = self.images_feature_json.get(image_name) 170 | sentence_embeds[i] = image_feature 171 | return sentence_embeds 172 | 173 | def collate_fn(self, batch): 174 | input_ids, token_type_ids, input_images, image_ids, lm_labels, mc_token_ids, mc_labels = zip(*batch) 175 | return input_ids, token_type_ids, input_images, image_ids, lm_labels, mc_token_ids, mc_labels 176 | -------------------------------------------------------------------------------- /featurizer/get_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | from itertools import chain 5 | from torch.utils.data import DataLoader, TensorDataset 6 | from collections import defaultdict 7 | from utils.utils import get_dataset 8 | from tqdm import tqdm 9 | import numpy as np 10 | from featurizer.dialogue_dataset import DialoDataset, DialoImageDataset 11 | 12 | SPECIAL_TOKENS = ["[CLS]", "[SEP]", "[speaker1]", "[speaker2]", "", "[PAD]"] 13 | MODEL_INPUTS = ["input_ids", "mc_token_ids", "lm_labels", "mc_labels", "token_type_ids", "input_images_name", "input_images_id"] 14 | PADDED_INPUTS = ["input_ids", "lm_labels", "token_type_ids"] 15 | 16 | 17 | def get_test_data(dataset_path, tokenizer): 18 | with open(dataset_path, "r", encoding="utf-8") as f: 19 | dataset = json.loads(f.read()) 20 | if isinstance(dataset, dict): 21 | dataset = dataset["test"] 22 | def tokenize(obj): 23 | if isinstance(obj, str): 24 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) 25 | if isinstance(obj, dict): 26 | return dict((n, tokenize(o)) if n != 'img_list' else (n, o) for n, o in obj.items()) 27 | return list(tokenize(o) for o in obj) 28 | dataset = tokenize(dataset) 29 | print('Finished convert tokens to ids...') 30 | return dataset 31 | 32 | 33 | def pad_dataset(dataset, logger, padding=0): 34 | """ Pad the dataset. This could be optimized by defining a Dataset class and padding at the batch level, but this is simpler. """ 35 | max_l = max(len(x) for x in dataset["input_ids"]) 36 | logger.info(f'The max length is {max_l}') 37 | for name in PADDED_INPUTS: 38 | dataset[name] = [x + [padding if name != "lm_labels" else -100] * (max_l - len(x)) for x in dataset[name]] 39 | return dataset 40 | 41 | 42 | def build_input_from_segments(persona, history, reply, img_list, tokenizer, args, lm_labels=False, with_eos=True): 43 | """ Build a sequence of input from 3 segments: persona, history and last reply. """ 44 | bos, eos, speaker1, speaker2, img_id = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) 45 | sequence = [[bos] + persona] + history + [reply + ([eos] if with_eos else [])] 46 | sequence = [sequence[0]] + [[speaker2 if (len(sequence) - i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])] 47 | instance = {} 48 | input_len, persona_len = len(list(chain(*sequence))), len(persona) 49 | instance["input_ids"] = list(chain(*sequence)) if input_len <= args.max_length else [bos] + persona + list(chain(*sequence))[-args.max_length + 1 + persona_len:] 50 | instance["token_type_ids"] = [bos] + [speaker1] * persona_len + [speaker1 if i % 2 else speaker2 for i, s in enumerate(sequence[1:]) for _ in s][-args.max_length + 1 + persona_len:] 51 | if input_len > args.max_length: 52 | instance["input_ids"][1 + persona_len] = instance["token_type_ids"][1 + persona_len] # Added 'speaker1' or 'speaker2' based on token_type_ids 53 | instance["mc_token_ids"] = len(instance["input_ids"]) - 1 54 | lm_labels_ids = [-100] * len(instance["input_ids"]) 55 | instance["lm_labels"] = lm_labels_ids[-args.max_length:] 56 | if lm_labels: 57 | lm_labels_ids = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:] 58 | instance["lm_labels"] = lm_labels_ids[-args.max_length:] 59 | # imges 60 | def get_image_chars_indexes(array, token): 61 | """获取图片在input ids 中位置""" 62 | if not isinstance(array, np.ndarray): 63 | array = np.array(array) 64 | # 查找图片字符在上下文中的位置 65 | indexes = np.argwhere(array == token) 66 | return indexes.reshape(1, -1).tolist()[0] 67 | 68 | images_name = [os.path.join(args.image_path, image_name) for image_name in img_list] 69 | images_id = get_image_chars_indexes(instance["input_ids"], img_id) 70 | if not images_id and len(images_id) == 0: 71 | images_name = [] 72 | else: 73 | images_name = images_name[-len(images_id):] # 截取input ids还存在images 74 | instance["image_names"] = images_name 75 | instance["image_ids"] = images_id 76 | assert len(instance["image_names"]) == len(instance["image_ids"]) 77 | return instance 78 | 79 | 80 | def build_dataloader(args, tokenizer, logger): 81 | """ Prepare the dataset for training and evaluation """ 82 | personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache, logger) 83 | logger.info("Build inputs and labels") 84 | datasets = {"train": defaultdict(list), "dev": defaultdict(list)} 85 | for dataset_name, dataset in personachat.items(): 86 | num_candidates = len(dataset[0]["utterances"][0]["candidates"]) 87 | if args.num_candidates > 0: # and dataset_name == 'train': 88 | num_candidates = min(args.num_candidates, num_candidates) 89 | for dialog in tqdm(dataset): 90 | persona = dialog["personality"].copy() 91 | for utterance in dialog["utterances"]: 92 | history = utterance["history"][-(2*args.max_history+1):] # +1 as question 93 | img_list = utterance["img_list"] 94 | for j, candidate in enumerate(utterance["candidates"][-num_candidates:]): 95 | lm_labels = bool(j == num_candidates-1) 96 | instance = build_input_from_segments(persona, history, candidate, img_list, tokenizer, args, lm_labels) 97 | for input_name, input_array in instance.items(): 98 | datasets[dataset_name][input_name].append(input_array) 99 | datasets[dataset_name]["mc_labels"].append(num_candidates - 1) 100 | datasets[dataset_name]["n_candidates"] = num_candidates 101 | 102 | logger.info("Pad inputs and convert to Tensor") 103 | data = {} 104 | for dataset_name, dataset in datasets.items(): 105 | dataset = pad_dataset(dataset, logger, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])) 106 | data[dataset_name] = dataset 107 | logger.info("Build train and validation dataloaders") 108 | train_dataset, valid_dataset = DialoImageDataset(data["train"], args.images_feature_path, "train"), DialoImageDataset(data["dev"], args.images_feature_path, "dev") 109 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None 110 | valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None 111 | train_loader = DataLoader(train_dataset, 112 | sampler=train_sampler, 113 | batch_size=args.train_batch_size, 114 | collate_fn=train_dataset.collate_fn, 115 | num_workers=args.num_workers, 116 | shuffle=(not args.distributed)) 117 | valid_loader = DataLoader(valid_dataset, 118 | sampler=valid_sampler, 119 | batch_size=args.valid_batch_size, 120 | collate_fn=valid_dataset.collate_fn, 121 | num_workers=args.num_workers, 122 | shuffle=False) 123 | logger.info("Train dataset (Batch, Seq length): {}".format(np.array(train_dataset.dataset["input_ids"]).shape)) 124 | logger.info("Valid dataset (Batch, Seq length): {}".format(np.array(valid_dataset.dataset["input_ids"]).shape)) 125 | return train_loader, valid_loader, train_sampler, valid_sampler 126 | 127 | 128 | def build_test_dataloader(args, tokenizer, logger): 129 | dataset = get_test_data(args.test_data_file, tokenizer) 130 | datasets = {"test": defaultdict(list)} 131 | for dialog in tqdm(dataset): 132 | persona = dialog["personality"].copy() 133 | for utterance in dialog["utterances"]: 134 | history = utterance["history"] 135 | img_list = utterance["img_list"] 136 | reply = [] 137 | instance = build_input_from_segments(persona, history, reply, img_list, tokenizer, args, with_eos=False) 138 | for input_name, input_array in instance.items(): 139 | datasets["test"][input_name].append(input_array) 140 | logger.info("Pad inputs and convert to Tensor") 141 | data = {} 142 | for dataset_name, dataset in datasets.items(): 143 | dataset = pad_dataset(dataset, logger, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])) 144 | data[dataset_name] = dataset 145 | logger.info("Build train and validation dataloaders") 146 | test_dataset = DialoImageDataset(data["test"], images_feature_path=None, data_type="val") 147 | test_loader = DataLoader(test_dataset, 148 | sampler=None, 149 | batch_size=1, 150 | collate_fn=test_dataset.collate_fn, 151 | num_workers=0, 152 | shuffle=False) 153 | return test_loader 154 | 155 | 156 | def get_dataloader(args, tokenizer): 157 | dataset = get_test_data(args.test_data_file, tokenizer) 158 | test_dataset = DialoDataset(dataset, tokenizer, args) 159 | test_loader = DataLoader(test_dataset, 160 | collate_fn=test_dataset.collate, 161 | pin_memory=(args.device == "cuda"), 162 | num_workers=0, 163 | batch_size=1, 164 | shuffle=False) 165 | return test_loader -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # # Copyright (c) 2019-present, HuggingFace Inc. 2 | # All rights reserved. 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 7 | import logging 8 | import random 9 | from argparse import ArgumentParser 10 | from tqdm import tqdm 11 | from pprint import pformat 12 | import warnings 13 | import json 14 | import re 15 | import torch 16 | import torch.nn.functional as F 17 | from gpt_model import GPT2LMHeadModel, BertTokenizer 18 | from featurizer.get_dataloader import SPECIAL_TOKENS, build_input_from_segments, get_dataloader 19 | 20 | PROJECT_FOLDER = os.path.dirname(os.path.realpath(__file__)) 21 | TEST_DATA = os.path.join(PROJECT_FOLDER, 'data/output_data/gpt2_test.json') 22 | # MODEL_CHECKPOINT = os.path.join(PROJECT_FOLDER, 'runs/model') 23 | # MODEL_CHECKPOINT = os.path.join(PROJECT_FOLDER, 'runs/Sep02_13-11-14_5cc6919aa215_gpt2') 24 | # MODEL_CHECKPOINT = os.path.join(PROJECT_FOLDER, 'runs/Sep07_15-03-23_5cc6919aa215_gpt2') # base model 25 | # MODEL_CHECKPOINT = os.path.join(PROJECT_FOLDER, 'runs/Sep07_15-07-21_5cc6919aa215_gpt2') # medium model 26 | MODEL_CHECKPOINT = os.path.join(PROJECT_FOLDER, 'runs/Sep10_12-46-23_5cc6919aa215_gpt2') # base model with images 27 | ONLINE_DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'online_test_data') 28 | 29 | 30 | def online_test_postprocess(args, predictions): 31 | with open(args.online_test_questions) as fp: 32 | ques_items = json.load(fp) 33 | output_list = [] 34 | for question, answer in zip(ques_items, predictions): 35 | result_dict = dict() 36 | result_dict['Id'] = question['Id'] 37 | result_dict['Answer'] = answer 38 | output_list.append(result_dict) 39 | with open(args.online_test_answers, 'w') as fp: 40 | result = json.dumps(output_list, ensure_ascii=False, indent=2) 41 | fp.write(result) 42 | 43 | 44 | def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')): 45 | """ Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering 46 | Args: 47 | logits: logits distribution shape (vocabulary size) 48 | top_k: <=0: no filtering, >0: keep only top k tokens with highest probability. 49 | top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset 50 | whose total probability mass is greater than or equal to the threshold top_p. 51 | In practice, we select the highest probability tokens whose cumulative probability mass exceeds 52 | the threshold top_p. 53 | threshold: a minimal threshold to keep logits 54 | """ 55 | assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code 56 | top_k = min(top_k, logits.size(-1)) 57 | if top_k > 0: 58 | # Remove all tokens with a probability less than the last token in the top-k tokens 59 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 60 | logits[indices_to_remove] = filter_value 61 | 62 | if top_p > 0.0: 63 | # Compute cumulative probabilities of sorted tokens 64 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 65 | cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 66 | 67 | # Remove tokens with cumulative probability above the threshold 68 | sorted_indices_to_remove = cumulative_probabilities > top_p 69 | # Shift the indices to the right to keep also the first token above the threshold 70 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 71 | sorted_indices_to_remove[..., 0] = 0 72 | 73 | # Back to unsorted indices and set them to -infinity 74 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 75 | logits[indices_to_remove] = filter_value 76 | 77 | indices_to_remove = logits < threshold 78 | logits[indices_to_remove] = filter_value 79 | 80 | return logits 81 | 82 | 83 | def sample_sequence(persona, history, tokenizer, model, args, current_output=None): 84 | special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) 85 | if current_output is None: 86 | current_output = [] 87 | 88 | for i in range(args.max_length): 89 | instance = build_input_from_segments(persona, history, current_output, tokenizer, args.input_max_length, with_eos=False) 90 | 91 | input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0) 92 | token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0) 93 | 94 | logits = model(input_ids, token_type_ids=token_type_ids) 95 | if isinstance(logits, tuple): # for gpt2 and maybe others 96 | logits = logits[0] 97 | logits = logits[0, -1, :] / args.temperature 98 | logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p) 99 | probs = F.softmax(logits, dim=-1) 100 | 101 | prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1) 102 | if i < args.min_length and prev.item() in special_tokens_ids: 103 | while prev.item() in special_tokens_ids: 104 | if probs.max().item() == 1: 105 | warnings.warn("Warning: model generating special token with probability 1.") 106 | break # avoid infinitely looping over special token 107 | prev = torch.multinomial(probs, num_samples=1) 108 | 109 | if prev.item() in special_tokens_ids: 110 | break 111 | current_output.append(prev.item()) 112 | 113 | return current_output 114 | 115 | 116 | def beam_search(model, input_ids, token_type_ids, image_names, image_ids, args): 117 | outputs = model.generate(input_ids, 118 | token_type_ids=token_type_ids, 119 | input_images=image_names, 120 | image_ids=image_ids, 121 | num_beams=args.num_beams, 122 | do_sample=False, 123 | temperature=0.7, 124 | top_k=0, 125 | top_p=0.9, 126 | max_length=args.max_length + input_ids.size(-1), 127 | bos_token_id=0, 128 | pad_token_id=1, 129 | eos_token_ids=2, 130 | num_return_sequences=1) 131 | outputs = outputs.data.cpu().numpy().tolist() 132 | return outputs 133 | 134 | 135 | def run(): 136 | parser = ArgumentParser() 137 | parser.add_argument("--test_data_file", type=str, default=TEST_DATA, help="Path or url of the dataset.") 138 | parser.add_argument("--output_file", type=str, default="", help="Path of response generated.") 139 | parser.add_argument("--online_test_questions", type=str, default=os.path.join(ONLINE_DATA_FOLDER, "test_questions.json")) 140 | parser.add_argument("--online_test_answers", type=str, default=os.path.join(ONLINE_DATA_FOLDER, "test_answers.json")) 141 | parser.add_argument("--image_path", type=str, default=os.path.join(ONLINE_DATA_FOLDER, "images_test"), help="Path of the images.") 142 | parser.add_argument("--model_checkpoint", type=str, default=MODEL_CHECKPOINT, help="Path, url or short name of the model") 143 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") 144 | parser.add_argument("--no_sample", action='store_true', default=True, help="Set to use greedy decoding instead of sampling") 145 | parser.add_argument("--input_max_length", type=int, default=256, help="Max length of input sentence") 146 | parser.add_argument("--num_beams", type=int, default=3, help="Beam num") 147 | parser.add_argument("--max_length", type=int, default=100, help="Maximum length of the output utterances") 148 | parser.add_argument("--min_length", type=int, default=1, help="Minimum length of the output utterances") 149 | parser.add_argument("--seed", type=int, default=42, help="Seed") 150 | parser.add_argument("--temperature", type=int, default=0.7, help="Sampling softmax temperature") 151 | parser.add_argument("--top_k", type=int, default=0, help="Filter top-k tokens before sampling (<=0: no filtering)") 152 | parser.add_argument("--top_p", type=float, default=0.9, help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)") 153 | args = parser.parse_args() 154 | 155 | logging.basicConfig(level=logging.INFO) 156 | logger = logging.getLogger(__file__) 157 | logger.info(pformat(args)) 158 | 159 | if args.model_checkpoint == "": 160 | logging.error("Loaded model checkpoint error!") 161 | return 162 | 163 | random.seed(args.seed) 164 | torch.random.manual_seed(args.seed) 165 | torch.cuda.manual_seed(args.seed) 166 | 167 | logger.info("Get pretrained model and tokenizer") 168 | tokenizer_class, model_class = BertTokenizer, GPT2LMHeadModel 169 | tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) 170 | model = model_class.from_pretrained(args.model_checkpoint) 171 | model.to(args.device) 172 | model.eval() 173 | 174 | test_dataloader = get_dataloader(args, tokenizer) 175 | predictions = [] 176 | for batch in tqdm(test_dataloader, ncols=80): 177 | batch = tuple(torch.tensor(input_data).to(args.device) if idx not in [2, 3] else input_data for idx, input_data in enumerate(batch)) 178 | input_ids, token_type_ids, image_names, image_ids = batch 179 | if args.no_sample: 180 | outputs = beam_search(model, input_ids, token_type_ids, image_names, image_ids, args) 181 | else: 182 | with torch.no_grad(): 183 | out_ids = sample_sequence(persona, history, tokenizer, model, args) 184 | for output in outputs: 185 | out_text = tokenizer.convert_ids_to_tokens(output[input_ids.size(-1):]) 186 | out_text = ''.join(out_text) 187 | out_text = out_text.replace('|||', ' ') 188 | out_text = out_text.replace('', '') 189 | out_text = out_text.replace('[UNK]', '') 190 | predictions.append(out_text) 191 | 192 | online_test_postprocess(args, predictions) 193 | 194 | 195 | if __name__ == "__main__": 196 | run() 197 | -------------------------------------------------------------------------------- /online_test_data/test_questions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Id": "a11390785fb6eb9cedf248e3800299eb_0", 4 | "SessionId": "a11390785fb6eb9cedf248e3800299eb", 5 | "Shop": "xjd", 6 | "ProductId": "80bef854b2dba51409b5c37c577a8492", 7 | "Context": [], 8 | "Question": "砂锅一开始是要烧开水吗?|||已购买→售后咨询组" 9 | }, 10 | { 11 | "Id": "c82d617f8b4c36446784c4dd3babeabe_0", 12 | "SessionId": "c82d617f8b4c36446784c4dd3babeabe", 13 | "Shop": "fx", 14 | "ProductId": "59165100d35fa03d5718ae8e9fe4d764", 15 | "Context": [ 16 | "Q:用户发起转人工", 17 | "A:您好,欢迎光临***官方旗舰店,我是您的运动顾问Jenna,3.5-3.8京东女神节,精选好物低至5折。领优惠券再下单最高满900-190,700-150,600-120,快来抢购女神节精选好货吧1. , 2. , 3. |||您好" 18 | ], 19 | "Question": "因为我有一件衣服买大了一码,想换小一号的" 20 | }, 21 | { 22 | "Id": "59610ceba3a58eb13c7a4583e38d1680_0", 23 | "SessionId": "59610ceba3a58eb13c7a4583e38d1680", 24 | "Shop": "fx", 25 | "ProductId": "4391f366be558318022805f69689ceda", 26 | "Context": [ 27 | "Q:用户发起转人工", 28 | "A:您好,我是***官方旗舰店客服专员法米加,很高兴为您服务。#E-s57", 29 | "Q:你好", 30 | "A:抱歉给您带来的不便,为了更快处理您的问题,麻烦您提供下商品照片(需要2张):第一张:问题部分清晰标记照片。第二张:商品全景照片 ", 31 | "Q:2e462cdab8b56469232d713c2484f894.jpg|||3be1f9cd157fd57971673774ed99b967.jpg", 32 | "A:抱歉给您带来的不便,请问您是一只商品有此种问题的吗" 33 | ], 34 | "Question": "对|||怎么办" 35 | }, 36 | { 37 | "Id": "553434cde6226eb28facaeccb4b61bb1_0", 38 | "SessionId": "553434cde6226eb28facaeccb4b61bb1", 39 | "Shop": "fx", 40 | "ProductId": "b00568a196ad7b513a1a18e3faa7b2f6", 41 | "Context": [ 42 | "Q:9bdd7a89c57b8c8e833913e409888d91.jpg|||这样就是确认收货了对么", 43 | "A:是的|||但是货还没发出哦", 44 | "Q:那我现在提交了申请,只要等京东上门来取就可以了吧", 45 | "A:但是您没有货|||怎么上门取件哦" 46 | ], 47 | "Question": "我有货……" 48 | }, 49 | { 50 | "Id": "49fe7e93d0fd0584fc8fab13c222c11c_0", 51 | "SessionId": "49fe7e93d0fd0584fc8fab13c222c11c", 52 | "Shop": "xjd", 53 | "ProductId": "5a97a9a613278c3f07599a6c972c9bfc", 54 | "Context": [ 55 | "Q:|||我要转人工", 56 | "A:尊敬的客户您好,这边是***自营旗舰店售后客服,请问有什么可以帮到您的呢?#E-s36#E-s36", 57 | "Q:亲,我刚买的豆浆,你们给我寄插头是坏的|||734936a3070c37d92cd2cf46839d355b.jpg", 58 | "A:您好,请问有什么可以帮到您的呢?", 59 | "Q:寄给我的插头是坏的", 60 | "A:没有呀", 61 | "Q:图片你看不到吗", 62 | "A:bb29ef5a1cf3d35b33101241145642a8.jpg|||这个吗", 63 | "Q:对啊", 64 | "A:保护套去掉就可以啊", 65 | "Q:我去掉了|||那个插头那边都压扁了", 66 | "A:您好,可以拍张照片吗", 67 | "Q:我把它翘起来了|||但是怕断", 68 | "A:***承诺:15天质量问题可以换货,30天网点检测为质量问题可以申请退货,180天网点检测为质量问题可以申请换货,***是全国保修一年的", 69 | "Q:我现在在打|||我是决定你们发货没检查吗", 70 | "A:您可以提供下商品损坏的照片吗" 71 | ], 72 | "Question": "我刚忘记把那个套子取下来拍了" 73 | }, 74 | { 75 | "Id": "779a8ff138c1fa8f7d8553cfed1ead66_0", 76 | "SessionId": "779a8ff138c1fa8f7d8553cfed1ead66", 77 | "Shop": "xjd", 78 | "ProductId": "468b3451bc2432b7f65d183c330b6c26", 79 | "Context": [ 80 | "Q:|||已购买→售后咨询组", 81 | "A:", 82 | "Q:拆卸", 83 | "A:旋转就是可以的", 84 | "Q:07fd0f5a079f8a497b94f729cdc81d33.jpg", 85 | "A:这个拆不了哈" 86 | ], 87 | "Question": "好的" 88 | }, 89 | { 90 | "Id": "2471df8c3701578a3eb306bfffab4111_0", 91 | "SessionId": "2471df8c3701578a3eb306bfffab4111", 92 | "Shop": "xjd", 93 | "ProductId": "e2d52c26887a8e54984fbf14674284a3", 94 | "Context": [ 95 | "Q:ea8efa4df622cd00c9e8080a7e04fedc.jpg|||未购买→售前咨询组", 96 | "A:" 97 | ], 98 | "Question": "9f02fd4e0436b67f5b28c0fe306fe44f.jpg|||请问这四周是橡胶吗?" 99 | }, 100 | { 101 | "Id": "ced106fe76a565cb319793f1d38661b0_0", 102 | "SessionId": "ced106fe76a565cb319793f1d38661b0", 103 | "Shop": "xjd", 104 | "ProductId": "93efd37b75db017a195fe3805c99308a", 105 | "Context": [ 106 | "Q:", 107 | "A:", 108 | "Q:已购买→售后咨询组", 109 | "A:您好,有什么可以帮到您的呢", 110 | "Q:时间总循环|||滴滴滴响不停", 111 | "A:麻烦您提供下订单号 这边给您看下哦" 112 | ], 113 | "Question": "b29d599fe076dadd7d7b20e16d83e84f.jpg|||ae2ad9caffaf9bd6a83eeac628442c23.jpg" 114 | }, 115 | { 116 | "Id": "90f9ce87b263cd18770dce9b23df108b_0", 117 | "SessionId": "90f9ce87b263cd18770dce9b23df108b", 118 | "Shop": "xjd", 119 | "ProductId": "84541ea72ac2646a28c33b9f0450639c", 120 | "Context": [ 121 | "Q:b5daf5e5720fc5564c90ce6312efca70.jpg|||壶类---收货后疑问解答", 122 | "A:您好在的,#E-s57#E-s57我是本次服务导购客服多多,欢迎小主光临!|||* 您好,现在机器是出现了什么问题可以详细的说明一下吗?#E-s31", 123 | "Q:这个是被用过了吗,有刮伤", 124 | "A:刚收到就是这样的吗?" 125 | ], 126 | "Question": "烧完一壶清水就这样" 127 | }, 128 | { 129 | "Id": "c78f19fcb43af0362593993dad5ef24f_0", 130 | "SessionId": "c78f19fcb43af0362593993dad5ef24f", 131 | "Shop": "xjd", 132 | "ProductId": "3d7e44858679a748b3d33825878bb790", 133 | "Context": [ 134 | "Q:这个有吗|||已购买→售后咨询组", 135 | "A:|||您好,这边是京东自营***厨房电器旗舰店,请问有什么可以帮您的呐|||在的呢,有什么可以帮您的呢" 136 | ], 137 | "Question": "c3468dd3b24244cd0981f42f8d1fe893.jpg|||我买的这个" 138 | }, 139 | { 140 | "Id": "3b97e429cddcc5cee7d6492f463bca3b_0", 141 | "SessionId": "3b97e429cddcc5cee7d6492f463bca3b", 142 | "Shop": "xjd", 143 | "ProductId": null, 144 | "Context": [ 145 | "Q:不可以快递回去换吗|||非得去网点?", 146 | "A:" 147 | ], 148 | "Question": "" 149 | }, 150 | { 151 | "Id": "d7bfa9a78b1289393557df1c509280f7_0", 152 | "SessionId": "d7bfa9a78b1289393557df1c509280f7", 153 | "Shop": "xjd", 154 | "ProductId": "d7b3df2fbcd89f4affc497947d6bef70", 155 | "Context": [ 156 | "Q:", 157 | "A:", 158 | "Q:已购买→售后咨询组", 159 | "A:您好,有什么可以帮到您的呢", 160 | "Q:e94dcc3b9d9544c9c945a435ed65ddb8.jpg|||这个电陶炉可以用砂锅吗", 161 | "A:可以的", 162 | "Q:79e61078e7abc567c5166ad9dd48b17a.jpg|||这种可以用吗", 163 | "A:嗯嗯", 164 | "Q:这是我刚买的砂锅还没有使用一次,在这上面不到5分钟就给我炸裂了|||a013500ca475448bc345b4aeb2b2be9a.jpg", 165 | "A:可以使用 但是您要确定耐热吗", 166 | "Q:922a5f4d03f80fe82668d52239480db2.jpg|||我买的的明火都烧不裂,经过商温的|||这是高温彩铀,", 167 | "A:抱歉哈 这个陶炉不限锅体材质的哈 " 168 | ], 169 | "Question": "没有人了吗" 170 | }, 171 | { 172 | "Id": "c8397fa44f876ebd108c72de8ac0409b_0", 173 | "SessionId": "c8397fa44f876ebd108c72de8ac0409b", 174 | "Shop": "fx", 175 | "ProductId": "60cf0b6e4e7135594e49c11eaa2f5eff", 176 | "Context": [ 177 | "Q:亲在吗|||售后咨询组", 178 | "A:您好~真是非常抱歉,现在咨询量比较大,您现在进入咨询排队,排到了会给您回复以上问题的哈,请您耐心等待~(建议您可以一次说好您的疑问,刷屏会加长客服看您问题的时间哦)自主售后问答:1-申请退换货催审核:答:亲爱的~售后是在逐一处理退换货审核的,会在48小时内处理的,还请您耐心等待下,这边售后已经是在加班加点处理,如商品不影响二次销售的可以优先寄回2-咨询商品瑕疵,质量问题:答:亲爱的~真是很抱歉,麻烦您先提供相关的图片给我,这边看到消息后会给您留言处理方案3-咨询物流未更新:答:因疫情物流更新相对较慢些,亲爱的您可以先耐心等待下,超过48小时还未更新的,您可以在联系我们咨询物流核实4-错发、漏发的问题:答:有出现错漏发的顾客,麻烦您提供下包装6个面照片,鞋舌上的信息照片、鞋盒信息标码照片,发货清单照片,客服看到后会给您处理5-鞋盒破损答:亲爱的~出现这样问题我们也很难看,物流途中造成的,鞋盒问题不影响穿着的,如您需要处理的您可以提供照片,客服看到了会给您留言处理方案 ", 179 | "Q:1ff34385a30a60b0387617b31d961d0c.jpg|||这个是订单编号号|||吗", 180 | "A:是的哟" 181 | ], 182 | "Question": "好|||我寄出了|||一会快递单号填哪里" 183 | }, 184 | { 185 | "Id": "2db98b28381ed0c3aab4cfc49f371838_0", 186 | "SessionId": "2db98b28381ed0c3aab4cfc49f371838", 187 | "Shop": "fx", 188 | "ProductId": "7b597edeabb67abeafbf2647c5162929", 189 | "Context": [ 190 | "Q:用户发起转人工", 191 | "A:您好,在的,请问有什么可以帮到您的吗?", 192 | "Q:衣服掉色厉害,把白色部分都染红了", 193 | "A:您提供照片呢", 194 | "Q:99252dcede303832e3d4af6a38d6479d.jpg|||1965180fa63be4185bc528efe2cdbc0c.jpg", 195 | "A:因为新的商品都会有一层保护色的,第一次水洗可能会有掉色正常的呢。针对这个问题,给您造成不便,实在抱歉呢亲~ ", 196 | "Q:那也太厉害了,怎么穿啊,肯定洗不掉了", 197 | "A:您可以试试漂白水哦" 198 | ], 199 | "Question": "能退吗" 200 | }, 201 | { 202 | "Id": "d79eb8dbebff125b1159e133eba328d0_0", 203 | "SessionId": "d79eb8dbebff125b1159e133eba328d0", 204 | "Shop": "xjd", 205 | "ProductId": "9a836e92973b8f69d59b0fd7c3fcbdb9", 206 | "Context": [ 207 | "Q:用户发起转人工", 208 | "A:亲爱的顾客您好,欢迎光临***京东自营旗舰店,本店产品由***品牌工厂直供,正品行货,京东物流配送,现在购买保价三十天的哦,喜欢的话抓紧时间购买哦。***剃须刀 【TOP爆款】1小时快充I 90分钟续航***剃须刀 【TOP爆款】全身水洗I干湿双剃***理发器?【TOP爆款】全身水洗?I 200分钟续航 ***电吹风 【TOP爆款】负离子护发I 大功率***剃须刀 【新品上新】USB充电款更多优惠尽在***京东自营旗舰店" 209 | ], 210 | "Question": "|||你好|||eee1dbad6302726ceef18b042d8636c3.jpg|||这个意思是这个时间点到嘛" 211 | }, 212 | { 213 | "Id": "bec8e9852bfea7449583cf228d2510b2_0", 214 | "SessionId": "bec8e9852bfea7449583cf228d2510b2", 215 | "Shop": "fx", 216 | "ProductId": "6cb9f57957b0dc81e2675275d2fd3f97", 217 | "Context": [ 218 | "Q:0786a405633bbd718be674b0dbe1bcd5.jpg|||917070b7080172b2414d5d25027495a8.jpg|||8591d69363494b366942a089feddfcd4.jpg|||售后咨询组", 219 | "A:您好,欢迎光临***官方旗舰店,请问有什么可以帮到您的呢?|||是错发尺码吗|||您把那个申请理由改一下就可以了呢|||然后这边把您备注", 220 | "Q:我提交多次申请", 221 | "A:您申请错发 是要举证的呢|||真的很抱歉,让您不愉快了~请问包裹是否收件本人签收,签收时外包装是否完好哦? 麻烦您分别提供下:1、商品的正面照2、鞋底鞋内编码3、外包装的图片4、鞋盒商品信息5、包裹内商品清单 6、快递6个面的照片|||所以 您这边把理由改成7天无理由" 222 | ], 223 | "Question": "京东快递已经把鞋子拿走了" 224 | }, 225 | { 226 | "Id": "18794790761fb09ee013e9a89c1616e1_0", 227 | "SessionId": "18794790761fb09ee013e9a89c1616e1", 228 | "Shop": "xjd", 229 | "ProductId": "a4d22cb830e7de7ce6d8c2d4e4065d66", 230 | "Context": [ 231 | "Q:今天客服休息吗?|||未购买→售前咨询组", 232 | "A:||||||在的呢,有什么可以帮到您", 233 | "Q:你好", 234 | "A:在的呢,有什么可以帮到您", 235 | "Q:榨汁机64元与79元区别在那里?", 236 | "A:请问您看的是哪款商品呢,方便发下链接吗?", 237 | "Q:128e53ba3ae93f3e1cd23660ced82a40.jpg", 238 | "A:在的呢,有什么可以帮到您", 239 | "Q:这二种", 240 | "A:|||颜色不同 功能材质一样的", 241 | "Q:是充电的", 242 | "A:充电使用的呢", 243 | "Q:容量多少", 244 | "A:250ml", 245 | "Q:一木不?|||杯", 246 | "A:是的呢" 247 | ], 248 | "Question": "会不会容易坏" 249 | }, 250 | { 251 | "Id": "d87b76bc7dc5157bba6055a01849b063_0", 252 | "SessionId": "d87b76bc7dc5157bba6055a01849b063", 253 | "Shop": "fx", 254 | "ProductId": "8b41c0f64f529029f746a4558bd6256f", 255 | "Context": [ 256 | "Q:没抢到|||我看到了你上午下单的一股现在降价了呀|||能包价吗", 257 | "A:页面价格为准哦", 258 | "Q:16079b959f743abcdca7d73125be33ab.jpg|||看看降价了呀|||我买的时候480,440", 259 | "A:哪一个订单呢", 260 | "Q:196a087a6237354860b42cfcc1f3e00e.jpg|||快帮我看看", 261 | "A:稍等下哦 这边帮您查询下呢", 262 | "Q:是不是应该返差价呢?", 263 | "A:这边看看哦", 264 | "Q:怎么样?", 265 | "A:稍等一下呢", 266 | "Q:现在还有这上85折吧?", 267 | "A:需要页面价格为准的呢", 268 | "Q:请教你一下,我怎么下单最合适?|||都在页面上,请问哪个是准的?", 269 | "A:需要首页领取优惠券 满足要求就可以呢", 270 | "Q:你好,我问的是,我已经下了单的T恤降价了,给返回差价吗?|||cf77348be65f45ae5ca8751fbcce3064.jpg|||看到了吗,现在下单是这个金额。", 271 | "A:稍等一下哦", 272 | "Q:我在等,看看是否取消订单", 273 | "A:#E-s21", 274 | "Q:怎么办?|||请教一下,应该怎么办?", 275 | "A:这边给您看一下哦", 276 | "Q:抓紧时间呢", 277 | "A:好的呢|||麻烦您稍等一下", 278 | "Q:嗯嗯", 279 | "A:#E-s21|||c858d3132deb0d84493705e53242f519.jpg|||这个是您两件t的价格哈", 280 | "Q:按照现在的价格就不是这个价位了呀", 281 | "A:稍等哈", 282 | "Q:如果重新下单会有接近一百元的差价", 283 | "A:您稍等哦 这边看看呢" 284 | ], 285 | "Question": "认真点好嘛?|||就当是你自己买东西一样认真" 286 | }, 287 | { 288 | "Id": "6fa0cc56db146efa6ff7d76b97d97209_0", 289 | "SessionId": "6fa0cc56db146efa6ff7d76b97d97209", 290 | "Shop": "xjd", 291 | "ProductId": "3d7e44858679a748b3d33825878bb790", 292 | "Context": [ 293 | "Q:有人吗|||我要转人工", 294 | "A:尊敬的客户您好,请问有什么可以帮到您的呢?#E-s68", 295 | "Q:想要破壁机", 296 | "A:您需要选购什么样功能的商品呢", 297 | "Q:有什么推荐", 298 | "A:|||***(上市时间2019年9月2日)这款主要有以下特点哦1:容量:热饮400-1400ml,冷饮250-1750ml2:操作方式:触摸按键控制3: 转速:3**转/min4:6叶旋风割刀5: 功能:五谷浆、滋补糊、养生粥、浓汤、果蔬汁、奶昔、预约、启动/取消、清洗6:功率:搅拌900瓦,加热900瓦核心卖点:1、一键通控制 2、大功率粉碎 3、12H预约", 299 | "Q:豆浆可以吗", 300 | "A:可以的", 301 | "Q:热的吗", 302 | "A:是打|||是的呢", 303 | "Q:黄豆直接放 进去就可以了吗", 304 | "A:破壁机有加热功能,打出的豆浆可以直接饮用,无需另外加热~用干豆做豆浆,食物与水的比例约为 1:13,用湿豆做豆浆,食物与水的比例约为 1:10,打湿豆需要提前用清水将黄豆浸泡6~7小时为宜。|||是哒", 305 | "Q:干豆和湿豆有什么不同", 306 | "A:湿豆口感好点", 307 | "Q:做豆浆要多久时间", 308 | "A:您好,这款宝贝每个功能时间都是设定好的哦,您可以根据不同食材,选择不同的功能,使用非常方便的哦~|||30分钟的呢", 309 | "Q:几杯呢", 310 | "A:单杯的呢", 311 | "Q:做出来的豆浆有多少杯", 312 | "A:您好,这款破壁机约1.75L大容量,一次可做五杯(300ml),全家共享营养美味。", 313 | "Q:现在多少钱", 314 | "A:您好,商品选择需要的款式,页面上会显示相应的金额,到手价以最终下单金额为准哦~|||***多功能破壁机(邓伦推荐款)❥①特色功能:预约,自动清洗❥②主杯容量:1.71L-1.8L大容量❥③高硼硅玻璃耐热耐冷不易碎❥④立体熬煮不糊底,智能防溢❥⑥冷、热饮自由选,一机多用||||||下单看看哈", 315 | "Q:有优惠吗", 316 | "A:不好意思,店铺价格都是统一的呢,已经是很优惠,希望您可以体谅下呢~|||没有的呢 ", 317 | "Q:不是下单立减130吗", 318 | "A:您好,具体活动以页面为主哦,您可已看下页面或加入购物车看一下相关促销活动哈~", 319 | "Q:a1be33bf6d44c5ff1567acd41f84ade6.jpg|||这些那个好", 320 | "A:您需要选购什么样功能的商品呢", 321 | "Q:功能齐全的", 322 | "A:您想要什么功能的呢", 323 | "Q:你说哪款好", 324 | "A:|||这款也不错的呢 |||***(上市时间2019年9月2日)这款主要有以下特点哦1:容量:热饮400-1400ml,冷饮250-1750ml2:操作方式:触摸按键控制3: 转速:3**转/min4:6叶旋风割刀5: 功能:五谷浆、滋补糊、养生粥、浓汤、果蔬汁、奶昔、预约、启动/取消、清洗6:功率:搅拌900瓦,加热900瓦核心卖点:1、一键通控制 2、大功率粉碎 3、12H预约", 325 | "Q:没秒杀吗", 326 | "A:您好,店铺目前没有活动通知呢,产品价格都是很实惠,有需要可以直接带走它,如需关注活动动态,您可以关注下我们的店铺首页,方便您一手掌握活动详细哦~", 327 | "Q:第一个没货吗|||超值爆款", 328 | "A:您可以关注下商品页面哦,有活动的话,商品页面都是会显示的哦,在产品的这个位置呦【图片】|||您好,能下单商品都是有货的,部分商品为预售商品,具体以商品详情页面为准。如提示缺货,您可以点击到货通知喔~", 329 | "Q:怎么有些有闪购", 330 | "A:是的呢 |||是抢购的额", 331 | "Q:究竟哪款好呢", 332 | "A:|||这款吗", 333 | "Q:其他不怎么好吗", 334 | "A:您是想要什么功能的呢", 335 | "Q:我要又好又不贵的", 336 | "A:***多功能破壁机(邓伦推荐款)❥①特色功能:预约,自动清洗❥②主杯容量:1.71L-1.8L大容量❥③高硼硅玻璃耐热耐冷不易碎❥④立体熬煮不糊底,智能防溢❥⑥冷、热饮自由选,一机多用", 337 | "Q:只有一款啊", 338 | "A:这款是可以的呢|||", 339 | "Q:可以 减130啊?", 340 | "A:您好,具体活动以页面为主哦,您可已看下页面或加入购物车看一下相关促销活动哈~", 341 | "Q:有吗", 342 | "A:下单看看哈" 343 | ], 344 | "Question": "3adde037a80e81eb4a9ddd3c33745314.jpg|||这个是什么" 345 | }, 346 | { 347 | "Id": "15cc323372499df3c4eceddfa8b1989e_0", 348 | "SessionId": "15cc323372499df3c4eceddfa8b1989e", 349 | "Shop": "xjd", 350 | "ProductId": "3d7e44858679a748b3d33825878bb790", 351 | "Context": [ 352 | "Q:64d18188d7f40fdde68816583c97db72.jpg|||这种怎么有撕贴的|||我要转人工", 353 | "A:您好,这边是京东自营***旗舰店售后客服,请问有什么可以帮您的呐|||直接 撕掉哦", 354 | "Q:我的机器没有得撕,是不是用过了的", 355 | "A:不是的饿呢|||有的哦", 356 | "Q:我的没有", 357 | "A:有的呢", 358 | "Q:也没有量杯", 359 | "A:京东承诺:30天质量问题可以申请退货,180天质量问题可以申请换货,***是全国保修一年的哦!", 360 | "Q:都怀疑是别人用过的呢", 361 | "A:***多功能破壁机(邓伦推荐款)❥①特色功能:预约,自动清洗❥②主杯容量:1.71L-1.8L大容量❥③高硼硅玻璃耐热耐冷不易碎❥④立体熬煮不糊底,智能防溢❥⑥冷、热饮自由选,一机多用|||不会的呢", 362 | "Q:你是说标签还是那个膜有", 363 | "A:有膜的", 364 | "Q:标签有还是还有个膜|||没有膜", 365 | "A:还有个膜|||有的|||用力撕‘", 366 | "Q:都没有得撕|||怎么用力", 367 | "A:京东承诺:30天质量问题可以申请退货,180天质量问题可以申请换货,***是全国保修一年的哦!", 368 | "Q:图片是评价的", 369 | "A:", 370 | "Q:都没有得撕开的", 371 | "A:确实很抱歉,给您带来不便了。’", 372 | "Q:都怀疑是用过的", 373 | "A:您好,我们是***自营旗舰店哦,均为正品,商品质量都是经过严格审查,享受三包服务,售后保障,您可以放心选购呢~", 374 | "Q:昨晚那个客服又说很多人都是这样提出这个问题", 375 | "A:不会的呢", 376 | "Q:是真的没有", 377 | "A:京东承诺:30天质量问题可以申请退货,180天质量问题可以申请换货,***是全国保修一年的哦!|||", 378 | "Q:昨晚又说发那个量杯给我的", 379 | "A:确实很抱歉,给您带来不便了。", 380 | "Q:昨晚卖家说没有都没有膜的", 381 | "A:都有的呢", 382 | "Q:可没有", 383 | "A:|||京东承诺:30天质量问题可以申请退货,180天质量问题可以申请换货,***是全国保修一年的哦!", 384 | "Q:昨晚又说是正品,全新的", 385 | "A:您好,我们是***自营旗舰店哦,均为正品,商品质量都是经过严格审查,享受三包服务,售后保障,您可以放心选购呢~|||#E-s21", 386 | "Q:没有膜的", 387 | "A:有的呢", 388 | "Q:真的没有", 389 | "A:好的呢|||", 390 | "Q:也没有量杯,我都怀疑是用过了|||我要投诉", 391 | "A:确实很抱歉,给您带来不便了。", 392 | "Q:卖家", 393 | "A:在的呢", 394 | "Q:没有撕贴的", 395 | "A:|||京东承诺:30天质量问题可以申请退货,180天质量问题可以申请换货,***是全国保修一年的哦!", 396 | "Q:退货是退全额吗", 397 | "A:京东是根据您的实际支付金额进行退款,若核对后,退款金额小于您实际支付请联系京东客服为您核实处理哦", 398 | "Q:保险呢", 399 | "A:感谢您的咨询 祝您生活愉快! |||非常抱歉,您咨询的问题属于京东客服服务范围,联系京东客服:手机端--我的--客户服务--在线客服咨询,点击下面的开始咨询即可!", 400 | "Q:现在这不是客服吗|||都没有撕贴的|||我可以退款吗", 401 | "A: 可在京东APP【我的-客户服务-退换/售后】中提交退换申请,具体以售后审核意见为准哦。温馨提示:目前京东支持同订单多个商品中部分商品退货。(注:七天无理由退货需保持商品全新未使用即可)|||保险是京东金融提供的 这边查询不到呢", 402 | "Q:这样子的啊", 403 | "A:很高心为您服务,还有什么可以帮您的嘛?|||是的呢", 404 | "Q:怎么我的机器没有撕贴的", 405 | "A:可以去网点检测一下哦", 406 | "Q:07a746f5b677db49e02c87cc39503090.jpg", 407 | "A:看到是有的哦", 408 | "Q:附近没有网点呀", 409 | "A:有膜的", 410 | "Q:这图片是评价的图片,我的机器在家呢", 411 | "A:c46d53d31e70b810aaac928b6011c8b8.jpg|||好的 那您回家的时候拍个照片哈", 412 | "Q:这是别人的,我在评价那里找的", 413 | "A:好的呢", 414 | "Q:我回家拍的就是没有,昨晚问,卖家是说都没有的,我还有记录", 415 | "A:好的呢", 416 | "Q:你昨晚说没有", 417 | "A:没有的话 可以申请退款哦" 418 | ], 419 | "Question": "退款,可上门收,运费谁承担" 420 | } 421 | ] -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 3 | import math 4 | import logging 5 | from pprint import pformat 6 | from argparse import ArgumentParser 7 | 8 | import torch 9 | from torch.optim.lr_scheduler import LambdaLR 10 | from torch.nn.parallel import DistributedDataParallel 11 | from ignite.engine import Engine, Events 12 | from ignite.handlers import ModelCheckpoint 13 | from ignite.metrics import Accuracy, Loss, MetricsLambda, RunningAverage 14 | from ignite.contrib.handlers import ProgressBar, PiecewiseLinear, LRScheduler 15 | from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler 16 | from gpt_model import AdamW, GPT2LMHeadModel, GPT2Config, BertTokenizer, WEIGHTS_NAME, CONFIG_NAME 17 | 18 | from utils.utils import make_logdir 19 | from featurizer.get_dataloader import build_dataloader 20 | 21 | PROJECT_FOLDER = os.path.dirname(os.path.realpath(__file__)) 22 | MODEL_CHECKPOINT = os.path.join(PROJECT_FOLDER, "runs/pretrained") 23 | IMG_FEATURE_FOLDER = os.path.join(PROJECT_FOLDER, "data/output_data/train_images_feature.pkl") 24 | IMG_FOLDER = os.path.join(PROJECT_FOLDER, "data/raw_data/images_train_dev") 25 | DATA_FOLDER = os.path.join(PROJECT_FOLDER, "data/output_data/gpt2_train_dev_persona_BERTvocab.json") 26 | DATA_CACHE = os.path.join(PROJECT_FOLDER, "data/output_data/gpt2_train_dev_persona_BERTvocab") 27 | # DATA_FOLDER = os.path.join(PROJECT_FOLDER, "data/output_data/test_del.json") 28 | # DATA_CACHE = os.path.join(PROJECT_FOLDER, "data/output_data/test_del") 29 | # DATA_FOLDER = os.path.join(PROJECT_FOLDER, "data/output_data/gpt2_images.json") 30 | # DATA_CACHE = os.path.join(PROJECT_FOLDER, "data/output_data/gpt2_images") 31 | VOCAB_PATH = os.path.join(PROJECT_FOLDER, "config/Custom/vocab_custom.txt") 32 | CONFIG_PATH = os.path.join(PROJECT_FOLDER, "config/Custom/config_base.json") 33 | 34 | ATTR_TO_SPECIAL_TOKEN = {'additional_special_tokens': ['', '', '', '#E-s', '|||', '[UNK]']} 35 | 36 | logger = logging.getLogger(__file__) 37 | 38 | 39 | def average_distributed_scalar(scalar, args): 40 | """ Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """ 41 | if args.local_rank == -1: 42 | return scalar 43 | scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size() 44 | torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM) 45 | return scalar_t.item() 46 | 47 | 48 | def add_special_tokens_(model, tokenizer): 49 | """ Add special tokens to the tokenizer and the model if they have not already been added. """ 50 | num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) # doesn't add if they are already there 51 | if num_added_tokens > 0: 52 | model.resize_token_embeddings(new_num_tokens=len(tokenizer)) 53 | 54 | 55 | def train(): 56 | parser = ArgumentParser() 57 | parser.add_argument("--dataset_path", type=str, default=DATA_FOLDER, help="Path of the dataset.") 58 | parser.add_argument("--image_path", type=str, default=IMG_FOLDER, help="Path of the images.") 59 | parser.add_argument("--images_feature_path", type=str, default=IMG_FEATURE_FOLDER, help="Path of the images.") 60 | parser.add_argument("--dataset_cache", type=str, default=DATA_CACHE, help="Path of the dataset cache_no_pretrained") 61 | parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model") 62 | parser.add_argument('--dhead_gpt2', action='store_true', default=False, help="use double head gpt2") 63 | parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step") 64 | parser.add_argument('--pretrained', action='store_true', default=True, help="If False train from scratch") 65 | parser.add_argument("--num_candidates", type=int, default=1, help="Number of candidates for training") 66 | parser.add_argument("--max_history", type=int, default=3, help="Number of previous turns to keep in history") 67 | parser.add_argument("--max_length", type=int, default=256, help="Max length of input sentence") 68 | parser.add_argument("--train_batch_size", type=int, default=58, help="Batch size for training") 69 | parser.add_argument("--valid_batch_size", type=int, default=32, help="Batch size for validation") 70 | parser.add_argument("--gradient_accumulation_steps", type=int, default=9, help="Accumulate gradients on several steps") 71 | parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") 72 | parser.add_argument("--scheduler", type=str, default="linear", choices=['noam', 'linear'], help="method of optim") 73 | parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)") 74 | parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps") 75 | parser.add_argument("--lm_coef", type=float, default=2.0, help="LM loss coefficient") 76 | parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") 77 | parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") 78 | parser.add_argument("--n_epochs", type=int, default=50, help="Number of training epochs") 79 | parser.add_argument("--num_workers", type=int, default=0, help="Number of subprocesses for data loading") 80 | parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") 81 | parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training") 82 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") 83 | parser.add_argument("--fp16", type=str, default="O1", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") 84 | parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") 85 | args = parser.parse_args() 86 | 87 | # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes 88 | logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 89 | logger.warning("Running process %d", args.local_rank) # This is a logger.warning: it will be printed by all distributed processes 90 | logger.info("Arguments: %s", pformat(args)) 91 | 92 | # Initialize distributed training if needed 93 | args.distributed = (args.local_rank != -1) 94 | if args.distributed: 95 | torch.cuda.set_device(args.local_rank) 96 | args.device = torch.device("cuda", args.local_rank) 97 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 98 | logger.info("Prepare tokenizer, pretrained model and optimizer.") 99 | tokenizer_class = BertTokenizer 100 | config_class = GPT2Config # GPT2Config if "gpt2" in args.model_checkpoint else OpenAIGPTConfig 101 | model_class = GPT2LMHeadModel # GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel 102 | if args.pretrained: 103 | tokenizer = tokenizer_class.from_pretrained(MODEL_CHECKPOINT, do_lower_case=False) 104 | # tokenizer = tokenizer_class(vocab_file=VOCAB_PATH, do_lower_case=True) 105 | model = model_class.from_pretrained(MODEL_CHECKPOINT) 106 | else: 107 | tokenizer = tokenizer_class(vocab_file=VOCAB_PATH, do_lower_case=False) 108 | tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) 109 | config = config_class.from_json_file(CONFIG_PATH) 110 | model = model_class(config) 111 | model.to(args.device) 112 | # Add special tokens if they are not already added 113 | # add_special_tokens_(model, tokenizer) 114 | # optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) 115 | optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, correct_bias=True) 116 | 117 | # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) 118 | if args.fp16: 119 | from apex import amp # Apex is only required if we use fp16 training 120 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) 121 | if args.distributed: 122 | model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) 123 | 124 | logger.info("Prepare datasets") 125 | train_loader, val_loader, train_sampler, valid_sampler = build_dataloader(args, tokenizer, logger) 126 | 127 | def update(engine, batch): 128 | model.train() 129 | batch = tuple(torch.tensor(input_data).to(args.device) if idx not in [2, 3] else input_data for idx, input_data in enumerate(batch)) 130 | input_ids, token_type_ids, input_images, image_ids, lm_labels, mc_token_ids, mc_labels = batch 131 | if args.dhead_gpt2: 132 | (lm_loss), (mc_loss), *_ = model(input_ids, 133 | token_type_ids=token_type_ids, 134 | mc_token_ids=mc_token_ids, 135 | mc_labels=mc_labels, 136 | lm_labels=lm_labels) 137 | loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps 138 | else: 139 | (lm_loss), *_ = model(input_ids, 140 | labels=lm_labels, 141 | token_type_ids=token_type_ids, 142 | input_images=input_images, 143 | image_ids=image_ids) 144 | loss = lm_loss / args.gradient_accumulation_steps 145 | if args.fp16: 146 | with amp.scale_loss(loss, optimizer) as scaled_loss: 147 | scaled_loss.backward() 148 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) 149 | else: 150 | loss.backward() 151 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) 152 | if engine.state.iteration % args.gradient_accumulation_steps == 0: 153 | optimizer.step() 154 | optimizer.zero_grad() 155 | return loss.item() #, optimizer.param_groups[0]['lr'] 156 | trainer = Engine(update) 157 | 158 | # Evaluation function and evaluator (evaluator output is the input of the metrics) 159 | def inference(engine, batch): 160 | model.eval() 161 | with torch.no_grad(): 162 | batch = tuple(input_tensor.to(args.device) for input_tensor in batch) 163 | input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch 164 | # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) 165 | # if we dont send labels to model, it doesnt return losses 166 | if args.dhead_gpt2: 167 | lm_logits, mc_logits, *_ = model( 168 | input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, 169 | ) 170 | lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) 171 | lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) 172 | return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) 173 | else: 174 | lm_logits, *_ = model(input_ids, token_type_ids=token_type_ids) 175 | lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) 176 | lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) 177 | return lm_logits_flat_shifted, lm_labels_flat_shifted 178 | evaluator = Engine(inference) 179 | 180 | # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch 181 | # trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) 182 | if args.n_epochs < 1: 183 | trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) 184 | if args.eval_before_start: 185 | trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) 186 | 187 | # Make sure distributed data samplers split the dataset nicely between the distributed processes 188 | if args.distributed: 189 | trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) 190 | evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) 191 | 192 | # Linearly decrease the learning rate from lr to zero 193 | model_size = args.n_emd 194 | noam_lambda = lambda step: ( 195 | model_size ** (-0.5) * min((step + 1) ** (-0.5), (step + 1) * args.warmup_steps ** (-1.5))) 196 | noam_scheduler = LambdaLR(optimizer, lr_lambda=noam_lambda, last_epoch=args.from_step) 197 | scheduler = LRScheduler(noam_scheduler) 198 | if args.scheduler == "linear": 199 | scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) 200 | trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) 201 | 202 | # Prepare metrics - note how we compute distributed metrics 203 | RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") 204 | metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])), 205 | "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))} 206 | metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), 207 | "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)}) 208 | metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) 209 | for name, metric in metrics.items(): 210 | metric.attach(evaluator, name) 211 | 212 | # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train 213 | if args.local_rank in [-1, 0]: 214 | pbar = ProgressBar(persist=True) 215 | pbar.attach(trainer, metric_names=["loss"]) 216 | evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) 217 | 218 | log_dir = make_logdir(args.model_checkpoint) 219 | tb_logger = TensorboardLogger(log_dir) 220 | 221 | tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) 222 | tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) 223 | # tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) 224 | 225 | checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', n_saved=None) 226 | trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" takes care of distributed encapsulation 227 | 228 | torch.save(args, log_dir + '/model_training_args.bin') 229 | getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) 230 | tokenizer.save_pretrained(log_dir) 231 | 232 | # Run the training 233 | trainer.run(train_loader, max_epochs=args.n_epochs) 234 | 235 | # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) 236 | if args.local_rank in [-1, 0] and args.n_epochs > 0: 237 | os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner) 238 | tb_logger.close() 239 | 240 | 241 | if __name__ == "__main__": 242 | train() 243 | -------------------------------------------------------------------------------- /gpt_model/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | 19 | import copy 20 | import json 21 | import logging 22 | import os 23 | from typing import Dict, Optional, Tuple 24 | 25 | from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values. 41 | - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. 42 | 43 | Args: 44 | finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`): 45 | Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 46 | num_labels (:obj:`int`, `optional`, defaults to `2`): 47 | Number of classes to use when the model is a classification model (sequences/tokens) 48 | output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): 49 | Should the model returns attentions weights. 50 | output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`): 51 | Should the model returns all hidden-states. 52 | torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): 53 | Is the model used with Torchscript (for PyTorch models). 54 | """ 55 | pretrained_config_archive_map = {} # type: Dict[str, str] 56 | model_type = "" # type: str 57 | 58 | def __init__(self, **kwargs): 59 | # Attributes with defaults 60 | self.output_attentions = kwargs.pop("output_attentions", False) 61 | self.output_hidden_states = kwargs.pop("output_hidden_states", False) 62 | self.output_past = kwargs.pop("output_past", True) # Not used by all models 63 | self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models 64 | self.use_bfloat16 = kwargs.pop("use_bfloat16", False) 65 | self.pruned_heads = kwargs.pop("pruned_heads", {}) 66 | 67 | # Is decoder is used in encoder-decoder models to differentiate encoder from decoder 68 | self.is_decoder = kwargs.pop("is_decoder", False) 69 | 70 | # Parameters for sequence generation 71 | self.max_length = kwargs.pop("max_length", 20) 72 | self.do_sample = kwargs.pop("do_sample", False) 73 | self.num_beams = kwargs.pop("num_beams", 1) 74 | self.temperature = kwargs.pop("temperature", 1.0) 75 | self.top_k = kwargs.pop("top_k", 50) 76 | self.top_p = kwargs.pop("top_p", 1.0) 77 | self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) 78 | self.bos_token_id = kwargs.pop("bos_token_id", None) 79 | self.pad_token_id = kwargs.pop("pad_token_id", None) 80 | self.eos_token_ids = kwargs.pop("eos_token_ids", None) 81 | self.length_penalty = kwargs.pop("length_penalty", 1.0) 82 | self.num_return_sequences = kwargs.pop("num_return_sequences", 1) 83 | 84 | # Fine-tuning task arguments 85 | self.architectures = kwargs.pop("architectures", None) 86 | self.finetuning_task = kwargs.pop("finetuning_task", None) 87 | self.num_labels = kwargs.pop("num_labels", 2) 88 | self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)}) 89 | self.id2label = dict((int(key), value) for key, value in self.id2label.items()) 90 | self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys()))) 91 | self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) 92 | 93 | # Additional attributes without default values 94 | for key, value in kwargs.items(): 95 | try: 96 | setattr(self, key, value) 97 | except AttributeError as err: 98 | logger.error("Can't set {} with value {} for {}".format(key, value, self)) 99 | raise err 100 | 101 | def save_pretrained(self, save_directory): 102 | """ 103 | Save a configuration object to the directory `save_directory`, so that it 104 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 105 | 106 | Args: 107 | save_directory (:obj:`string`): 108 | Directory where the configuration JSON file will be saved. 109 | """ 110 | assert os.path.isdir( 111 | save_directory 112 | ), "Saving path should be a directory where the model and configuration can be saved" 113 | 114 | # If we save using the predefined names, we can load using `from_pretrained` 115 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 116 | 117 | self.to_json_file(output_config_file) 118 | logger.info("Configuration saved in {}".format(output_config_file)) 119 | 120 | @classmethod 121 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": 122 | r""" 123 | 124 | Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 125 | 126 | Args: 127 | pretrained_model_name_or_path (:obj:`string`): 128 | either: 129 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache_no_pretrained or 130 | download, e.g.: ``bert-base-uncased``. 131 | - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to 132 | our S3, e.g.: ``dbmdz/bert-base-german-cased``. 133 | - a path to a `directory` containing a configuration file saved using the 134 | :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 135 | - a path or url to a saved configuration JSON `file`, e.g.: 136 | ``./my_model_directory/configuration.json``. 137 | cache_dir (:obj:`string`, `optional`): 138 | Path to a directory in which a downloaded pre-trained model 139 | configuration should be cached if the standard cache_no_pretrained should not be used. 140 | kwargs (:obj:`Dict[str, any]`, `optional`): 141 | The values in kwargs of any keys which are configuration attributes will be used to override the loaded 142 | values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is 143 | controlled by the `return_unused_kwargs` keyword parameter. 144 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 145 | Force to (re-)download the model weights and configuration files and override the cached versions if they exist. 146 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 147 | Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. 148 | proxies (:obj:`Dict`, `optional`): 149 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: 150 | :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` 151 | The proxies are used on each request. 152 | return_unused_kwargs: (`optional`) bool: 153 | If False, then this function returns just the final configuration object. 154 | If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a 155 | dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part 156 | of kwargs which has not been used to update `config` and is otherwise ignored. 157 | 158 | Returns: 159 | :class:`PretrainedConfig`: An instance of a configuration object 160 | 161 | Examples:: 162 | 163 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 164 | # derived class: BertConfig 165 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache_no_pretrained. 166 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 167 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 168 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 169 | assert config.output_attention == True 170 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 171 | foo=False, return_unused_kwargs=True) 172 | assert config.output_attention == True 173 | assert unused_kwargs == {'foo': False} 174 | 175 | """ 176 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 177 | return cls.from_dict(config_dict, **kwargs) 178 | 179 | @classmethod 180 | def get_config_dict( 181 | cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs 182 | ) -> Tuple[Dict, Dict]: 183 | """ 184 | From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used 185 | for instantiating a Config using `from_dict`. 186 | 187 | Parameters: 188 | pretrained_model_name_or_path (:obj:`string`): 189 | The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. 190 | pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict: 191 | A map of `shortcut names` to `url`. By default, will use the current class attribute. 192 | 193 | Returns: 194 | :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. 195 | 196 | """ 197 | cache_dir = kwargs.pop("cache_dir", None) 198 | force_download = kwargs.pop("force_download", False) 199 | resume_download = kwargs.pop("resume_download", False) 200 | proxies = kwargs.pop("proxies", None) 201 | local_files_only = kwargs.pop("local_files_only", False) 202 | 203 | if pretrained_config_archive_map is None: 204 | pretrained_config_archive_map = cls.pretrained_config_archive_map 205 | 206 | if pretrained_model_name_or_path in pretrained_config_archive_map: 207 | config_file = pretrained_config_archive_map[pretrained_model_name_or_path] 208 | elif os.path.isdir(pretrained_model_name_or_path): 209 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 210 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): 211 | config_file = pretrained_model_name_or_path 212 | else: 213 | config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME) 214 | 215 | try: 216 | # Load from URL or cache_no_pretrained if already cached 217 | resolved_config_file = cached_path( 218 | config_file, 219 | cache_dir=cache_dir, 220 | force_download=force_download, 221 | proxies=proxies, 222 | resume_download=resume_download, 223 | local_files_only=local_files_only, 224 | ) 225 | # Load config dict 226 | if resolved_config_file is None: 227 | raise EnvironmentError 228 | config_dict = cls._dict_from_json_file(resolved_config_file) 229 | 230 | except EnvironmentError: 231 | if pretrained_model_name_or_path in pretrained_config_archive_map: 232 | msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 233 | config_file 234 | ) 235 | else: 236 | msg = ( 237 | "Model name '{}' was not found in model name list. " 238 | "We assumed '{}' was a path, a model identifier, or url to a configuration file named {} or " 239 | "a directory containing such a file but couldn't find any such file at this path or url.".format( 240 | pretrained_model_name_or_path, config_file, CONFIG_NAME, 241 | ) 242 | ) 243 | raise EnvironmentError(msg) 244 | 245 | except json.JSONDecodeError: 246 | msg = ( 247 | "Couldn't reach server at '{}' to download configuration file or " 248 | "configuration file is not a valid JSON file. " 249 | "Please check network or file content here: {}.".format(config_file, resolved_config_file) 250 | ) 251 | raise EnvironmentError(msg) 252 | 253 | if resolved_config_file == config_file: 254 | logger.info("loading configuration file {}".format(config_file)) 255 | else: 256 | logger.info("loading configuration file {} from cache_no_pretrained at {}".format(config_file, resolved_config_file)) 257 | 258 | return config_dict, kwargs 259 | 260 | @classmethod 261 | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": 262 | """ 263 | Constructs a `Config` from a Python dictionary of parameters. 264 | 265 | Args: 266 | config_dict (:obj:`Dict[str, any]`): 267 | Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved 268 | from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` 269 | method. 270 | kwargs (:obj:`Dict[str, any]`): 271 | Additional parameters from which to initialize the configuration object. 272 | 273 | Returns: 274 | :class:`PretrainedConfig`: An instance of a configuration object 275 | """ 276 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) 277 | 278 | config = cls(**config_dict) 279 | 280 | if hasattr(config, "pruned_heads"): 281 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) 282 | 283 | # Update config with kwargs if needed 284 | to_remove = [] 285 | for key, value in kwargs.items(): 286 | if hasattr(config, key): 287 | setattr(config, key, value) 288 | to_remove.append(key) 289 | for key in to_remove: 290 | kwargs.pop(key, None) 291 | 292 | logger.info("Model config %s", str(config)) 293 | if return_unused_kwargs: 294 | return config, kwargs 295 | else: 296 | return config 297 | 298 | @classmethod 299 | def from_json_file(cls, json_file: str) -> "PretrainedConfig": 300 | """ 301 | Constructs a `Config` from the path to a json file of parameters. 302 | 303 | Args: 304 | json_file (:obj:`string`): 305 | Path to the JSON file containing the parameters. 306 | 307 | Returns: 308 | :class:`PretrainedConfig`: An instance of a configuration object 309 | 310 | """ 311 | config_dict = cls._dict_from_json_file(json_file) 312 | return cls(**config_dict) 313 | 314 | @classmethod 315 | def _dict_from_json_file(cls, json_file: str): 316 | with open(json_file, "r", encoding="utf-8") as reader: 317 | text = reader.read() 318 | return json.loads(text) 319 | 320 | def __eq__(self, other): 321 | return self.__dict__ == other.__dict__ 322 | 323 | def __repr__(self): 324 | return "{} {}".format(self.__class__.__name__, self.to_json_string()) 325 | 326 | def to_dict(self): 327 | """ 328 | Serializes this instance to a Python dictionary. 329 | 330 | Returns: 331 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 332 | """ 333 | output = copy.deepcopy(self.__dict__) 334 | if hasattr(self.__class__, "model_type"): 335 | output["model_type"] = self.__class__.model_type 336 | return output 337 | 338 | def to_json_string(self): 339 | """ 340 | Serializes this instance to a JSON string. 341 | 342 | Returns: 343 | :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. 344 | """ 345 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 346 | 347 | def to_json_file(self, json_file_path): 348 | """ 349 | Save this instance to a json file. 350 | 351 | Args: 352 | json_file_path (:obj:`string`): 353 | Path to the JSON file in which this configuration instance's parameters will be saved. 354 | """ 355 | with open(json_file_path, "w", encoding="utf-8") as writer: 356 | writer.write(self.to_json_string()) 357 | -------------------------------------------------------------------------------- /gpt_model/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache_no_pretrained. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import fnmatch 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import sys 13 | import tarfile 14 | import tempfile 15 | from contextlib import contextmanager 16 | from functools import partial, wraps 17 | from hashlib import sha256 18 | from typing import Optional 19 | from urllib.parse import urlparse 20 | from zipfile import ZipFile, is_zipfile 21 | 22 | import boto3 23 | import requests 24 | from botocore.config import Config 25 | from botocore.exceptions import ClientError 26 | from filelock import FileLock 27 | from tqdm.auto import tqdm 28 | 29 | from . import __version__ 30 | 31 | 32 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 33 | 34 | try: 35 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 36 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 37 | if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): 38 | import torch 39 | 40 | _torch_available = True # pylint: disable=invalid-name 41 | logger.info("PyTorch version {} available.".format(torch.__version__)) 42 | else: 43 | logger.info("Disabling PyTorch because USE_TF is set") 44 | _torch_available = False 45 | except ImportError: 46 | _torch_available = False # pylint: disable=invalid-name 47 | 48 | try: 49 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 50 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 51 | 52 | if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): 53 | import tensorflow as tf 54 | 55 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 56 | _tf_available = True # pylint: disable=invalid-name 57 | logger.info("TensorFlow version {} available.".format(tf.__version__)) 58 | else: 59 | logger.info("Disabling Tensorflow because USE_TORCH is set") 60 | _tf_available = False 61 | except (ImportError, AssertionError): 62 | _tf_available = False # pylint: disable=invalid-name 63 | 64 | try: 65 | from torch.hub import _get_torch_home 66 | 67 | torch_cache_home = _get_torch_home() 68 | except ImportError: 69 | torch_cache_home = os.path.expanduser( 70 | os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache_no_pretrained"), "torch")) 71 | ) 72 | default_cache_path = os.path.join(torch_cache_home, "transformers") 73 | 74 | try: 75 | from pathlib import Path 76 | 77 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 78 | os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)) 79 | ) 80 | except (AttributeError, ImportError): 81 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( 82 | "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) 83 | ) 84 | 85 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 86 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 87 | 88 | WEIGHTS_NAME = "pytorch_model.bin" 89 | TF2_WEIGHTS_NAME = "tf_model.h5" 90 | TF_WEIGHTS_NAME = "model.ckpt" 91 | CONFIG_NAME = "config.json" 92 | MODEL_CARD_NAME = "modelcard.json" 93 | 94 | 95 | MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] 96 | DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] 97 | DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] 98 | 99 | S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" 100 | CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" 101 | 102 | 103 | def is_torch_available(): 104 | return _torch_available 105 | 106 | 107 | def is_tf_available(): 108 | return _tf_available 109 | 110 | 111 | def add_start_docstrings(*docstr): 112 | def docstring_decorator(fn): 113 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 114 | return fn 115 | 116 | return docstring_decorator 117 | 118 | 119 | def add_start_docstrings_to_callable(*docstr): 120 | def docstring_decorator(fn): 121 | class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) 122 | intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) 123 | note = r""" 124 | 125 | .. note:: 126 | Although the recipe for forward pass needs to be defined within 127 | this function, one should call the :class:`Module` instance afterwards 128 | instead of this since the former takes care of running the 129 | pre and post processing steps while the latter silently ignores them. 130 | """ 131 | fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 132 | return fn 133 | 134 | return docstring_decorator 135 | 136 | 137 | def add_end_docstrings(*docstr): 138 | def docstring_decorator(fn): 139 | fn.__doc__ = fn.__doc__ + "".join(docstr) 140 | return fn 141 | 142 | return docstring_decorator 143 | 144 | 145 | def is_remote_url(url_or_filename): 146 | parsed = urlparse(url_or_filename) 147 | return parsed.scheme in ("http", "https", "s3") 148 | 149 | 150 | def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: 151 | endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX 152 | if postfix is None: 153 | return "/".join((endpoint, identifier)) 154 | else: 155 | return "/".join((endpoint, identifier, postfix)) 156 | 157 | 158 | def url_to_filename(url, etag=None): 159 | """ 160 | Convert `url` into a hashed filename in a repeatable way. 161 | If `etag` is specified, append its hash to the url's, delimited 162 | by a period. 163 | If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name 164 | so that TF 2.0 can identify it as a HDF5 file 165 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) 166 | """ 167 | url_bytes = url.encode("utf-8") 168 | url_hash = sha256(url_bytes) 169 | filename = url_hash.hexdigest() 170 | 171 | if etag: 172 | etag_bytes = etag.encode("utf-8") 173 | etag_hash = sha256(etag_bytes) 174 | filename += "." + etag_hash.hexdigest() 175 | 176 | if url.endswith(".h5"): 177 | filename += ".h5" 178 | 179 | return filename 180 | 181 | 182 | def filename_to_url(filename, cache_dir=None): 183 | """ 184 | Return the url and etag (which may be ``None``) stored for `filename`. 185 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 186 | """ 187 | if cache_dir is None: 188 | cache_dir = TRANSFORMERS_CACHE 189 | if isinstance(cache_dir, Path): 190 | cache_dir = str(cache_dir) 191 | 192 | cache_path = os.path.join(cache_dir, filename) 193 | if not os.path.exists(cache_path): 194 | raise EnvironmentError("file {} not found".format(cache_path)) 195 | 196 | meta_path = cache_path + ".json" 197 | if not os.path.exists(meta_path): 198 | raise EnvironmentError("file {} not found".format(meta_path)) 199 | 200 | with open(meta_path, encoding="utf-8") as meta_file: 201 | metadata = json.load(meta_file) 202 | url = metadata["url"] 203 | etag = metadata["etag"] 204 | 205 | return url, etag 206 | 207 | 208 | def cached_path( 209 | url_or_filename, 210 | cache_dir=None, 211 | force_download=False, 212 | proxies=None, 213 | resume_download=False, 214 | user_agent=None, 215 | extract_compressed_file=False, 216 | force_extract=False, 217 | local_files_only=False, 218 | ) -> Optional[str]: 219 | """ 220 | Given something that might be a URL (or might be a local path), 221 | determine which. If it's a URL, download the file and cache_no_pretrained it, and 222 | return the path to the cached file. If it's already a local path, 223 | make sure the file exists and then return the path. 224 | Args: 225 | cache_dir: specify a cache_no_pretrained directory to save the file to (overwrite the default cache_no_pretrained dir). 226 | force_download: if True, re-dowload the file even if it's already cached in the cache_no_pretrained dir. 227 | resume_download: if True, resume the download if incompletly recieved file is found. 228 | user_agent: Optional string or dict that will be appended to the user-agent on remote requests. 229 | extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed 230 | file in a folder along the archive. 231 | force_extract: if True when extract_compressed_file is True and the archive was already extracted, 232 | re-extract the archive and overide the folder where it was extracted. 233 | 234 | Return: 235 | None in case of non-recoverable file (non-existent or inaccessible url + no cache_no_pretrained on disk). 236 | Local path (string) otherwise 237 | """ 238 | if cache_dir is None: 239 | cache_dir = TRANSFORMERS_CACHE 240 | if isinstance(url_or_filename, Path): 241 | url_or_filename = str(url_or_filename) 242 | if isinstance(cache_dir, Path): 243 | cache_dir = str(cache_dir) 244 | 245 | if is_remote_url(url_or_filename): 246 | # URL, so get it from the cache_no_pretrained (downloading if necessary) 247 | output_path = get_from_cache( 248 | url_or_filename, 249 | cache_dir=cache_dir, 250 | force_download=force_download, 251 | proxies=proxies, 252 | resume_download=resume_download, 253 | user_agent=user_agent, 254 | local_files_only=local_files_only, 255 | ) 256 | elif os.path.exists(url_or_filename): 257 | # File, and it exists. 258 | output_path = url_or_filename 259 | elif urlparse(url_or_filename).scheme == "": 260 | # File, but it doesn't exist. 261 | raise EnvironmentError("file {} not found".format(url_or_filename)) 262 | else: 263 | # Something unknown 264 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 265 | 266 | if extract_compressed_file: 267 | if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): 268 | return output_path 269 | 270 | # Path where we extract compressed archives 271 | # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" 272 | output_dir, output_file = os.path.split(output_path) 273 | output_extract_dir_name = output_file.replace(".", "-") + "-extracted" 274 | output_path_extracted = os.path.join(output_dir, output_extract_dir_name) 275 | 276 | if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: 277 | return output_path_extracted 278 | 279 | # Prevent parallel extractions 280 | lock_path = output_path + ".lock" 281 | with FileLock(lock_path): 282 | shutil.rmtree(output_path_extracted, ignore_errors=True) 283 | os.makedirs(output_path_extracted) 284 | if is_zipfile(output_path): 285 | with ZipFile(output_path, "r") as zip_file: 286 | zip_file.extractall(output_path_extracted) 287 | zip_file.close() 288 | elif tarfile.is_tarfile(output_path): 289 | tar_file = tarfile.open(output_path) 290 | tar_file.extractall(output_path_extracted) 291 | tar_file.close() 292 | else: 293 | raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) 294 | 295 | return output_path_extracted 296 | 297 | return output_path 298 | 299 | 300 | def split_s3_path(url): 301 | """Split a full s3 path into the bucket name and path.""" 302 | parsed = urlparse(url) 303 | if not parsed.netloc or not parsed.path: 304 | raise ValueError("bad s3 path {}".format(url)) 305 | bucket_name = parsed.netloc 306 | s3_path = parsed.path 307 | # Remove '/' at beginning of path. 308 | if s3_path.startswith("/"): 309 | s3_path = s3_path[1:] 310 | return bucket_name, s3_path 311 | 312 | 313 | def s3_request(func): 314 | """ 315 | Wrapper function for s3 requests in order to create more helpful error 316 | messages. 317 | """ 318 | 319 | @wraps(func) 320 | def wrapper(url, *args, **kwargs): 321 | try: 322 | return func(url, *args, **kwargs) 323 | except ClientError as exc: 324 | if int(exc.response["Error"]["Code"]) == 404: 325 | raise EnvironmentError("file {} not found".format(url)) 326 | else: 327 | raise 328 | 329 | return wrapper 330 | 331 | 332 | @s3_request 333 | def s3_etag(url, proxies=None): 334 | """Check ETag on S3 object.""" 335 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 336 | bucket_name, s3_path = split_s3_path(url) 337 | s3_object = s3_resource.Object(bucket_name, s3_path) 338 | return s3_object.e_tag 339 | 340 | 341 | @s3_request 342 | def s3_get(url, temp_file, proxies=None): 343 | """Pull a file directly from S3.""" 344 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 345 | bucket_name, s3_path = split_s3_path(url) 346 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 347 | 348 | 349 | def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): 350 | ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) 351 | if is_torch_available(): 352 | ua += "; torch/{}".format(torch.__version__) 353 | if is_tf_available(): 354 | ua += "; tensorflow/{}".format(tf.__version__) 355 | if isinstance(user_agent, dict): 356 | ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) 357 | elif isinstance(user_agent, str): 358 | ua += "; " + user_agent 359 | headers = {"user-agent": ua} 360 | if resume_size > 0: 361 | headers["Range"] = "bytes=%d-" % (resume_size,) 362 | response = requests.get(url, stream=True, proxies=proxies, headers=headers) 363 | if response.status_code == 416: # Range not satisfiable 364 | return 365 | content_length = response.headers.get("Content-Length") 366 | total = resume_size + int(content_length) if content_length is not None else None 367 | progress = tqdm( 368 | unit="B", 369 | unit_scale=True, 370 | total=total, 371 | initial=resume_size, 372 | desc="Downloading", 373 | disable=bool(logger.getEffectiveLevel() == logging.NOTSET), 374 | ) 375 | for chunk in response.iter_content(chunk_size=1024): 376 | if chunk: # filter out keep-alive new chunks 377 | progress.update(len(chunk)) 378 | temp_file.write(chunk) 379 | progress.close() 380 | 381 | 382 | def get_from_cache( 383 | url, 384 | cache_dir=None, 385 | force_download=False, 386 | proxies=None, 387 | etag_timeout=10, 388 | resume_download=False, 389 | user_agent=None, 390 | local_files_only=False, 391 | ) -> Optional[str]: 392 | """ 393 | Given a URL, look for the corresponding file in the local cache_no_pretrained. 394 | If it's not there, download it. Then return the path to the cached file. 395 | 396 | Return: 397 | None in case of non-recoverable file (non-existent or inaccessible url + no cache_no_pretrained on disk). 398 | Local path (string) otherwise 399 | """ 400 | if cache_dir is None: 401 | cache_dir = TRANSFORMERS_CACHE 402 | if isinstance(cache_dir, Path): 403 | cache_dir = str(cache_dir) 404 | 405 | os.makedirs(cache_dir, exist_ok=True) 406 | 407 | etag = None 408 | if not local_files_only: 409 | # Get eTag to add to filename, if it exists. 410 | if url.startswith("s3://"): 411 | etag = s3_etag(url, proxies=proxies) 412 | else: 413 | try: 414 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) 415 | if response.status_code == 200: 416 | etag = response.headers.get("ETag") 417 | except (EnvironmentError, requests.exceptions.Timeout): 418 | # etag is already None 419 | pass 420 | 421 | filename = url_to_filename(url, etag) 422 | 423 | # get cache_no_pretrained path to put the file 424 | cache_path = os.path.join(cache_dir, filename) 425 | 426 | # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. 427 | # try to get the last downloaded one 428 | if etag is None: 429 | if os.path.exists(cache_path): 430 | return cache_path 431 | else: 432 | matching_files = [ 433 | file 434 | for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") 435 | if not file.endswith(".json") and not file.endswith(".lock") 436 | ] 437 | if len(matching_files) > 0: 438 | return os.path.join(cache_dir, matching_files[-1]) 439 | else: 440 | # If files cannot be found and local_files_only=True, 441 | # the models might've been found if local_files_only=False 442 | # Notify the user about that 443 | if local_files_only: 444 | raise ValueError( 445 | "Cannot find the requested files in the cached path and outgoing traffic has been" 446 | " disabled. To enable model look-ups and downloads online, set 'local_files_only'" 447 | " to False." 448 | ) 449 | return None 450 | 451 | # From now on, etag is not None. 452 | if os.path.exists(cache_path) and not force_download: 453 | return cache_path 454 | 455 | # Prevent parallel downloads of the same file with a lock. 456 | lock_path = cache_path + ".lock" 457 | with FileLock(lock_path): 458 | 459 | if resume_download: 460 | incomplete_path = cache_path + ".incomplete" 461 | 462 | @contextmanager 463 | def _resumable_file_manager(): 464 | with open(incomplete_path, "a+b") as f: 465 | yield f 466 | 467 | temp_file_manager = _resumable_file_manager 468 | if os.path.exists(incomplete_path): 469 | resume_size = os.stat(incomplete_path).st_size 470 | else: 471 | resume_size = 0 472 | else: 473 | temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) 474 | resume_size = 0 475 | 476 | # Download to temporary file, then copy to cache_no_pretrained dir once finished. 477 | # Otherwise you get corrupt cache_no_pretrained entries if the download gets interrupted. 478 | with temp_file_manager() as temp_file: 479 | logger.info("%s not found in cache_no_pretrained or force_download set to True, downloading to %s", url, temp_file.name) 480 | 481 | # GET file object 482 | if url.startswith("s3://"): 483 | if resume_download: 484 | logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') 485 | s3_get(url, temp_file, proxies=proxies) 486 | else: 487 | http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) 488 | 489 | logger.info("storing %s in cache_no_pretrained at %s", url, cache_path) 490 | os.rename(temp_file.name, cache_path) 491 | 492 | logger.info("creating metadata file for %s", cache_path) 493 | meta = {"url": url, "etag": etag} 494 | meta_path = cache_path + ".json" 495 | with open(meta_path, "w") as meta_file: 496 | json.dump(meta, meta_file) 497 | 498 | return cache_path 499 | -------------------------------------------------------------------------------- /gpt_model/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | 18 | import collections 19 | import logging 20 | import os 21 | import unicodedata 22 | 23 | from tokenizers import BertWordPieceTokenizer 24 | 25 | from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 31 | 32 | PRETRAINED_VOCAB_FILES_MAP = { 33 | "vocab_file": { 34 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 37 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 38 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 39 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 40 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 41 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 42 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 43 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 44 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 45 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 46 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 47 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", 48 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", 49 | "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", 50 | "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", 51 | "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt", 52 | } 53 | } 54 | 55 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 56 | "bert-base-uncased": 512, 57 | "bert-large-uncased": 512, 58 | "bert-base-cased": 512, 59 | "bert-large-cased": 512, 60 | "bert-base-multilingual-uncased": 512, 61 | "bert-base-multilingual-cased": 512, 62 | "bert-base-chinese": 512, 63 | "bert-base-german-cased": 512, 64 | "bert-large-uncased-whole-word-masking": 512, 65 | "bert-large-cased-whole-word-masking": 512, 66 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512, 67 | "bert-large-cased-whole-word-masking-finetuned-squad": 512, 68 | "bert-base-cased-finetuned-mrpc": 512, 69 | "bert-base-german-dbmdz-cased": 512, 70 | "bert-base-german-dbmdz-uncased": 512, 71 | "bert-base-finnish-cased-v1": 512, 72 | "bert-base-finnish-uncased-v1": 512, 73 | "bert-base-dutch-cased": 512, 74 | } 75 | 76 | PRETRAINED_INIT_CONFIGURATION = { 77 | "bert-base-uncased": {"do_lower_case": True}, 78 | "bert-large-uncased": {"do_lower_case": True}, 79 | "bert-base-cased": {"do_lower_case": False}, 80 | "bert-large-cased": {"do_lower_case": False}, 81 | "bert-base-multilingual-uncased": {"do_lower_case": True}, 82 | "bert-base-multilingual-cased": {"do_lower_case": False}, 83 | "bert-base-chinese": {"do_lower_case": False}, 84 | "bert-base-german-cased": {"do_lower_case": False}, 85 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, 86 | "bert-large-cased-whole-word-masking": {"do_lower_case": False}, 87 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, 88 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, 89 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, 90 | "bert-base-german-dbmdz-cased": {"do_lower_case": False}, 91 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, 92 | "bert-base-finnish-cased-v1": {"do_lower_case": False}, 93 | "bert-base-finnish-uncased-v1": {"do_lower_case": True}, 94 | "bert-base-dutch-cased": {"do_lower_case": False}, 95 | } 96 | 97 | 98 | def load_vocab(vocab_file): 99 | """Loads a vocabulary file into a dictionary.""" 100 | vocab = collections.OrderedDict() 101 | with open(vocab_file, "r", encoding="utf-8") as reader: 102 | tokens = reader.readlines() 103 | for index, token in enumerate(tokens): 104 | token = token.rstrip("\n") 105 | vocab[token] = index 106 | return vocab 107 | 108 | 109 | def whitespace_tokenize(text): 110 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 111 | text = text.strip() 112 | if not text: 113 | return [] 114 | tokens = text.split() 115 | return tokens 116 | 117 | 118 | class BertTokenizer(PreTrainedTokenizer): 119 | r""" 120 | Constructs a BertTokenizer. 121 | :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece 122 | 123 | Args: 124 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 125 | do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True 126 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 127 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 128 | minimum of this value (if specified) and the underlying BERT model's sequence length. 129 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 130 | do_basic_tokenize=True 131 | """ 132 | 133 | vocab_files_names = VOCAB_FILES_NAMES 134 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 135 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 136 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 137 | 138 | def __init__( 139 | self, 140 | vocab_file, 141 | do_lower_case=True, 142 | do_basic_tokenize=True, 143 | never_split=None, 144 | unk_token="[UNK]", 145 | sep_token="[SEP]", 146 | pad_token="[PAD]", 147 | cls_token="[CLS]", 148 | mask_token="[MASK]", 149 | tokenize_chinese_chars=True, 150 | **kwargs 151 | ): 152 | """Constructs a BertTokenizer. 153 | 154 | Args: 155 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 156 | **do_lower_case**: (`optional`) boolean (default True) 157 | Whether to lower case the input 158 | Only has an effect when do_basic_tokenize=True 159 | **do_basic_tokenize**: (`optional`) boolean (default True) 160 | Whether to do basic tokenization before wordpiece. 161 | **never_split**: (`optional`) list of string 162 | List of tokens which will never be split during tokenization. 163 | Only has an effect when do_basic_tokenize=True 164 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 165 | Whether to tokenize Chinese characters. 166 | This should likely be deactivated for Japanese: 167 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 168 | """ 169 | super().__init__( 170 | unk_token=unk_token, 171 | sep_token=sep_token, 172 | pad_token=pad_token, 173 | cls_token=cls_token, 174 | mask_token=mask_token, 175 | **kwargs, 176 | ) 177 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 178 | self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens 179 | 180 | if not os.path.isfile(vocab_file): 181 | raise ValueError( 182 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 183 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) 184 | ) 185 | self.vocab = load_vocab(vocab_file) 186 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 187 | self.do_basic_tokenize = do_basic_tokenize 188 | if do_basic_tokenize: 189 | self.basic_tokenizer = BasicTokenizer( 190 | do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars 191 | ) 192 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 193 | 194 | @property 195 | def vocab_size(self): 196 | return len(self.vocab) 197 | 198 | def get_vocab(self): 199 | return dict(self.vocab, **self.added_tokens_encoder) 200 | 201 | def _tokenize(self, text): 202 | split_tokens = [] 203 | if self.do_basic_tokenize: 204 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 205 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 206 | split_tokens.append(sub_token) 207 | else: 208 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 209 | return split_tokens 210 | 211 | def _convert_token_to_id(self, token): 212 | """ Converts a token (str) in an id using the vocab. """ 213 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 214 | 215 | def _convert_id_to_token(self, index): 216 | """Converts an index (integer) in a token (str) using the vocab.""" 217 | return self.ids_to_tokens.get(index, self.unk_token) 218 | 219 | def convert_tokens_to_string(self, tokens): 220 | """ Converts a sequence of tokens (string) in a single string. """ 221 | out_string = " ".join(tokens).replace(" ##", "").strip() 222 | return out_string 223 | 224 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 225 | """ 226 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 227 | by concatenating and adding special tokens. 228 | A BERT sequence has the following format: 229 | single sequence: [CLS] X [SEP] 230 | pair of sequences: [CLS] A [SEP] B [SEP] 231 | """ 232 | if token_ids_1 is None: 233 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 234 | cls = [self.cls_token_id] 235 | sep = [self.sep_token_id] 236 | return cls + token_ids_0 + sep + token_ids_1 + sep 237 | 238 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 239 | """ 240 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 241 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 242 | 243 | Args: 244 | token_ids_0: list of ids (must not contain special tokens) 245 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 246 | for sequence pairs 247 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 248 | special tokens for the model 249 | 250 | Returns: 251 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 252 | """ 253 | 254 | if already_has_special_tokens: 255 | if token_ids_1 is not None: 256 | raise ValueError( 257 | "You should not supply a second sequence if the provided sequence of " 258 | "ids is already formated with special tokens for the model." 259 | ) 260 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 261 | 262 | if token_ids_1 is not None: 263 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 264 | return [1] + ([0] * len(token_ids_0)) + [1] 265 | 266 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 267 | """ 268 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 269 | A BERT sequence pair mask has the following format: 270 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 271 | | first sequence | second sequence 272 | 273 | if token_ids_1 is None, only returns the first portion of the mask (0's). 274 | """ 275 | sep = [self.sep_token_id] 276 | cls = [self.cls_token_id] 277 | if token_ids_1 is None: 278 | return len(cls + token_ids_0 + sep) * [0] 279 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 280 | 281 | def save_vocabulary(self, vocab_path): 282 | """Save the tokenizer vocabulary to a directory or file.""" 283 | index = 0 284 | if os.path.isdir(vocab_path): 285 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) 286 | else: 287 | vocab_file = vocab_path 288 | with open(vocab_file, "w", encoding="utf-8") as writer: 289 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 290 | if index != token_index: 291 | logger.warning( 292 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 293 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 294 | ) 295 | index = token_index 296 | writer.write(token + "\n") 297 | index += 1 298 | return (vocab_file,) 299 | 300 | 301 | class BasicTokenizer(object): 302 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 303 | 304 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 305 | """ Constructs a BasicTokenizer. 306 | 307 | Args: 308 | **do_lower_case**: Whether to lower case the input. 309 | **never_split**: (`optional`) list of str 310 | Kept for backward compatibility purposes. 311 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 312 | List of token not to split. 313 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 314 | Whether to tokenize Chinese characters. 315 | This should likely be deactivated for Japanese: 316 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 317 | """ 318 | if never_split is None: 319 | never_split = [] 320 | self.do_lower_case = do_lower_case 321 | self.never_split = never_split 322 | self.tokenize_chinese_chars = tokenize_chinese_chars 323 | 324 | def tokenize(self, text, never_split=None): 325 | """ Basic Tokenization of a piece of text. 326 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 327 | 328 | Args: 329 | **never_split**: (`optional`) list of str 330 | Kept for backward compatibility purposes. 331 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 332 | List of token not to split. 333 | """ 334 | never_split = self.never_split + (never_split if never_split is not None else []) 335 | text = self._clean_text(text) 336 | # This was added on November 1st, 2018 for the multilingual and Chinese 337 | # models. This is also applied to the English models now, but it doesn't 338 | # matter since the English models were not trained on any Chinese data 339 | # and generally don't have any Chinese data in them (there are Chinese 340 | # characters in the vocabulary because Wikipedia does have some Chinese 341 | # words in the English Wikipedia.). 342 | if self.tokenize_chinese_chars: 343 | text = self._tokenize_chinese_chars(text) 344 | orig_tokens = whitespace_tokenize(text) 345 | split_tokens = [] 346 | for token in orig_tokens: 347 | if self.do_lower_case and token not in never_split: 348 | token = token.lower() 349 | token = self._run_strip_accents(token) 350 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 351 | 352 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 353 | return output_tokens 354 | 355 | def _run_strip_accents(self, text): 356 | """Strips accents from a piece of text.""" 357 | text = unicodedata.normalize("NFD", text) 358 | output = [] 359 | for char in text: 360 | cat = unicodedata.category(char) 361 | if cat == "Mn": 362 | continue 363 | output.append(char) 364 | return "".join(output) 365 | 366 | def _run_split_on_punc(self, text, never_split=None): 367 | """Splits punctuation on a piece of text.""" 368 | if never_split is not None and text in never_split: 369 | return [text] 370 | chars = list(text) 371 | i = 0 372 | start_new_word = True 373 | output = [] 374 | while i < len(chars): 375 | char = chars[i] 376 | if _is_punctuation(char): 377 | output.append([char]) 378 | start_new_word = True 379 | else: 380 | if start_new_word: 381 | output.append([]) 382 | start_new_word = False 383 | output[-1].append(char) 384 | i += 1 385 | 386 | return ["".join(x) for x in output] 387 | 388 | def _tokenize_chinese_chars(self, text): 389 | """Adds whitespace around any CJK character.""" 390 | output = [] 391 | for char in text: 392 | cp = ord(char) 393 | if self._is_chinese_char(cp): 394 | output.append(" ") 395 | output.append(char) 396 | output.append(" ") 397 | else: 398 | output.append(char) 399 | return "".join(output) 400 | 401 | def _is_chinese_char(self, cp): 402 | """Checks whether CP is the codepoint of a CJK character.""" 403 | # This defines a "chinese character" as anything in the CJK Unicode block: 404 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 405 | # 406 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 407 | # despite its name. The modern Korean Hangul alphabet is a different block, 408 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 409 | # space-separated words, so they are not treated specially and handled 410 | # like the all of the other languages. 411 | if ( 412 | (cp >= 0x4E00 and cp <= 0x9FFF) 413 | or (cp >= 0x3400 and cp <= 0x4DBF) # 414 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 415 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 416 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 417 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 418 | or (cp >= 0xF900 and cp <= 0xFAFF) 419 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 420 | ): # 421 | return True 422 | 423 | return False 424 | 425 | def _clean_text(self, text): 426 | """Performs invalid character removal and whitespace cleanup on text.""" 427 | output = [] 428 | for char in text: 429 | cp = ord(char) 430 | if cp == 0 or cp == 0xFFFD or _is_control(char): 431 | continue 432 | if _is_whitespace(char): 433 | output.append(" ") 434 | else: 435 | output.append(char) 436 | return "".join(output) 437 | 438 | 439 | class WordpieceTokenizer(object): 440 | """Runs WordPiece tokenization.""" 441 | 442 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 443 | self.vocab = vocab 444 | self.unk_token = unk_token 445 | self.max_input_chars_per_word = max_input_chars_per_word 446 | 447 | def tokenize(self, text): 448 | """Tokenizes a piece of text into its word pieces. 449 | 450 | This uses a greedy longest-match-first algorithm to perform tokenization 451 | using the given vocabulary. 452 | 453 | For example: 454 | input = "unaffable" 455 | output = ["un", "##aff", "##able"] 456 | 457 | Args: 458 | text: A single token or whitespace separated tokens. This should have 459 | already been passed through `BasicTokenizer`. 460 | 461 | Returns: 462 | A list of wordpiece tokens. 463 | """ 464 | 465 | output_tokens = [] 466 | for token in whitespace_tokenize(text): 467 | chars = list(token) 468 | if len(chars) > self.max_input_chars_per_word: 469 | output_tokens.append(self.unk_token) 470 | continue 471 | 472 | is_bad = False 473 | start = 0 474 | sub_tokens = [] 475 | while start < len(chars): 476 | end = len(chars) 477 | cur_substr = None 478 | while start < end: 479 | substr = "".join(chars[start:end]) 480 | if start > 0: 481 | substr = "##" + substr 482 | if substr in self.vocab: 483 | cur_substr = substr 484 | break 485 | end -= 1 486 | if cur_substr is None: 487 | is_bad = True 488 | break 489 | sub_tokens.append(cur_substr) 490 | start = end 491 | 492 | if is_bad: 493 | output_tokens.append(self.unk_token) 494 | else: 495 | output_tokens.extend(sub_tokens) 496 | return output_tokens 497 | 498 | 499 | def _is_whitespace(char): 500 | """Checks whether `chars` is a whitespace character.""" 501 | # \t, \n, and \r are technically contorl characters but we treat them 502 | # as whitespace since they are generally considered as such. 503 | if char == " " or char == "\t" or char == "\n" or char == "\r": 504 | return True 505 | cat = unicodedata.category(char) 506 | if cat == "Zs": 507 | return True 508 | return False 509 | 510 | 511 | def _is_control(char): 512 | """Checks whether `chars` is a control character.""" 513 | # These are technically control characters but we count them as whitespace 514 | # characters. 515 | if char == "\t" or char == "\n" or char == "\r": 516 | return False 517 | cat = unicodedata.category(char) 518 | if cat.startswith("C"): 519 | return True 520 | return False 521 | 522 | 523 | def _is_punctuation(char): 524 | """Checks whether `chars` is a punctuation character.""" 525 | cp = ord(char) 526 | # We treat all non-letter/number ASCII as punctuation. 527 | # Characters such as "^", "$", and "`" are not in the Unicode 528 | # Punctuation class but we treat them as punctuation anyways, for 529 | # consistency. 530 | if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): 531 | return True 532 | cat = unicodedata.category(char) 533 | if cat.startswith("P"): 534 | return True 535 | return False 536 | 537 | 538 | class BertTokenizerFast(PreTrainedTokenizerFast): 539 | vocab_files_names = VOCAB_FILES_NAMES 540 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 541 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 542 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 543 | 544 | def __init__( 545 | self, 546 | vocab_file, 547 | do_lower_case=True, 548 | do_basic_tokenize=True, 549 | never_split=None, 550 | unk_token="[UNK]", 551 | sep_token="[SEP]", 552 | pad_token="[PAD]", 553 | cls_token="[CLS]", 554 | mask_token="[MASK]", 555 | clean_text=True, 556 | tokenize_chinese_chars=True, 557 | add_special_tokens=True, 558 | strip_accents=True, 559 | wordpieces_prefix="##", 560 | **kwargs 561 | ): 562 | super().__init__( 563 | BertWordPieceTokenizer( 564 | vocab_file=vocab_file, 565 | add_special_tokens=add_special_tokens, 566 | unk_token=unk_token, 567 | sep_token=sep_token, 568 | cls_token=cls_token, 569 | clean_text=clean_text, 570 | handle_chinese_chars=tokenize_chinese_chars, 571 | strip_accents=strip_accents, 572 | lowercase=do_lower_case, 573 | wordpieces_prefix=wordpieces_prefix, 574 | ), 575 | unk_token=unk_token, 576 | sep_token=sep_token, 577 | pad_token=pad_token, 578 | cls_token=cls_token, 579 | mask_token=mask_token, 580 | **kwargs, 581 | ) 582 | 583 | self.do_lower_case = do_lower_case 584 | 585 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 586 | output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 587 | 588 | if token_ids_1: 589 | output += token_ids_1 + [self.sep_token_id] 590 | 591 | return output 592 | -------------------------------------------------------------------------------- /gpt_model/modeling_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT model.""" 17 | 18 | 19 | import json 20 | import logging 21 | import math 22 | import os 23 | 24 | import torch 25 | import torch.nn as nn 26 | from torch.nn import CrossEntropyLoss 27 | 28 | from .activations import gelu_new, swish 29 | from .configuration_openai import OpenAIGPTConfig 30 | from .file_utils import add_start_docstrings, add_start_docstrings_to_callable 31 | from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer 32 | 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = { 37 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin" 38 | } 39 | 40 | 41 | def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): 42 | """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here) 43 | """ 44 | import re 45 | import numpy as np 46 | 47 | if ".ckpt" in openai_checkpoint_folder_path: 48 | openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path) 49 | 50 | logger.info("Loading weights from {}".format(openai_checkpoint_folder_path)) 51 | 52 | with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle: 53 | names = json.load(names_handle) 54 | with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle: 55 | shapes = json.load(shapes_handle) 56 | offsets = np.cumsum([np.prod(shape) for shape in shapes]) 57 | init_params = [np.load(openai_checkpoint_folder_path + "/params_{}.npy".format(n)) for n in range(10)] 58 | init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] 59 | init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] 60 | 61 | # This was used when we had a single embedding matrix for positions and tokens 62 | # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) 63 | # del init_params[1] 64 | init_params = [arr.squeeze() for arr in init_params] 65 | 66 | try: 67 | assert model.tokens_embed.weight.shape == init_params[1].shape 68 | assert model.positions_embed.weight.shape == init_params[0].shape 69 | except AssertionError as e: 70 | e.args += (model.tokens_embed.weight.shape, init_params[1].shape) 71 | e.args += (model.positions_embed.weight.shape, init_params[0].shape) 72 | raise 73 | 74 | model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) 75 | model.positions_embed.weight.data = torch.from_numpy(init_params[0]) 76 | names.pop(0) 77 | # Pop position and token embedding arrays 78 | init_params.pop(0) 79 | init_params.pop(0) 80 | 81 | for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): 82 | name = name[6:] # skip "model/" 83 | assert name[-2:] == ":0" 84 | name = name[:-2] 85 | name = name.split("/") 86 | pointer = model 87 | for m_name in name: 88 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 89 | scope_names = re.split(r"(\d+)", m_name) 90 | else: 91 | scope_names = [m_name] 92 | if scope_names[0] == "g": 93 | pointer = getattr(pointer, "weight") 94 | elif scope_names[0] == "b": 95 | pointer = getattr(pointer, "bias") 96 | elif scope_names[0] == "w": 97 | pointer = getattr(pointer, "weight") 98 | else: 99 | pointer = getattr(pointer, scope_names[0]) 100 | if len(scope_names) >= 2: 101 | num = int(scope_names[1]) 102 | pointer = pointer[num] 103 | try: 104 | assert pointer.shape == array.shape 105 | except AssertionError as e: 106 | e.args += (pointer.shape, array.shape) 107 | raise 108 | try: 109 | assert pointer.shape == array.shape 110 | except AssertionError as e: 111 | e.args += (pointer.shape, array.shape) 112 | raise 113 | logger.info("Initialize PyTorch weight {}".format(name)) 114 | pointer.data = torch.from_numpy(array) 115 | return model 116 | 117 | 118 | ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu_new} 119 | 120 | 121 | class Attention(nn.Module): 122 | def __init__(self, nx, n_ctx, config, scale=False): 123 | super().__init__() 124 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 125 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 126 | assert n_state % config.n_head == 0 127 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 128 | self.n_head = config.n_head 129 | self.split_size = n_state 130 | self.scale = scale 131 | 132 | self.output_attentions = config.output_attentions 133 | 134 | self.c_attn = Conv1D(n_state * 3, nx) 135 | self.c_proj = Conv1D(n_state, nx) 136 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 137 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 138 | self.pruned_heads = set() 139 | 140 | def prune_heads(self, heads): 141 | if len(heads) == 0: 142 | return 143 | mask = torch.ones(self.n_head, self.split_size // self.n_head) 144 | heads = set(heads) - self.pruned_heads 145 | for head in heads: 146 | head -= sum(1 if h < head else 0 for h in self.pruned_heads) 147 | mask[head] = 0 148 | mask = mask.view(-1).contiguous().eq(1) 149 | index = torch.arange(len(mask))[mask].long() 150 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 151 | # Prune conv1d layers 152 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 153 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 154 | # Update hyper params 155 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 156 | self.n_head = self.n_head - len(heads) 157 | self.pruned_heads = self.pruned_heads.union(heads) 158 | 159 | def _attn(self, q, k, v, attention_mask=None, head_mask=None): 160 | w = torch.matmul(q, k) 161 | if self.scale: 162 | w = w / math.sqrt(v.size(-1)) 163 | # w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights 164 | # XD: self.b may be larger than w, so we need to crop it 165 | b = self.bias[:, :, : w.size(-2), : w.size(-1)] 166 | w = w * b + -1e4 * (1 - b) 167 | 168 | if attention_mask is not None: 169 | # Apply the attention mask 170 | w = w + attention_mask 171 | 172 | w = nn.Softmax(dim=-1)(w) 173 | w = self.attn_dropout(w) 174 | 175 | # Mask heads if we want to 176 | if head_mask is not None: 177 | w = w * head_mask 178 | 179 | outputs = [torch.matmul(w, v)] 180 | if self.output_attentions: 181 | outputs.append(w) 182 | return outputs 183 | 184 | def merge_heads(self, x): 185 | x = x.permute(0, 2, 1, 3).contiguous() 186 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 187 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 188 | 189 | def split_heads(self, x, k=False): 190 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 191 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 192 | if k: 193 | return x.permute(0, 2, 3, 1) 194 | else: 195 | return x.permute(0, 2, 1, 3) 196 | 197 | def forward(self, x, attention_mask=None, head_mask=None): 198 | x = self.c_attn(x) 199 | query, key, value = x.split(self.split_size, dim=2) 200 | query = self.split_heads(query) 201 | key = self.split_heads(key, k=True) 202 | value = self.split_heads(value) 203 | 204 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask) 205 | a = attn_outputs[0] 206 | 207 | a = self.merge_heads(a) 208 | a = self.c_proj(a) 209 | a = self.resid_dropout(a) 210 | 211 | outputs = [a] + attn_outputs[1:] 212 | return outputs # a, (attentions) 213 | 214 | 215 | class MLP(nn.Module): 216 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 217 | super().__init__() 218 | nx = config.n_embd 219 | self.c_fc = Conv1D(n_state, nx) 220 | self.c_proj = Conv1D(nx, n_state) 221 | self.act = ACT_FNS[config.afn] 222 | self.dropout = nn.Dropout(config.resid_pdrop) 223 | 224 | def forward(self, x): 225 | h = self.act(self.c_fc(x)) 226 | h2 = self.c_proj(h) 227 | return self.dropout(h2) 228 | 229 | 230 | class Block(nn.Module): 231 | def __init__(self, n_ctx, config, scale=False): 232 | super().__init__() 233 | nx = config.n_embd 234 | self.attn = Attention(nx, n_ctx, config, scale) 235 | self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 236 | self.mlp = MLP(4 * nx, config) 237 | self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 238 | 239 | def forward(self, x, attention_mask=None, head_mask=None): 240 | attn_outputs = self.attn(x, attention_mask=attention_mask, head_mask=head_mask) 241 | a = attn_outputs[0] 242 | 243 | n = self.ln_1(x + a) 244 | m = self.mlp(n) 245 | h = self.ln_2(n + m) 246 | 247 | outputs = [h] + attn_outputs[1:] 248 | return outputs 249 | 250 | 251 | class OpenAIGPTPreTrainedModel(PreTrainedModel): 252 | """ An abstract class to handle weights initialization and 253 | a simple interface for downloading and loading pretrained models. 254 | """ 255 | 256 | config_class = OpenAIGPTConfig 257 | pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP 258 | load_tf_weights = load_tf_weights_in_openai_gpt 259 | base_model_prefix = "transformer" 260 | 261 | def _init_weights(self, module): 262 | """ Initialize the weights. 263 | """ 264 | if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): 265 | # Slightly different from the TF version which uses truncated_normal for initialization 266 | # cf https://github.com/pytorch/pytorch/pull/5617 267 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 268 | if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: 269 | module.bias.data.zero_() 270 | elif isinstance(module, nn.LayerNorm): 271 | module.bias.data.zero_() 272 | module.weight.data.fill_(1.0) 273 | 274 | 275 | OPENAI_GPT_START_DOCSTRING = r""" 276 | 277 | This model is a PyTorch `torch.nn.Module `_ sub-class. 278 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general 279 | usage and behavior. 280 | 281 | Parameters: 282 | config (:class:`~transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model. 283 | Initializing with a config file does not load the weights associated with the model, only the configuration. 284 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 285 | """ 286 | 287 | OPENAI_GPT_INPUTS_DOCSTRING = r""" 288 | Args: 289 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 290 | Indices of input sequence tokens in the vocabulary. 291 | 292 | Indices can be obtained using :class:`transformers.OpenAIGPTTokenizer`. 293 | See :func:`transformers.PreTrainedTokenizer.encode` and 294 | :func:`transformers.PreTrainedTokenizer.encode_plus` for details. 295 | 296 | `What are input IDs? <../glossary.html#input-ids>`__ 297 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 298 | Mask to avoid performing attention on padding token indices. 299 | Mask values selected in ``[0, 1]``: 300 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. 301 | 302 | `What are attention masks? <../glossary.html#attention-mask>`__ 303 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 304 | Segment token indices to indicate first and second portions of the inputs. 305 | Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` 306 | corresponds to a `sentence B` token 307 | 308 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 309 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 310 | Indices of positions of each input sequence tokens in the position embeddings. 311 | Selected in the range ``[0, config.max_position_embeddings - 1]``. 312 | 313 | `What are position IDs? <../glossary.html#position-ids>`_ 314 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): 315 | Mask to nullify selected heads of the self-attention modules. 316 | Mask values selected in ``[0, 1]``: 317 | :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. 318 | input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): 319 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 320 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 321 | than the model's internal embedding lookup matrix. 322 | """ 323 | 324 | 325 | @add_start_docstrings( 326 | "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", 327 | OPENAI_GPT_START_DOCSTRING, 328 | ) 329 | class OpenAIGPTModel(OpenAIGPTPreTrainedModel): 330 | def __init__(self, config): 331 | super().__init__(config) 332 | self.output_attentions = config.output_attentions 333 | self.output_hidden_states = config.output_hidden_states 334 | 335 | self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) 336 | self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) 337 | self.drop = nn.Dropout(config.embd_pdrop) 338 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 339 | 340 | self.init_weights() 341 | 342 | def get_input_embeddings(self): 343 | return self.tokens_embed 344 | 345 | def set_input_embeddings(self, new_embeddings): 346 | self.tokens_embed = new_embeddings 347 | 348 | def _prune_heads(self, heads_to_prune): 349 | """ Prunes heads of the model. 350 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 351 | """ 352 | for layer, heads in heads_to_prune.items(): 353 | self.h[layer].attn.prune_heads(heads) 354 | 355 | @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING) 356 | def forward( 357 | self, 358 | input_ids=None, 359 | attention_mask=None, 360 | token_type_ids=None, 361 | position_ids=None, 362 | head_mask=None, 363 | inputs_embeds=None, 364 | ): 365 | r""" 366 | Return: 367 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs: 368 | last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): 369 | Sequence of hidden-states at the last layer of the model. 370 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 371 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 372 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 373 | 374 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 375 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 376 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 377 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 378 | 379 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 380 | heads. 381 | 382 | Examples:: 383 | 384 | from transformers import OpenAIGPTTokenizer, OpenAIGPTModel 385 | import torch 386 | 387 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 388 | model = OpenAIGPTModel.from_pretrained('openai-gpt') 389 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 390 | outputs = model(input_ids) 391 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 392 | 393 | """ 394 | if input_ids is not None and inputs_embeds is not None: 395 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 396 | elif input_ids is not None: 397 | input_shape = input_ids.size() 398 | input_ids = input_ids.view(-1, input_shape[-1]) 399 | elif inputs_embeds is not None: 400 | input_shape = inputs_embeds.size()[:-1] 401 | else: 402 | raise ValueError("You have to specify either input_ids or inputs_embeds") 403 | 404 | if position_ids is None: 405 | # Code is different from when we had a single embedding matrice from position and token embeddings 406 | device = input_ids.device if input_ids is not None else inputs_embeds.device 407 | position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device) 408 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 409 | 410 | # Attention mask. 411 | if attention_mask is not None: 412 | # We create a 3D attention mask from a 2D tensor mask. 413 | # Sizes are [batch_size, 1, 1, to_seq_length] 414 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 415 | # this attention mask is more simple than the triangular masking of causal attention 416 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 417 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 418 | 419 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 420 | # masked positions, this operation will create a tensor which is 0.0 for 421 | # positions we want to attend and -10000.0 for masked positions. 422 | # Since we are adding it to the raw scores before the softmax, this is 423 | # effectively the same as removing these entirely. 424 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 425 | attention_mask = (1.0 - attention_mask) * -10000.0 426 | 427 | # Prepare head mask if needed 428 | # 1.0 in head_mask indicate we keep the head 429 | # attention_probs has shape bsz x n_heads x N x N 430 | # head_mask has shape n_layer x batch x n_heads x N x N 431 | if head_mask is not None: 432 | if head_mask.dim() == 1: 433 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 434 | head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) 435 | elif head_mask.dim() == 2: 436 | head_mask = ( 437 | head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) 438 | ) # We can specify head_mask for each layer 439 | head_mask = head_mask.to( 440 | dtype=next(self.parameters()).dtype 441 | ) # switch to fload if need + fp16 compatibility 442 | else: 443 | head_mask = [None] * self.config.n_layer 444 | 445 | if inputs_embeds is None: 446 | inputs_embeds = self.tokens_embed(input_ids) 447 | position_embeds = self.positions_embed(position_ids) 448 | if token_type_ids is not None: 449 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 450 | token_type_embeds = self.tokens_embed(token_type_ids) 451 | else: 452 | token_type_embeds = 0 453 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 454 | hidden_states = self.drop(hidden_states) 455 | 456 | output_shape = input_shape + (hidden_states.size(-1),) 457 | 458 | all_attentions = () 459 | all_hidden_states = () 460 | for i, block in enumerate(self.h): 461 | if self.output_hidden_states: 462 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) 463 | 464 | outputs = block(hidden_states, attention_mask, head_mask[i]) 465 | hidden_states = outputs[0] 466 | if self.output_attentions: 467 | all_attentions = all_attentions + (outputs[1],) 468 | 469 | # Add last layer 470 | if self.output_hidden_states: 471 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) 472 | 473 | outputs = (hidden_states.view(*output_shape),) 474 | if self.output_hidden_states: 475 | outputs = outputs + (all_hidden_states,) 476 | if self.output_attentions: 477 | outputs = outputs + (all_attentions,) 478 | return outputs # last hidden state, (all hidden states), (all attentions) 479 | 480 | 481 | @add_start_docstrings( 482 | """OpenAI GPT Model transformer with a language modeling head on top 483 | (linear layer with weights tied to the input embeddings). """, 484 | OPENAI_GPT_START_DOCSTRING, 485 | ) 486 | class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): 487 | def __init__(self, config): 488 | super().__init__(config) 489 | self.transformer = OpenAIGPTModel(config) 490 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 491 | 492 | self.init_weights() 493 | 494 | def get_output_embeddings(self): 495 | return self.lm_head 496 | 497 | def prepare_inputs_for_generation(self, input_ids, **kwargs): 498 | # only last token for inputs_ids if past is defined in kwargs 499 | token_type_ids = kwargs["token_type_ids"] 500 | if input_ids.size(-1) > token_type_ids.size(-1): 501 | token_type_ids = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(0)], dim=-1) 502 | inputs = {"input_ids": input_ids, "token_type_ids": token_type_ids} 503 | return inputs 504 | 505 | @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING) 506 | def forward( 507 | self, 508 | input_ids=None, 509 | attention_mask=None, 510 | token_type_ids=None, 511 | position_ids=None, 512 | head_mask=None, 513 | inputs_embeds=None, 514 | labels=None, 515 | ): 516 | r""" 517 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 518 | Labels for language modeling. 519 | Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` 520 | Indices are selected in ``[-100, 0, ..., config.vocab_size]`` 521 | All labels set to ``-100`` are ignored (masked), the loss is only 522 | computed for labels in ``[0, ..., config.vocab_size]`` 523 | 524 | Return: 525 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs: 526 | loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided) 527 | Language modeling loss. 528 | prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): 529 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 530 | past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`): 531 | Contains pre-computed hidden-states (key and values in the attention blocks). 532 | Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model 533 | should not be passed as input ids as they have already been computed. 534 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 535 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 536 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 537 | 538 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 539 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 540 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 541 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 542 | 543 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 544 | heads. 545 | 546 | Examples:: 547 | 548 | from transformers import OpenAIGPTTokenizer, OpenAIGPTLMHeadModel 549 | import torch 550 | 551 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 552 | model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt') 553 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 554 | outputs = model(input_ids, labels=input_ids) 555 | loss, logits = outputs[:2] 556 | 557 | """ 558 | transformer_outputs = self.transformer( 559 | input_ids, 560 | attention_mask=attention_mask, 561 | token_type_ids=token_type_ids, 562 | position_ids=position_ids, 563 | head_mask=head_mask, 564 | inputs_embeds=inputs_embeds, 565 | ) 566 | hidden_states = transformer_outputs[0] 567 | lm_logits = self.lm_head(hidden_states) 568 | 569 | outputs = (lm_logits,) + transformer_outputs[1:] 570 | if labels is not None: 571 | # Shift so that tokens < n predict n 572 | shift_logits = lm_logits[..., :-1, :].contiguous() 573 | shift_labels = labels[..., 1:].contiguous() 574 | # Flatten the tokens 575 | loss_fct = CrossEntropyLoss() 576 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 577 | outputs = (loss,) + outputs 578 | 579 | return outputs # (loss), lm_logits, (all hidden states), (all attentions) 580 | 581 | 582 | @add_start_docstrings( 583 | """OpenAI GPT Model transformer with a language modeling and a multiple-choice classification 584 | head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers. 585 | The language modeling head has its weights tied to the input embeddings, 586 | the classification head takes as input the input of a specified classification token index in the input sequence). 587 | """, 588 | OPENAI_GPT_START_DOCSTRING, 589 | ) 590 | class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): 591 | def __init__(self, config): 592 | super().__init__(config) 593 | 594 | config.num_labels = 1 595 | self.transformer = OpenAIGPTModel(config) 596 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 597 | self.multiple_choice_head = SequenceSummary(config) 598 | 599 | self.init_weights() 600 | 601 | def get_output_embeddings(self): 602 | return self.lm_head 603 | 604 | @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING) 605 | def forward( 606 | self, 607 | input_ids=None, 608 | attention_mask=None, 609 | token_type_ids=None, 610 | position_ids=None, 611 | head_mask=None, 612 | inputs_embeds=None, 613 | mc_token_ids=None, 614 | lm_labels=None, 615 | mc_labels=None, 616 | ): 617 | r""" 618 | mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input) 619 | Index of the classification token in each input sequence. 620 | Selected in the range ``[0, input_ids.size(-1) - 1[``. 621 | lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`) 622 | Labels for language modeling. 623 | Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` 624 | Indices are selected in ``[-1, 0, ..., config.vocab_size]`` 625 | All labels set to ``-100`` are ignored (masked), the loss is only 626 | computed for labels in ``[0, ..., config.vocab_size]`` 627 | mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`) 628 | Labels for computing the multiple choice classification loss. 629 | Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension 630 | of the input tensors. (see `input_ids` above) 631 | 632 | Return: 633 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs: 634 | lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``lm_labels`` is provided): 635 | Language modeling loss. 636 | mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`multiple_choice_labels` is provided): 637 | Multiple choice classification loss. 638 | lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): 639 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 640 | mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): 641 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). 642 | past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`): 643 | Contains pre-computed hidden-states (key and values in the attention blocks). 644 | Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model 645 | should not be passed as input ids as they have already been computed. 646 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 647 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 648 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 649 | 650 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 651 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 652 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 653 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 654 | 655 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 656 | heads. 657 | 658 | Examples:: 659 | 660 | from transformers import OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel 661 | import torch 662 | 663 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 664 | model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt') 665 | tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!) 666 | model.resize_token_embeddings(len(tokenizer)) 667 | 668 | choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] 669 | input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices 670 | mc_token_ids = torch.tensor([input_ids.size(-1)-1, input_ids.size(-1)-1]).unsqueeze(0) # Batch size 1 671 | 672 | outputs = model(input_ids, mc_token_ids=mc_token_ids) 673 | lm_prediction_scores, mc_prediction_scores = outputs[:2] 674 | 675 | """ 676 | transformer_outputs = self.transformer( 677 | input_ids, 678 | attention_mask=attention_mask, 679 | token_type_ids=token_type_ids, 680 | position_ids=position_ids, 681 | head_mask=head_mask, 682 | inputs_embeds=inputs_embeds, 683 | ) 684 | hidden_states = transformer_outputs[0] 685 | 686 | lm_logits = self.lm_head(hidden_states) 687 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 688 | 689 | outputs = (lm_logits, mc_logits) + transformer_outputs[1:] 690 | if mc_labels is not None: 691 | loss_fct = CrossEntropyLoss() 692 | loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) 693 | outputs = (loss,) + outputs 694 | if lm_labels is not None: 695 | shift_logits = lm_logits[..., :-1, :].contiguous() 696 | shift_labels = lm_labels[..., 1:].contiguous() 697 | loss_fct = CrossEntropyLoss() 698 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 699 | outputs = (loss,) + outputs 700 | 701 | return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions) 702 | --------------------------------------------------------------------------------