├── README.md ├── __init__.py ├── config └── model_config_dialogue_small.json ├── dataset.py ├── generate_dialogue_subset.py ├── interact.py ├── interact_mmi.py ├── requirements.txt ├── train.py ├── transformers ├── __init__.py ├── __main__.py ├── commands │ ├── __init__.py │ ├── convert.py │ ├── download.py │ ├── run.py │ ├── serving.py │ ├── train.py │ └── user.py ├── configuration_albert.py ├── configuration_auto.py ├── configuration_bert.py ├── configuration_camembert.py ├── configuration_ctrl.py ├── configuration_distilbert.py ├── configuration_gpt2.py ├── configuration_mmbt.py ├── configuration_openai.py ├── configuration_roberta.py ├── configuration_t5.py ├── configuration_transfo_xl.py ├── configuration_utils.py ├── configuration_xlm.py ├── configuration_xlm_roberta.py ├── configuration_xlnet.py ├── convert_albert_original_tf_checkpoint_to_pytorch.py ├── convert_bert_original_tf_checkpoint_to_pytorch.py ├── convert_bert_pytorch_checkpoint_to_original_tf.py ├── convert_gpt2_original_tf_checkpoint_to_pytorch.py ├── convert_openai_original_tf_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf2.py ├── convert_roberta_original_pytorch_checkpoint_to_pytorch.py ├── convert_t5_original_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_original_tf_checkpoint_to_pytorch.py ├── convert_xlm_original_pytorch_checkpoint_to_pytorch.py ├── convert_xlnet_original_tf_checkpoint_to_pytorch.py ├── data │ ├── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ └── squad_metrics.py │ └── processors │ │ ├── __init__.py │ │ ├── glue.py │ │ ├── squad.py │ │ ├── utils.py │ │ └── xnli.py ├── file_utils.py ├── hf_api.py ├── modelcard.py ├── modeling_albert.py ├── modeling_auto.py ├── modeling_bert.py ├── modeling_camembert.py ├── modeling_ctrl.py ├── modeling_distilbert.py ├── modeling_encoder_decoder.py ├── modeling_gpt2.py ├── modeling_mmbt.py ├── modeling_openai.py ├── modeling_roberta.py ├── modeling_t5.py ├── modeling_tf_albert.py ├── modeling_tf_auto.py ├── modeling_tf_bert.py ├── modeling_tf_ctrl.py ├── modeling_tf_distilbert.py ├── modeling_tf_gpt2.py ├── modeling_tf_openai.py ├── modeling_tf_pytorch_utils.py ├── modeling_tf_roberta.py ├── modeling_tf_t5.py ├── modeling_tf_transfo_xl.py ├── modeling_tf_transfo_xl_utilities.py ├── modeling_tf_utils.py ├── modeling_tf_xlm.py ├── modeling_tf_xlnet.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_utils.py ├── modeling_xlm.py ├── modeling_xlm_roberta.py ├── modeling_xlnet.py ├── optimization.py ├── optimization_tf.py ├── pipelines.py ├── tokenization_albert.py ├── tokenization_auto.py ├── tokenization_bert.py ├── tokenization_bert_japanese.py ├── tokenization_camembert.py ├── tokenization_ctrl.py ├── tokenization_distilbert.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization_roberta.py ├── tokenization_t5.py ├── tokenization_transfo_xl.py ├── tokenization_utils.py ├── tokenization_xlm.py ├── tokenization_xlm_roberta.py └── tokenization_xlnet.py └── vocabulary └── vocab_small.txt /README.md: -------------------------------------------------------------------------------- 1 | # GPT2 for Chinese Summary 2 | 3 | 4 | ## 项目描述 5 | - 本项目使用 GPT2-Chinese 的模型将wiki中文的数据导入模型训练了通用模型。 6 | - 将GPT2-chitchat的对话任务稍作修改来适用于中文摘要任务。 7 | - 将通用模型的权重应用在摘要问题上进行进一步训练的。 8 | - GPT2-Chinese 参考:https://github.com/Morizeyao/GPT2-Chinese 9 | - GPT2-chitchat参考:https://link.zhihu.com/?target=https%3A//github.com/yangjianxin1/GPT2-chitchat 10 | - 项目工作流程详见:https://zhuanlan.zhihu.com/p/113869509 11 | - 本项目为GPT2-chitchat稍作修改的内容,在此也感谢大佬的分享。 12 | - 由于NLPCC的摘要数据为新闻语料,涉及话题和内容较多,应用在垂直领域下效果会好一些。 13 | 14 | ## 运行环境 15 | python3.6、 transformers==2.1.1、pytorch==1.3.1 16 | 17 | ## 项目结构 18 | - config:存放GPT2模型的参数的配置文件 19 | - data 20 | - train_with_summary.txt:默认的原始训练集文件,存放摘要语料 21 | - train_tokenized.txt:对原始训练语料进行顺序tokenize之后的文件,用于model的训练 22 | - summary_model:存放摘要生成的模型 23 | - vocabulary:存放GPT2模型的字典 24 | - train.py:训练代码 25 | - interact.py:测试代码 26 | 27 | 28 | ## 模型参数(详见config/model_config_dialogue_small.json文件) 29 | - initializer_range: 0.02 30 | - layer_norm_epsilon: 1e-05 31 | - n_ctx: 300 32 | - n_embd: 768 33 | - n_head: 12 34 | - n_layer: 10 35 | - n_positions: 300 36 | - vocab_size: 13317 37 | 38 | ## Chinese Summary 39 | Dialogue Model是基于GPT2模型的生成模型,对每条训练数据进行"顺序"拼接,然后将其输入到网络中,进行训练(该项目没有训练MMI Model的"逆序") 40 | 41 | 在训练Chinese Summary时,将上述训练数据进行如下拼接然后,将上述拼接结果作为Summary Model的输入,对模型进行训练 42 | ```python 43 | [CLS]"四海网讯,近日,有媒体报道称:章子怡真怀孕了!报道还援引知情人士消息称,“章子怡怀孕大概四五个月,预产期是年底前后,现在已经不接工作了。”这到底是怎么回事?消息是真是假?针对此消息,23日晚8时30分,华西都市报记者迅速联系上了与章子怡家里关系极好的知情人士,这位人士向华西都市报记者证实说:“子怡这次确实怀孕了。她已经36岁了,也该怀孕了。章子怡怀上汪峰的孩子后,子怡的父母亲十分高兴。子怡的母亲,已开始悉心照料女儿了。子怡的预产期大概是今年12月底。”当晚9时,华西都市报记者为了求证章子怡怀孕消息,又电话联系章子怡的亲哥哥章子男,但电话通了,一直没有人接听。有关章子怡怀孕的新闻自从2013年9月份章子怡和汪峰恋情以来,就被传N遍了!不过,时间跨入2015年,事情却发生着微妙的变化。2015年3月21日,章子怡担任制片人的电影《从天儿降》开机,在开机发布会上几张合影,让网友又燃起了好奇心:“章子怡真的怀孕了吗?”但后据证实,章子怡的“大肚照”只是影片宣传的噱头。过了四个月的7月22日,《太平轮》新一轮宣传,章子怡又被发现状态不佳,不时深呼吸,不自觉想捂住肚子,又觉得不妥。然后在8月的一天,章子怡和朋友吃饭,在酒店门口被风行工作室拍到了,疑似有孕在身!今年7月11日,汪峰本来在上海要举行演唱会,后来因为台风“灿鸿”取消了。而消息人士称,汪峰原来打算在演唱会上当着章子怡的面宣布重大消息,而且章子怡已经赴上海准备参加演唱会了,怎知遇到台风,只好延期,相信9月26日的演唱会应该还会有惊喜大白天下吧。"[SEP]"知情人透露章子怡怀孕后,父母很高兴。章母已开始悉心照料。据悉,预产期大概是12月底"[SEP] 44 | ``` 45 | 46 | 47 | 48 | ## 模型分享 49 | |模型 | 百度网盘 |模型描述| 50 | |---------|--------|--------| 51 | |GPT2-nlpcc-summary | 链接:https://pan.baidu.com/s/1atsbABI7Lq5HQNctC11E5g
提取码:grtn |使用nlpcc的摘要数据基于GPT2-wiki训练的摘要模型| 52 | |GPT2-wiki | 链接:https://pan.baidu.com/s/1oo1fpuGPYR9IMCcWQzzE9w
提取码:o1aq
复制这段内容后打开百度网盘手机App,操作更方便哦 |使用GPT2-Chinese训练的通用模型| 53 | 54 | ## 模型使用方法 55 | 56 | ##### 将GPT2-nlpcc-summary参数下载,放入summary_model文件夹中,运行即可。 57 | 58 | 为了直观展示,这里使用的交互式的预测形式,将需要预测的文章粘贴到控制台即可,由于中间的解码使用的是sampling的方式,所以每次生成的内容不能保证一致。可以多运行几次,取rouge得分最高的,该项目仅用于探索项目,若要实际项目中使用,还需重新训练通用模型和summary模型。 59 | 60 | 61 | ## 示例 62 | ``` 63 | 发布日期:2015-04-1917:36:34吉安市气象台2015年4月19日17时20分发布雷电黄色预警信号:预计未来6小时内,我市中部将有强雷电活动,局地可伴有短时强降水、雷雨大风等强对流天气,请注意加强防范。图例标准防御指南6小时内可能发生雷电活动,可能会造成雷电灾害事故。1、政府及相关部门按照职责做好防雷工作;2、密切关注天气,尽量避免户外活动。 64 | 65 | summary:发布雷电黄色预警:预计未来6小时内,我市中部将有强雷电活动,局地可伴有短时强降水、雷雨大风等强对流天气,请注意加强防范。... 66 | summary:组图:我市中部地区发布雷电黄色预警:预计未来6小时内,我市中部将有强雷电活动,局地可伴有短时强降水、雷雨大风等强对流天气,请注意加强防范。... 67 | summary:组图:我市中部发布雷电黄色预警:预计未来6小时内,我市中部将有强雷电活动,局地可伴有短时强降水、雷雨大风... 68 | summary:组图:我市中部发布雷电黄色预警:预计未来6小时内,我市中部将有强雷电活动,局地可伴有短时强降水、雷雨大风... 69 | summary:组图:我市中部发布雷电黄色预警:预计未来6小时内,我市中部将有强雷电活动,局地可伴有短时强降水、雷雨大风... 70 | ``` 71 | 72 | ``` 73 | 台海网1月1日讯据中评社报道,国民党12月31日下午在中常会听取副秘书长兼行管会主委林德瑞专案报告党产处理现况,合计党产总值约277亿元(新台币,下同;约54.1亿元人民币);外界质疑党产总值约为申报5倍,包括中国大陆、海通投资资产估计有1千亿以上,林德瑞表示,此纯属臆测,就目前中投公司所有财务资料,并无中国大陆投资,更无投资1千亿元以上。国民党12月31日中常会由代理主席吴敦义主持,林德瑞依指示提出党产现况报报,回应参选党主席补选的新北市长朱立伦日前提出说要处理不当党产及外界质欵马英九处理党产的成效;中常会后当晚文传会即将党产报告放置在官网“常会特稿”栏内,证明国民党处理党产光明磊落 74 | 75 | summary:国民党濛月國日下午在中常会听取报告;外界质疑党产总值约为申报5倍 76 | summary:外媒称党产总值约为申报5倍,外界质疑党产总值约为申报5倍,媒体称此次 77 | summary:外媒称党产总值约为申报5倍,外界质疑党产总值约为申报5倍,媒体称此次报道为申报5倍,外界质疑党产总值约为申报5倍。 78 | summary:外媒称党产总值约为申报5倍,外界质疑为申报5倍,外界质疑为申报5倍。 79 | summary:外媒称党产总值约4倍,外界质疑为申报5倍,外界质疑为申报5倍,外界质疑为申报5倍。 80 | ``` 81 | 82 | 83 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qingkongzhiqian/GPT2-Summary/9567d345065e67c03493eab9392de41b6afe7081/__init__.py -------------------------------------------------------------------------------- /config/model_config_dialogue_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "initializer_range": 0.02, 3 | "layer_norm_epsilon": 1e-05, 4 | "n_ctx": 1024, 5 | "n_embd": 768, 6 | "n_head": 12, 7 | "n_layer": 10, 8 | "n_positions": 1024, 9 | "vocab_size": 13317 10 | } -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | 5 | class MyDataset(Dataset): 6 | """ 7 | 8 | """ 9 | 10 | def __init__(self, data_list): 11 | self.data_list = data_list 12 | 13 | def __getitem__(self, index): 14 | input_ids = self.data_list[index].strip() 15 | input_ids = [int(token_id) for token_id in input_ids.split()] 16 | return input_ids 17 | 18 | def __len__(self): 19 | return len(self.data_list) 20 | -------------------------------------------------------------------------------- /generate_dialogue_subset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import join 3 | import numpy as np 4 | from collections import Counter 5 | import matplotlib.pyplot as plt 6 | from matplotlib.pyplot import MultipleLocator 7 | 8 | 9 | def generate_subset(): 10 | """ 11 | 用于生成训练子集 12 | :return: 13 | """ 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--raw_data_path', default='data/train.txt', type=str, required=False, help='原始训练语料') 16 | parser.add_argument('--subset_size', default=500000, type=int, required=False, help='要获取的对话数据子集的规模') 17 | parser.add_argument('--subset_data_path', default='data', type=str, required=False, 18 | help='数据子集文件路径,指定文件的父目录') 19 | args = parser.parse_args() 20 | with open(args.raw_data_path, "r", encoding="utf8") as f: 21 | data = f.read() 22 | dialogues = data.split("\n\n") 23 | subset_size = min(len(dialogues), args.subset_size) 24 | 25 | with open(join(args.subset_data_path, "train_{}w.txt".format(int(subset_size / 10000))), "w", encoding="utf8") as f: 26 | print("generating subset,please wait a few seconds ") 27 | for dialogue_index, dialogue in enumerate(dialogues): 28 | if dialogue_index >= subset_size: 29 | break 30 | for utterance in dialogue.split("\n"): 31 | f.writelines(utterance + "\n") 32 | f.writelines("\n") 33 | 34 | 35 | def compute_dialogue_length(): 36 | """ 37 | 查看聊天语料中的dialogue的长度分布 38 | :return: 39 | """ 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--raw_data_path', default='data/train.txt', type=str, required=False, help='原始训练语料') 42 | args = parser.parse_args() 43 | with open(args.raw_data_path, "r", encoding="utf8") as f: 44 | data = f.read() 45 | dialogues = data.split("\n\n") 46 | # 统计各个dialogue的长度 47 | dialogues_lengths = [len(dialogue.replace("\n", "")) for dialogue in dialogues] 48 | counter = Counter(dialogues_lengths) # {label:sum(label)} 49 | dialogue_length_arr = list(counter) 50 | num_arr = [counter[element] for element in list(counter)] 51 | print(counter[300]) 52 | 53 | x_major_locator = MultipleLocator(100) # MultipleLocator用于设置刻度间隔 54 | # y_major_locator = MultipleLocator(20000) 55 | ax = plt.gca() # ax为两条坐标轴的实例 56 | ax.xaxis.set_major_locator(x_major_locator) # 把x轴的主刻度设置为10的倍数 57 | # ax.yaxis.set_major_locator(y_major_locator) 58 | 59 | plt.xlabel('dialogue length') 60 | plt.ylabel('number of dialogue') 61 | # plt.plot(dialogue_length_arr, num_arr, c='green') 62 | plt.scatter(dialogue_length_arr, num_arr) 63 | plt.show() 64 | 65 | 66 | if __name__ == '__main__': 67 | compute_dialogue_length() 68 | -------------------------------------------------------------------------------- /interact.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import json 5 | import random 6 | import numpy as np 7 | import argparse 8 | from torch.utils.tensorboard import SummaryWriter 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from torch.nn import DataParallel 12 | import logging 13 | from transformers.modeling_gpt2 import GPT2Config, GPT2LMHeadModel 14 | from transformers import BertTokenizer 15 | from os.path import join, exists 16 | from itertools import zip_longest, chain 17 | # from chatbot.model import DialogueGPT2Model 18 | from dataset import MyDataset 19 | from torch.utils.data import Dataset, DataLoader 20 | from torch.nn import CrossEntropyLoss 21 | from sklearn.model_selection import train_test_split 22 | from train import create_model 23 | import torch.nn.functional as F 24 | 25 | PAD = '[PAD]' 26 | pad_id = 0 27 | 28 | 29 | def set_interact_args(): 30 | """ 31 | Sets up the training arguments. 32 | """ 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--device', default='0,1', type=str, required=False, help='生成设备') 35 | parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature') 36 | parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1') 37 | parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率') 38 | parser.add_argument('--model_config', default='summary_model/config.json', type=str, required=False, 39 | help='模型参数') 40 | parser.add_argument('--log_path', default='data/interacting.log', type=str, required=False, help='interact日志存放位置') 41 | parser.add_argument('--voca_path', default='vocabulary/vocab_small.txt', type=str, required=False, help='选择词库') 42 | parser.add_argument('--dialogue_model_path', default='summary_model/', type=str, required=False, help='对话模型路径') 43 | parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径") 44 | parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, 45 | help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数") 46 | parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的') 47 | parser.add_argument('--max_len', type=int, default=300, help='每个utterance的最大长度,超过指定长度则进行截断') 48 | parser.add_argument('--max_history_len', type=int, default=1, help="dialogue history的最大长度") 49 | parser.add_argument('--no_cuda', default=False, help='不使用GPU进行预测') 50 | return parser.parse_args() 51 | 52 | 53 | def create_logger(args): 54 | """ 55 | 将日志输出到日志文件和控制台 56 | """ 57 | logger = logging.getLogger(__name__) 58 | logger.setLevel(logging.INFO) 59 | 60 | formatter = logging.Formatter( 61 | '%(asctime)s - %(levelname)s - %(message)s') 62 | 63 | # 创建一个handler,用于写入日志文件 64 | file_handler = logging.FileHandler( 65 | filename=args.log_path) 66 | file_handler.setFormatter(formatter) 67 | file_handler.setLevel(logging.INFO) 68 | logger.addHandler(file_handler) 69 | 70 | # 创建一个handler,用于将日志输出到控制台 71 | console = logging.StreamHandler() 72 | console.setLevel(logging.DEBUG) 73 | console.setFormatter(formatter) 74 | logger.addHandler(console) 75 | 76 | return logger 77 | 78 | 79 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 80 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 81 | Args: 82 | logits: logits distribution shape (vocabulary size) 83 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 84 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 85 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 86 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 87 | """ 88 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 89 | top_k = min(top_k, logits.size(-1)) # Safety check 90 | if top_k > 0: 91 | # Remove all tokens with a probability less than the last token of the top-k 92 | # torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices) 93 | # ...表示其他维度由计算机自行推断 94 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 95 | logits[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷 96 | 97 | if top_p > 0.0: 98 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) # 对logits进行递减排序 99 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 100 | 101 | # Remove tokens with cumulative probability above the threshold 102 | sorted_indices_to_remove = cumulative_probs > top_p 103 | # Shift the indices to the right to keep also the first token above the threshold 104 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 105 | sorted_indices_to_remove[..., 0] = 0 106 | 107 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 108 | logits[indices_to_remove] = filter_value 109 | return logits 110 | 111 | 112 | def main(): 113 | args = set_interact_args() 114 | logger = create_logger(args) 115 | # 当用户使用GPU,并且GPU可用时 116 | args.cuda = torch.cuda.is_available() and not args.no_cuda 117 | # args.cuda = False 118 | device = 'cuda' if args.cuda else 'cpu' 119 | logger.info('using device:{}'.format(device)) 120 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 121 | tokenizer = BertTokenizer(vocab_file=args.voca_path) 122 | model = GPT2LMHeadModel.from_pretrained(args.dialogue_model_path) 123 | model.to(device) 124 | model.eval() 125 | 126 | print('***********************Summary model start************************') 127 | 128 | while True: 129 | try: 130 | 131 | text = input() 132 | for i in range(5): 133 | if len(text) : text = text[:1000] 134 | input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 135 | input_ids.extend(tokenizer.encode(text)) 136 | input_ids.append(tokenizer.sep_token_id) 137 | curr_input_tensor = torch.tensor(input_ids).long().to(device) 138 | 139 | generated = [] 140 | # 最多生成max_len个token 141 | for _ in range(args.max_len): 142 | outputs = model(input_ids=curr_input_tensor) 143 | next_token_logits = outputs[0][-1, :] 144 | # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 145 | for id in set(generated): 146 | next_token_logits[id] /= args.repetition_penalty 147 | next_token_logits = next_token_logits / args.temperature 148 | # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token 149 | next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') 150 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) 151 | # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 152 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 153 | if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束 154 | break 155 | generated.append(next_token.item()) 156 | curr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=0) 157 | 158 | text = tokenizer.convert_ids_to_tokens(generated) 159 | print("summary:" + "".join(text)) 160 | 161 | except KeyboardInterrupt: 162 | break 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==2.1.1 2 | pytorch==1.3.1 3 | sklearn 4 | tqdm 5 | numpy 6 | scipy==1.2.1 -------------------------------------------------------------------------------- /transformers/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | 3 | 4 | def main(): 5 | import sys 6 | 7 | if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]: 8 | print( 9 | "First argument to `transformers` command line interface should be one of: \n" 10 | ">> convert serve train predict" 11 | ) 12 | if sys.argv[1] == "convert": 13 | from transformers.commands import convert 14 | 15 | convert(sys.argv) 16 | elif sys.argv[1] == "train": 17 | from transformers.commands import train 18 | 19 | train(sys.argv) 20 | elif sys.argv[1] == "serve": 21 | pass 22 | # from argparse import ArgumentParser 23 | # from transformers.commands.serving import ServeCommand 24 | # parser = ArgumentParser('Transformers CLI tool', usage='transformers serve []') 25 | # commands_parser = parser.add_subparsers(help='transformers-cli command helpers') 26 | 27 | # # Register commands 28 | # ServeCommand.register_subcommand(commands_parser) 29 | 30 | # # Let's go 31 | # args = parser.parse_args() 32 | 33 | # if not hasattr(args, 'func'): 34 | # parser.print_help() 35 | # exit(1) 36 | # # Run 37 | # service = args.func(args) 38 | # service.run() 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /transformers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | 5 | class BaseTransformersCLICommand(ABC): 6 | @staticmethod 7 | @abstractmethod 8 | def register_subcommand(parser: ArgumentParser): 9 | raise NotImplementedError() 10 | 11 | @abstractmethod 12 | def run(self): 13 | raise NotImplementedError() 14 | -------------------------------------------------------------------------------- /transformers/commands/convert.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from logging import getLogger 3 | 4 | from transformers.commands import BaseTransformersCLICommand 5 | 6 | 7 | def convert_command_factory(args: Namespace): 8 | """ 9 | Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint. 10 | :return: ServeCommand 11 | """ 12 | return ConvertCommand( 13 | args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name 14 | ) 15 | 16 | 17 | class ConvertCommand(BaseTransformersCLICommand): 18 | @staticmethod 19 | def register_subcommand(parser: ArgumentParser): 20 | """ 21 | Register this command to argparse so it's available for the transformer-cli 22 | :param parser: Root parser to register command-specific arguments 23 | :return: 24 | """ 25 | train_parser = parser.add_parser( 26 | "convert", 27 | help="CLI tool to run convert model from original " 28 | "author checkpoints to Transformesr PyTorch checkpoints.", 29 | ) 30 | train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.") 31 | train_parser.add_argument( 32 | "--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder." 33 | ) 34 | train_parser.add_argument( 35 | "--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output." 36 | ) 37 | train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.") 38 | train_parser.add_argument( 39 | "--finetuning_task_name", 40 | type=str, 41 | default=None, 42 | help="Optional fine-tuning task name if the TF model was a finetuned model.", 43 | ) 44 | train_parser.set_defaults(func=convert_command_factory) 45 | 46 | def __init__( 47 | self, 48 | model_type: str, 49 | tf_checkpoint: str, 50 | pytorch_dump_output: str, 51 | config: str, 52 | finetuning_task_name: str, 53 | *args 54 | ): 55 | self._logger = getLogger("transformers-cli/converting") 56 | 57 | self._logger.info("Loading model {}".format(model_type)) 58 | self._model_type = model_type 59 | self._tf_checkpoint = tf_checkpoint 60 | self._pytorch_dump_output = pytorch_dump_output 61 | self._config = config 62 | self._finetuning_task_name = finetuning_task_name 63 | 64 | def run(self): 65 | if self._model_type == "bert": 66 | try: 67 | from transformers.convert_bert_original_tf_checkpoint_to_pytorch import ( 68 | convert_tf_checkpoint_to_pytorch, 69 | ) 70 | except ImportError: 71 | msg = ( 72 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 73 | "In that case, it requires TensorFlow to be installed. Please see " 74 | "https://www.tensorflow.org/install/ for installation instructions." 75 | ) 76 | raise ImportError(msg) 77 | 78 | convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) 79 | elif self._model_type == "gpt": 80 | from transformers.convert_openai_original_tf_checkpoint_to_pytorch import ( 81 | convert_openai_checkpoint_to_pytorch, 82 | ) 83 | 84 | convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) 85 | elif self._model_type == "transfo_xl": 86 | try: 87 | from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import ( 88 | convert_transfo_xl_checkpoint_to_pytorch, 89 | ) 90 | except ImportError: 91 | msg = ( 92 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 93 | "In that case, it requires TensorFlow to be installed. Please see " 94 | "https://www.tensorflow.org/install/ for installation instructions." 95 | ) 96 | raise ImportError(msg) 97 | 98 | if "ckpt" in self._tf_checkpoint.lower(): 99 | TF_CHECKPOINT = self._tf_checkpoint 100 | TF_DATASET_FILE = "" 101 | else: 102 | TF_DATASET_FILE = self._tf_checkpoint 103 | TF_CHECKPOINT = "" 104 | convert_transfo_xl_checkpoint_to_pytorch( 105 | TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE 106 | ) 107 | elif self._model_type == "gpt2": 108 | try: 109 | from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import ( 110 | convert_gpt2_checkpoint_to_pytorch, 111 | ) 112 | except ImportError: 113 | msg = ( 114 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 115 | "In that case, it requires TensorFlow to be installed. Please see " 116 | "https://www.tensorflow.org/install/ for installation instructions." 117 | ) 118 | raise ImportError(msg) 119 | 120 | convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) 121 | elif self._model_type == "xlnet": 122 | try: 123 | from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import ( 124 | convert_xlnet_checkpoint_to_pytorch, 125 | ) 126 | except ImportError: 127 | msg = ( 128 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 129 | "In that case, it requires TensorFlow to be installed. Please see " 130 | "https://www.tensorflow.org/install/ for installation instructions." 131 | ) 132 | raise ImportError(msg) 133 | 134 | convert_xlnet_checkpoint_to_pytorch( 135 | self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name 136 | ) 137 | elif self._model_type == "xlm": 138 | from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import ( 139 | convert_xlm_checkpoint_to_pytorch, 140 | ) 141 | 142 | convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) 143 | else: 144 | raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm]") 145 | -------------------------------------------------------------------------------- /transformers/commands/download.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from transformers.commands import BaseTransformersCLICommand 4 | 5 | 6 | def download_command_factory(args): 7 | return DownloadCommand(args.model, args.cache_dir, args.force) 8 | 9 | 10 | class DownloadCommand(BaseTransformersCLICommand): 11 | @staticmethod 12 | def register_subcommand(parser: ArgumentParser): 13 | download_parser = parser.add_parser("download") 14 | download_parser.add_argument( 15 | "--cache-dir", type=str, default=None, help="Path to location to store the models" 16 | ) 17 | download_parser.add_argument( 18 | "--force", action="store_true", help="Force the model to be download even if already in cache-dir" 19 | ) 20 | download_parser.add_argument("model", type=str, help="Name of the model to download") 21 | download_parser.set_defaults(func=download_command_factory) 22 | 23 | def __init__(self, model: str, cache: str, force: bool): 24 | self._model = model 25 | self._cache = cache 26 | self._force = force 27 | 28 | def run(self): 29 | from transformers import AutoModel, AutoTokenizer 30 | 31 | AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) 32 | AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) 33 | -------------------------------------------------------------------------------- /transformers/commands/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import ArgumentParser 3 | 4 | from transformers.commands import BaseTransformersCLICommand 5 | from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline 6 | 7 | 8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 9 | 10 | 11 | def try_infer_format_from_ext(path: str): 12 | if not path: 13 | return "pipe" 14 | 15 | for ext in PipelineDataFormat.SUPPORTED_FORMATS: 16 | if path.endswith(ext): 17 | return ext 18 | 19 | raise Exception( 20 | "Unable to determine file format from file extension {}. " 21 | "Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS) 22 | ) 23 | 24 | 25 | def run_command_factory(args): 26 | nlp = pipeline( 27 | task=args.task, 28 | model=args.model if args.model else None, 29 | config=args.config, 30 | tokenizer=args.tokenizer, 31 | device=args.device, 32 | ) 33 | format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format 34 | reader = PipelineDataFormat.from_str( 35 | format=format, 36 | output_path=args.output, 37 | input_path=args.input, 38 | column=args.column if args.column else nlp.default_input_names, 39 | overwrite=args.overwrite, 40 | ) 41 | return RunCommand(nlp, reader) 42 | 43 | 44 | class RunCommand(BaseTransformersCLICommand): 45 | def __init__(self, nlp: Pipeline, reader: PipelineDataFormat): 46 | self._nlp = nlp 47 | self._reader = reader 48 | 49 | @staticmethod 50 | def register_subcommand(parser: ArgumentParser): 51 | run_parser = parser.add_parser("run", help="Run a pipeline through the CLI") 52 | run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run") 53 | run_parser.add_argument("--input", type=str, help="Path to the file to use for inference") 54 | run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.") 55 | run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.") 56 | run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.") 57 | run_parser.add_argument( 58 | "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)" 59 | ) 60 | run_parser.add_argument( 61 | "--column", 62 | type=str, 63 | help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)", 64 | ) 65 | run_parser.add_argument( 66 | "--format", 67 | type=str, 68 | default="infer", 69 | choices=PipelineDataFormat.SUPPORTED_FORMATS, 70 | help="Input format to read from", 71 | ) 72 | run_parser.add_argument( 73 | "--device", 74 | type=int, 75 | default=-1, 76 | help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)", 77 | ) 78 | run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.") 79 | run_parser.set_defaults(func=run_command_factory) 80 | 81 | def run(self): 82 | nlp, outputs = self._nlp, [] 83 | 84 | for entry in self._reader: 85 | output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry) 86 | if isinstance(output, dict): 87 | outputs.append(output) 88 | else: 89 | outputs += output 90 | 91 | # Saving data 92 | if self._nlp.binary_output: 93 | binary_path = self._reader.save_binary(outputs) 94 | logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path)) 95 | else: 96 | self._reader.save(outputs) 97 | -------------------------------------------------------------------------------- /transformers/commands/serving.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import ArgumentParser, Namespace 3 | from typing import Any, List, Optional, Union 4 | 5 | from transformers import Pipeline 6 | from transformers.commands import BaseTransformersCLICommand 7 | from transformers.pipelines import SUPPORTED_TASKS, pipeline 8 | 9 | 10 | try: 11 | from uvicorn import run 12 | from fastapi import FastAPI, HTTPException, Body 13 | from pydantic import BaseModel 14 | 15 | _serve_dependancies_installed = True 16 | except (ImportError, AttributeError): 17 | BaseModel = object 18 | 19 | def Body(*x, **y): 20 | pass 21 | 22 | _serve_dependancies_installed = False 23 | 24 | 25 | logger = logging.getLogger("transformers-cli/serving") 26 | 27 | 28 | def serve_command_factory(args: Namespace): 29 | """ 30 | Factory function used to instantiate serving server from provided command line arguments. 31 | :return: ServeCommand 32 | """ 33 | nlp = pipeline( 34 | task=args.task, 35 | model=args.model if args.model else None, 36 | config=args.config, 37 | tokenizer=args.tokenizer, 38 | device=args.device, 39 | ) 40 | return ServeCommand(nlp, args.host, args.port) 41 | 42 | 43 | class ServeModelInfoResult(BaseModel): 44 | """ 45 | Expose model information 46 | """ 47 | 48 | infos: dict 49 | 50 | 51 | class ServeTokenizeResult(BaseModel): 52 | """ 53 | Tokenize result model 54 | """ 55 | 56 | tokens: List[str] 57 | tokens_ids: Optional[List[int]] 58 | 59 | 60 | class ServeDeTokenizeResult(BaseModel): 61 | """ 62 | DeTokenize result model 63 | """ 64 | 65 | text: str 66 | 67 | 68 | class ServeForwardResult(BaseModel): 69 | """ 70 | Forward result model 71 | """ 72 | 73 | output: Any 74 | 75 | 76 | class ServeCommand(BaseTransformersCLICommand): 77 | @staticmethod 78 | def register_subcommand(parser: ArgumentParser): 79 | """ 80 | Register this command to argparse so it's available for the transformer-cli 81 | :param parser: Root parser to register command-specific arguments 82 | :return: 83 | """ 84 | serve_parser = parser.add_parser( 85 | "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints." 86 | ) 87 | serve_parser.add_argument( 88 | "--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on" 89 | ) 90 | serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") 91 | serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.") 92 | serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.") 93 | serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.") 94 | serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.") 95 | serve_parser.add_argument( 96 | "--device", 97 | type=int, 98 | default=-1, 99 | help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)", 100 | ) 101 | serve_parser.set_defaults(func=serve_command_factory) 102 | 103 | def __init__(self, pipeline: Pipeline, host: str, port: int): 104 | 105 | self._pipeline = pipeline 106 | 107 | self._host = host 108 | self._port = port 109 | if not _serve_dependancies_installed: 110 | raise RuntimeError( 111 | "Using serve command requires FastAPI and unicorn. " 112 | "Please install transformers with [serving]: pip install transformers[serving]." 113 | "Or install FastAPI and unicorn separatly." 114 | ) 115 | else: 116 | logger.info("Serving model over {}:{}".format(host, port)) 117 | self._app = FastAPI() 118 | 119 | # Register routes 120 | self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"]) 121 | self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"]) 122 | self._app.add_api_route( 123 | "/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"] 124 | ) 125 | self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"]) 126 | 127 | def run(self): 128 | run(self._app, host=self._host, port=self._port) 129 | 130 | def model_info(self): 131 | return ServeModelInfoResult(infos=vars(self._pipeline.model.config)) 132 | 133 | def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)): 134 | """ 135 | Tokenize the provided input and eventually returns corresponding tokens id: 136 | - **text_input**: String to tokenize 137 | - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping. 138 | """ 139 | try: 140 | tokens_txt = self._pipeline.tokenizer.tokenize(text_input) 141 | 142 | if return_ids: 143 | tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt) 144 | return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids) 145 | else: 146 | return ServeTokenizeResult(tokens=tokens_txt) 147 | 148 | except Exception as e: 149 | raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) 150 | 151 | def detokenize( 152 | self, 153 | tokens_ids: List[int] = Body(None, embed=True), 154 | skip_special_tokens: bool = Body(False, embed=True), 155 | cleanup_tokenization_spaces: bool = Body(True, embed=True), 156 | ): 157 | """ 158 | Detokenize the provided tokens ids to readable text: 159 | - **tokens_ids**: List of tokens ids 160 | - **skip_special_tokens**: Flag indicating to not try to decode special tokens 161 | - **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones. 162 | """ 163 | try: 164 | decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces) 165 | return ServeDeTokenizeResult(model="", text=decoded_str) 166 | except Exception as e: 167 | raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) 168 | 169 | def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)): 170 | """ 171 | **inputs**: 172 | **attention_mask**: 173 | **tokens_type_ids**: 174 | """ 175 | 176 | # Check we don't have empty string 177 | if len(inputs) == 0: 178 | return ServeForwardResult(output=[], attention=[]) 179 | 180 | try: 181 | # Forward through the model 182 | output = self._pipeline(inputs) 183 | return ServeForwardResult(output=output) 184 | except Exception as e: 185 | raise HTTPException(500, {"error": str(e)}) 186 | -------------------------------------------------------------------------------- /transformers/commands/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser, Namespace 3 | from logging import getLogger 4 | 5 | from transformers import SingleSentenceClassificationProcessor as Processor 6 | from transformers import TextClassificationPipeline, is_tf_available, is_torch_available 7 | from transformers.commands import BaseTransformersCLICommand 8 | 9 | 10 | if not is_tf_available() and not is_torch_available(): 11 | raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") 12 | 13 | # TF training parameters 14 | USE_XLA = False 15 | USE_AMP = False 16 | 17 | 18 | def train_command_factory(args: Namespace): 19 | """ 20 | Factory function used to instantiate serving server from provided command line arguments. 21 | :return: ServeCommand 22 | """ 23 | return TrainCommand(args) 24 | 25 | 26 | class TrainCommand(BaseTransformersCLICommand): 27 | @staticmethod 28 | def register_subcommand(parser: ArgumentParser): 29 | """ 30 | Register this command to argparse so it's available for the transformer-cli 31 | :param parser: Root parser to register command-specific arguments 32 | :return: 33 | """ 34 | train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.") 35 | 36 | train_parser.add_argument( 37 | "--train_data", 38 | type=str, 39 | required=True, 40 | help="path to train (and optionally evaluation) dataset as a csv with " 41 | "tab separated labels and sentences.", 42 | ) 43 | train_parser.add_argument( 44 | "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels." 45 | ) 46 | train_parser.add_argument( 47 | "--column_text", type=int, default=1, help="Column of the dataset csv file with example texts." 48 | ) 49 | train_parser.add_argument( 50 | "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids." 51 | ) 52 | train_parser.add_argument( 53 | "--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)." 54 | ) 55 | 56 | train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.") 57 | train_parser.add_argument( 58 | "--validation_split", 59 | type=float, 60 | default=0.1, 61 | help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.", 62 | ) 63 | 64 | train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.") 65 | 66 | train_parser.add_argument( 67 | "--task", type=str, default="text_classification", help="Task to train the model on." 68 | ) 69 | train_parser.add_argument( 70 | "--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model." 71 | ) 72 | train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.") 73 | train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.") 74 | train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.") 75 | train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.") 76 | train_parser.set_defaults(func=train_command_factory) 77 | 78 | def __init__(self, args: Namespace): 79 | self.logger = getLogger("transformers-cli/training") 80 | 81 | self.framework = "tf" if is_tf_available() else "torch" 82 | 83 | os.makedirs(args.output, exist_ok=True) 84 | assert os.path.isdir(args.output) 85 | self.output = args.output 86 | 87 | self.column_label = args.column_label 88 | self.column_text = args.column_text 89 | self.column_id = args.column_id 90 | 91 | self.logger.info("Loading {} pipeline for {}".format(args.task, args.model)) 92 | if args.task == "text_classification": 93 | self.pipeline = TextClassificationPipeline.from_pretrained(args.model) 94 | elif args.task == "token_classification": 95 | raise NotImplementedError 96 | elif args.task == "question_answering": 97 | raise NotImplementedError 98 | 99 | self.logger.info("Loading dataset from {}".format(args.train_data)) 100 | self.train_dataset = Processor.create_from_csv( 101 | args.train_data, 102 | column_label=args.column_label, 103 | column_text=args.column_text, 104 | column_id=args.column_id, 105 | skip_first_row=args.skip_first_row, 106 | ) 107 | self.valid_dataset = None 108 | if args.validation_data: 109 | self.logger.info("Loading validation dataset from {}".format(args.validation_data)) 110 | self.valid_dataset = Processor.create_from_csv( 111 | args.validation_data, 112 | column_label=args.column_label, 113 | column_text=args.column_text, 114 | column_id=args.column_id, 115 | skip_first_row=args.skip_first_row, 116 | ) 117 | 118 | self.validation_split = args.validation_split 119 | self.train_batch_size = args.train_batch_size 120 | self.valid_batch_size = args.valid_batch_size 121 | self.learning_rate = args.learning_rate 122 | self.adam_epsilon = args.adam_epsilon 123 | 124 | def run(self): 125 | if self.framework == "tf": 126 | return self.run_tf() 127 | return self.run_torch() 128 | 129 | def run_torch(self): 130 | raise NotImplementedError 131 | 132 | def run_tf(self): 133 | self.pipeline.fit( 134 | self.train_dataset, 135 | validation_data=self.valid_dataset, 136 | validation_split=self.validation_split, 137 | learning_rate=self.learning_rate, 138 | adam_epsilon=self.adam_epsilon, 139 | train_batch_size=self.train_batch_size, 140 | valid_batch_size=self.valid_batch_size, 141 | ) 142 | 143 | # Save trained pipeline 144 | self.pipeline.save_pretrained(self.output) 145 | -------------------------------------------------------------------------------- /transformers/commands/user.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from getpass import getpass 4 | from typing import List, Union 5 | 6 | from requests.exceptions import HTTPError 7 | 8 | from transformers.commands import BaseTransformersCLICommand 9 | from transformers.hf_api import HfApi, HfFolder 10 | 11 | 12 | class UserCommands(BaseTransformersCLICommand): 13 | @staticmethod 14 | def register_subcommand(parser: ArgumentParser): 15 | login_parser = parser.add_parser("login") 16 | login_parser.set_defaults(func=lambda args: LoginCommand(args)) 17 | whoami_parser = parser.add_parser("whoami") 18 | whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) 19 | logout_parser = parser.add_parser("logout") 20 | logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) 21 | list_parser = parser.add_parser("ls") 22 | list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) 23 | # upload 24 | upload_parser = parser.add_parser("upload") 25 | upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.") 26 | upload_parser.add_argument( 27 | "--filename", type=str, default=None, help="Optional: override individual object filename on S3." 28 | ) 29 | upload_parser.set_defaults(func=lambda args: UploadCommand(args)) 30 | 31 | 32 | class ANSI: 33 | """ 34 | Helper for en.wikipedia.org/wiki/ANSI_escape_code 35 | """ 36 | 37 | _bold = "\u001b[1m" 38 | _reset = "\u001b[0m" 39 | 40 | @classmethod 41 | def bold(cls, s): 42 | return "{}{}{}".format(cls._bold, s, cls._reset) 43 | 44 | 45 | class BaseUserCommand: 46 | def __init__(self, args): 47 | self.args = args 48 | self._api = HfApi() 49 | 50 | 51 | class LoginCommand(BaseUserCommand): 52 | def run(self): 53 | print( 54 | """ 55 | _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| 56 | _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| 57 | _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| 58 | _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| 59 | _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| 60 | 61 | """ 62 | ) 63 | username = input("Username: ") 64 | password = getpass() 65 | try: 66 | token = self._api.login(username, password) 67 | except HTTPError as e: 68 | # probably invalid credentials, display error message. 69 | print(e) 70 | exit(1) 71 | HfFolder.save_token(token) 72 | print("Login successful") 73 | print("Your token:", token, "\n") 74 | print("Your token has been saved to", HfFolder.path_token) 75 | 76 | 77 | class WhoamiCommand(BaseUserCommand): 78 | def run(self): 79 | token = HfFolder.get_token() 80 | if token is None: 81 | print("Not logged in") 82 | exit() 83 | try: 84 | user = self._api.whoami(token) 85 | print(user) 86 | except HTTPError as e: 87 | print(e) 88 | 89 | 90 | class LogoutCommand(BaseUserCommand): 91 | def run(self): 92 | token = HfFolder.get_token() 93 | if token is None: 94 | print("Not logged in") 95 | exit() 96 | HfFolder.delete_token() 97 | self._api.logout(token) 98 | print("Successfully logged out.") 99 | 100 | 101 | class ListObjsCommand(BaseUserCommand): 102 | def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str: 103 | """ 104 | Inspired by: 105 | stackoverflow.com/a/8356620/593036 106 | stackoverflow.com/questions/9535954/printing-lists-as-tabular-data 107 | """ 108 | col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] 109 | row_format = ("{{:{}}} " * len(headers)).format(*col_widths) 110 | lines = [] 111 | lines.append(row_format.format(*headers)) 112 | lines.append(row_format.format(*["-" * w for w in col_widths])) 113 | for row in rows: 114 | lines.append(row_format.format(*row)) 115 | return "\n".join(lines) 116 | 117 | def run(self): 118 | token = HfFolder.get_token() 119 | if token is None: 120 | print("Not logged in") 121 | exit(1) 122 | try: 123 | objs = self._api.list_objs(token) 124 | except HTTPError as e: 125 | print(e) 126 | exit(1) 127 | if len(objs) == 0: 128 | print("No shared file yet") 129 | exit() 130 | rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs] 131 | print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])) 132 | 133 | 134 | class UploadCommand(BaseUserCommand): 135 | def walk_dir(self, rel_path): 136 | """ 137 | Recursively list all files in a folder. 138 | """ 139 | entries: List[os.DirEntry] = list(os.scandir(rel_path)) 140 | files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename 141 | for f in entries: 142 | if f.is_dir(): 143 | files += self.walk_dir(f.path) 144 | return files 145 | 146 | def run(self): 147 | token = HfFolder.get_token() 148 | if token is None: 149 | print("Not logged in") 150 | exit(1) 151 | local_path = os.path.abspath(self.args.path) 152 | if os.path.isdir(local_path): 153 | if self.args.filename is not None: 154 | raise ValueError("Cannot specify a filename override when uploading a folder.") 155 | rel_path = os.path.basename(local_path) 156 | files = self.walk_dir(rel_path) 157 | elif os.path.isfile(local_path): 158 | filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path) 159 | files = [(local_path, filename)] 160 | else: 161 | raise ValueError("Not a valid file or directory: {}".format(local_path)) 162 | 163 | for filepath, filename in files: 164 | print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename))) 165 | 166 | choice = input("Proceed? [Y/n] ").lower() 167 | if not (choice == "" or choice == "y" or choice == "yes"): 168 | print("Abort") 169 | exit() 170 | print(ANSI.bold("Uploading... This might take a while if files are large")) 171 | for filepath, filename in files: 172 | access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath) 173 | print("Your file now lives at:") 174 | print(access_url) 175 | -------------------------------------------------------------------------------- /transformers/configuration_albert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ ALBERT model configuration """ 17 | 18 | from .configuration_utils import PretrainedConfig 19 | 20 | 21 | ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 22 | "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json", 23 | "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json", 24 | "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json", 25 | "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json", 26 | "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json", 27 | "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json", 28 | "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json", 29 | "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json", 30 | } 31 | 32 | 33 | class AlbertConfig(PretrainedConfig): 34 | """Configuration for `AlbertModel`. 35 | 36 | The default settings match the configuration of model `albert_xxlarge`. 37 | """ 38 | 39 | pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 40 | 41 | def __init__( 42 | self, 43 | vocab_size=30000, 44 | embedding_size=128, 45 | hidden_size=4096, 46 | num_hidden_layers=12, 47 | num_hidden_groups=1, 48 | num_attention_heads=64, 49 | intermediate_size=16384, 50 | inner_group_num=1, 51 | hidden_act="gelu_new", 52 | hidden_dropout_prob=0, 53 | attention_probs_dropout_prob=0, 54 | max_position_embeddings=512, 55 | type_vocab_size=2, 56 | initializer_range=0.02, 57 | layer_norm_eps=1e-12, 58 | **kwargs 59 | ): 60 | """Constructs AlbertConfig. 61 | 62 | Args: 63 | vocab_size: Vocabulary size of `inputs_ids` in `AlbertModel`. 64 | embedding_size: size of voc embeddings. 65 | hidden_size: Size of the encoder layers and the pooler layer. 66 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 67 | num_hidden_groups: Number of group for the hidden layers, parameters in 68 | the same group are shared. 69 | num_attention_heads: Number of attention heads for each attention layer in 70 | the Transformer encoder. 71 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 72 | layer in the Transformer encoder. 73 | inner_group_num: int, number of inner repetition of attention and ffn. 74 | down_scale_factor: float, the scale to apply 75 | hidden_act: The non-linear activation function (function or string) in the 76 | encoder and pooler. 77 | hidden_dropout_prob: The dropout probability for all fully connected 78 | layers in the embeddings, encoder, and pooler. 79 | attention_probs_dropout_prob: The dropout ratio for the attention 80 | probabilities. 81 | max_position_embeddings: The maximum sequence length that this model might 82 | ever be used with. Typically set this to something large just in case 83 | (e.g., 512 or 1024 or 2048). 84 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 85 | `AlbertModel`. 86 | initializer_range: The stdev of the truncated_normal_initializer for 87 | initializing all weight matrices. 88 | """ 89 | super(AlbertConfig, self).__init__(**kwargs) 90 | 91 | self.vocab_size = vocab_size 92 | self.embedding_size = embedding_size 93 | self.hidden_size = hidden_size 94 | self.num_hidden_layers = num_hidden_layers 95 | self.num_hidden_groups = num_hidden_groups 96 | self.num_attention_heads = num_attention_heads 97 | self.inner_group_num = inner_group_num 98 | self.hidden_act = hidden_act 99 | self.intermediate_size = intermediate_size 100 | self.hidden_dropout_prob = hidden_dropout_prob 101 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 102 | self.max_position_embeddings = max_position_embeddings 103 | self.type_vocab_size = type_vocab_size 104 | self.initializer_range = initializer_range 105 | self.layer_norm_eps = layer_norm_eps 106 | -------------------------------------------------------------------------------- /transformers/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 42 | "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", 43 | "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", 44 | "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", 45 | "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json", 46 | "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 47 | "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 48 | } 49 | 50 | 51 | class BertConfig(PretrainedConfig): 52 | r""" 53 | :class:`~transformers.BertConfig` is the configuration class to store the configuration of a 54 | `BertModel`. 55 | 56 | 57 | Arguments: 58 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 59 | hidden_size: Size of the encoder layers and the pooler layer. 60 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 61 | num_attention_heads: Number of attention heads for each attention layer in 62 | the Transformer encoder. 63 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 64 | layer in the Transformer encoder. 65 | hidden_act: The non-linear activation function (function or string) in the 66 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 67 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 68 | layers in the embeddings, encoder, and pooler. 69 | attention_probs_dropout_prob: The dropout ratio for the attention 70 | probabilities. 71 | max_position_embeddings: The maximum sequence length that this model might 72 | ever be used with. Typically set this to something large just in case 73 | (e.g., 512 or 1024 or 2048). 74 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 75 | `BertModel`. 76 | initializer_range: The sttdev of the truncated_normal_initializer for 77 | initializing all weight matrices. 78 | layer_norm_eps: The epsilon used by LayerNorm. 79 | """ 80 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 81 | 82 | def __init__( 83 | self, 84 | vocab_size=30522, 85 | hidden_size=768, 86 | num_hidden_layers=12, 87 | num_attention_heads=12, 88 | intermediate_size=3072, 89 | hidden_act="gelu", 90 | hidden_dropout_prob=0.1, 91 | attention_probs_dropout_prob=0.1, 92 | max_position_embeddings=512, 93 | type_vocab_size=2, 94 | initializer_range=0.02, 95 | layer_norm_eps=1e-12, 96 | **kwargs 97 | ): 98 | super(BertConfig, self).__init__(**kwargs) 99 | self.vocab_size = vocab_size 100 | self.hidden_size = hidden_size 101 | self.num_hidden_layers = num_hidden_layers 102 | self.num_attention_heads = num_attention_heads 103 | self.hidden_act = hidden_act 104 | self.intermediate_size = intermediate_size 105 | self.hidden_dropout_prob = hidden_dropout_prob 106 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 107 | self.max_position_embeddings = max_position_embeddings 108 | self.type_vocab_size = type_vocab_size 109 | self.initializer_range = initializer_range 110 | self.layer_norm_eps = layer_norm_eps 111 | -------------------------------------------------------------------------------- /transformers/configuration_camembert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ CamemBERT configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_roberta import RobertaConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", 28 | } 29 | 30 | 31 | class CamembertConfig(RobertaConfig): 32 | pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 33 | -------------------------------------------------------------------------------- /transformers/configuration_ctrl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Salesforce CTRL configuration """ 16 | 17 | 18 | import logging 19 | 20 | from .configuration_utils import PretrainedConfig 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"} 26 | 27 | 28 | class CTRLConfig(PretrainedConfig): 29 | """Configuration class to store the configuration of a `CTRLModel`. 30 | 31 | Args: 32 | vocab_size: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 33 | n_positions: Number of positional embeddings. 34 | n_ctx: Size of the causal mask (usually same as n_positions). 35 | dff: Size of the inner dimension of the FFN. 36 | n_embd: Dimensionality of the embeddings and hidden states. 37 | n_layer: Number of hidden layers in the Transformer encoder. 38 | n_head: Number of attention heads for each attention layer in 39 | the Transformer encoder. 40 | layer_norm_epsilon: epsilon to use in the layer norm layers 41 | resid_pdrop: The dropout probabilitiy for all fully connected 42 | layers in the embeddings, encoder, and pooler. 43 | attn_pdrop: The dropout ratio for the attention 44 | probabilities. 45 | embd_pdrop: The dropout ratio for the embeddings. 46 | initializer_range: The sttdev of the truncated_normal_initializer for 47 | initializing all weight matrices. 48 | """ 49 | 50 | pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP 51 | 52 | def __init__( 53 | self, 54 | vocab_size=246534, 55 | n_positions=256, 56 | n_ctx=256, 57 | n_embd=1280, 58 | dff=8192, 59 | n_layer=48, 60 | n_head=16, 61 | resid_pdrop=0.1, 62 | embd_pdrop=0.1, 63 | attn_pdrop=0.1, 64 | layer_norm_epsilon=1e-6, 65 | initializer_range=0.02, 66 | summary_type="cls_index", 67 | summary_use_proj=True, 68 | summary_activation=None, 69 | summary_proj_to_labels=True, 70 | summary_first_dropout=0.1, 71 | **kwargs 72 | ): 73 | """Constructs CTRLConfig. 74 | 75 | Args: 76 | vocab_size: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 77 | n_positions: Number of positional embeddings. 78 | n_ctx: Size of the causal mask (usually same as n_positions). 79 | dff: Size of the inner dimension of the FFN. 80 | n_embd: Dimensionality of the embeddings and hidden states. 81 | n_layer: Number of hidden layers in the Transformer encoder. 82 | n_head: Number of attention heads for each attention layer in 83 | the Transformer encoder. 84 | layer_norm_epsilon: epsilon to use in the layer norm layers 85 | resid_pdrop: The dropout probabilitiy for all fully connected 86 | layers in the embeddings, encoder, and pooler. 87 | attn_pdrop: The dropout ratio for the attention 88 | probabilities. 89 | embd_pdrop: The dropout ratio for the embeddings. 90 | initializer_range: The sttdev of the truncated_normal_initializer for 91 | initializing all weight matrices. 92 | """ 93 | super(CTRLConfig, self).__init__(**kwargs) 94 | self.vocab_size = vocab_size 95 | self.n_ctx = n_ctx 96 | self.n_positions = n_positions 97 | self.n_embd = n_embd 98 | self.n_layer = n_layer 99 | self.n_head = n_head 100 | self.dff = dff 101 | self.resid_pdrop = resid_pdrop 102 | self.embd_pdrop = embd_pdrop 103 | self.attn_pdrop = attn_pdrop 104 | self.layer_norm_epsilon = layer_norm_epsilon 105 | self.initializer_range = initializer_range 106 | 107 | self.summary_type = summary_type 108 | self.summary_use_proj = summary_use_proj 109 | self.summary_activation = summary_activation 110 | self.summary_first_dropout = summary_first_dropout 111 | self.summary_proj_to_labels = summary_proj_to_labels 112 | 113 | @property 114 | def max_position_embeddings(self): 115 | return self.n_positions 116 | 117 | @property 118 | def hidden_size(self): 119 | return self.n_embd 120 | 121 | @property 122 | def num_attention_heads(self): 123 | return self.n_head 124 | 125 | @property 126 | def num_hidden_layers(self): 127 | return self.n_layer 128 | -------------------------------------------------------------------------------- /transformers/configuration_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ DistilBERT model configuration """ 16 | 17 | 18 | import logging 19 | 20 | from .configuration_utils import PretrainedConfig 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 26 | "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", 27 | "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json", 28 | "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json", 29 | "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json", 30 | } 31 | 32 | 33 | class DistilBertConfig(PretrainedConfig): 34 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 35 | 36 | def __init__( 37 | self, 38 | vocab_size=30522, 39 | max_position_embeddings=512, 40 | sinusoidal_pos_embds=False, 41 | n_layers=6, 42 | n_heads=12, 43 | dim=768, 44 | hidden_dim=4 * 768, 45 | dropout=0.1, 46 | attention_dropout=0.1, 47 | activation="gelu", 48 | initializer_range=0.02, 49 | tie_weights_=True, 50 | qa_dropout=0.1, 51 | seq_classif_dropout=0.2, 52 | **kwargs 53 | ): 54 | super(DistilBertConfig, self).__init__(**kwargs) 55 | self.vocab_size = vocab_size 56 | self.max_position_embeddings = max_position_embeddings 57 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 58 | self.n_layers = n_layers 59 | self.n_heads = n_heads 60 | self.dim = dim 61 | self.hidden_dim = hidden_dim 62 | self.dropout = dropout 63 | self.attention_dropout = attention_dropout 64 | self.activation = activation 65 | self.initializer_range = initializer_range 66 | self.tie_weights_ = tie_weights_ 67 | self.qa_dropout = qa_dropout 68 | self.seq_classif_dropout = seq_classif_dropout 69 | 70 | @property 71 | def hidden_size(self): 72 | return self.dim 73 | 74 | @property 75 | def num_attention_heads(self): 76 | return self.n_heads 77 | 78 | @property 79 | def num_hidden_layers(self): 80 | return self.n_layers 81 | -------------------------------------------------------------------------------- /transformers/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", 28 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", 29 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", 30 | "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", 31 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json", 32 | } 33 | 34 | 35 | class GPT2Config(PretrainedConfig): 36 | """Configuration class to store the configuration of a `GPT2Model`. 37 | 38 | Args: 39 | vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 40 | n_positions: Number of positional embeddings. 41 | n_ctx: Size of the causal mask (usually same as n_positions). 42 | n_embd: Dimensionality of the embeddings and hidden states. 43 | n_layer: Number of hidden layers in the Transformer encoder. 44 | n_head: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | layer_norm_epsilon: epsilon to use in the layer norm layers 47 | resid_pdrop: The dropout probabilitiy for all fully connected 48 | layers in the embeddings, encoder, and pooler. 49 | attn_pdrop: The dropout ratio for the attention 50 | probabilities. 51 | embd_pdrop: The dropout ratio for the embeddings. 52 | initializer_range: The sttdev of the truncated_normal_initializer for 53 | initializing all weight matrices. 54 | """ 55 | 56 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 57 | 58 | def __init__( 59 | self, 60 | vocab_size=50257, 61 | n_positions=1024, 62 | n_ctx=1024, 63 | n_embd=768, 64 | n_layer=12, 65 | n_head=12, 66 | resid_pdrop=0.1, 67 | embd_pdrop=0.1, 68 | attn_pdrop=0.1, 69 | layer_norm_epsilon=1e-5, 70 | initializer_range=0.02, 71 | summary_type="cls_index", 72 | summary_use_proj=True, 73 | summary_activation=None, 74 | summary_proj_to_labels=True, 75 | summary_first_dropout=0.1, 76 | **kwargs 77 | ): 78 | """Constructs GPT2Config. 79 | 80 | Args: 81 | vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 82 | n_positions: Number of positional embeddings. 83 | n_ctx: Size of the causal mask (usually same as n_positions). 84 | n_embd: Dimensionality of the embeddings and hidden states. 85 | n_layer: Number of hidden layers in the Transformer encoder. 86 | n_head: Number of attention heads for each attention layer in 87 | the Transformer encoder. 88 | layer_norm_epsilon: epsilon to use in the layer norm layers 89 | resid_pdrop: The dropout probabilitiy for all fully connected 90 | layers in the embeddings, encoder, and pooler. 91 | attn_pdrop: The dropout ratio for the attention 92 | probabilities. 93 | embd_pdrop: The dropout ratio for the embeddings. 94 | initializer_range: The sttdev of the truncated_normal_initializer for 95 | initializing all weight matrices. 96 | """ 97 | super(GPT2Config, self).__init__(**kwargs) 98 | self.vocab_size = vocab_size 99 | self.n_ctx = n_ctx 100 | self.n_positions = n_positions 101 | self.n_embd = n_embd 102 | self.n_layer = n_layer 103 | self.n_head = n_head 104 | self.resid_pdrop = resid_pdrop 105 | self.embd_pdrop = embd_pdrop 106 | self.attn_pdrop = attn_pdrop 107 | self.layer_norm_epsilon = layer_norm_epsilon 108 | self.initializer_range = initializer_range 109 | self.summary_type = summary_type 110 | self.summary_use_proj = summary_use_proj 111 | self.summary_activation = summary_activation 112 | self.summary_first_dropout = summary_first_dropout 113 | self.summary_proj_to_labels = summary_proj_to_labels 114 | 115 | @property 116 | def max_position_embeddings(self): 117 | return self.n_positions 118 | 119 | @property 120 | def hidden_size(self): 121 | return self.n_embd 122 | 123 | @property 124 | def num_attention_heads(self): 125 | return self.n_head 126 | 127 | @property 128 | def num_hidden_layers(self): 129 | return self.n_layer 130 | -------------------------------------------------------------------------------- /transformers/configuration_mmbt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # Copyright (c) HuggingFace Inc. team. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ MMBT configuration """ 17 | 18 | 19 | import logging 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class MMBTConfig(object): 26 | """Configuration class to store the configuration of a `MMBT Model`. 27 | 28 | Args: 29 | config: config of the underlying Transformer models. It's values are copied over to use a single config. 30 | num_labels: Size of final Linear layer for classification. 31 | modal_hidden_size: Embedding dimension of the non-text modality encoder. 32 | """ 33 | 34 | def __init__(self, config, num_labels=None, modal_hidden_size=2048): 35 | self.__dict__ = config.__dict__ 36 | self.modal_hidden_size = modal_hidden_size 37 | if num_labels: 38 | self.num_labels = num_labels 39 | -------------------------------------------------------------------------------- /transformers/configuration_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" 28 | } 29 | 30 | 31 | class OpenAIGPTConfig(PretrainedConfig): 32 | """ 33 | Configuration class to store the configuration of a `OpenAIGPTModel`. 34 | 35 | Args: 36 | vocab_size: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. 37 | n_positions: Number of positional embeddings. 38 | n_ctx: Size of the causal mask (usually same as n_positions). 39 | n_embd: Dimensionality of the embeddings and hidden states. 40 | n_layer: Number of hidden layers in the Transformer encoder. 41 | n_head: Number of attention heads for each attention layer in 42 | the Transformer encoder. 43 | afn: The non-linear activation function (function or string) in the 44 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 45 | resid_pdrop: The dropout probabilitiy for all fully connected 46 | layers in the embeddings, encoder, and pooler. 47 | attn_pdrop: The dropout ratio for the attention 48 | probabilities. 49 | embd_pdrop: The dropout ratio for the embeddings. 50 | layer_norm_epsilon: epsilon to use in the layer norm layers 51 | initializer_range: The sttdev of the truncated_normal_initializer for 52 | initializing all weight matrices. 53 | predict_special_tokens: should we predict special tokens (when the model has a LM head) 54 | """ 55 | 56 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 57 | 58 | def __init__( 59 | self, 60 | vocab_size=40478, 61 | n_positions=512, 62 | n_ctx=512, 63 | n_embd=768, 64 | n_layer=12, 65 | n_head=12, 66 | afn="gelu", 67 | resid_pdrop=0.1, 68 | embd_pdrop=0.1, 69 | attn_pdrop=0.1, 70 | layer_norm_epsilon=1e-5, 71 | initializer_range=0.02, 72 | predict_special_tokens=True, 73 | summary_type="cls_index", 74 | summary_use_proj=True, 75 | summary_activation=None, 76 | summary_proj_to_labels=True, 77 | summary_first_dropout=0.1, 78 | **kwargs 79 | ): 80 | """Constructs OpenAIGPTConfig. 81 | """ 82 | super(OpenAIGPTConfig, self).__init__(**kwargs) 83 | self.vocab_size = vocab_size 84 | self.n_ctx = n_ctx 85 | self.n_positions = n_positions 86 | self.n_embd = n_embd 87 | self.n_layer = n_layer 88 | self.n_head = n_head 89 | self.afn = afn 90 | self.resid_pdrop = resid_pdrop 91 | self.embd_pdrop = embd_pdrop 92 | self.attn_pdrop = attn_pdrop 93 | self.layer_norm_epsilon = layer_norm_epsilon 94 | self.initializer_range = initializer_range 95 | self.predict_special_tokens = predict_special_tokens 96 | self.summary_type = summary_type 97 | self.summary_use_proj = summary_use_proj 98 | self.summary_activation = summary_activation 99 | self.summary_first_dropout = summary_first_dropout 100 | self.summary_proj_to_labels = summary_proj_to_labels 101 | 102 | @property 103 | def max_position_embeddings(self): 104 | return self.n_positions 105 | 106 | @property 107 | def hidden_size(self): 108 | return self.n_embd 109 | 110 | @property 111 | def num_attention_heads(self): 112 | return self.n_head 113 | 114 | @property 115 | def num_hidden_layers(self): 116 | return self.n_layer 117 | -------------------------------------------------------------------------------- /transformers/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_bert import BertConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 28 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 29 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 30 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json", 31 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json", 32 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json", 33 | } 34 | 35 | 36 | class RobertaConfig(BertConfig): 37 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 38 | -------------------------------------------------------------------------------- /transformers/configuration_t5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2010, The T5 Authors and HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ T5 model configuration """ 16 | 17 | 18 | import logging 19 | 20 | from .configuration_utils import PretrainedConfig 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | T5_PRETRAINED_CONFIG_ARCHIVE_MAP = { 26 | "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json", 27 | "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json", 28 | "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json", 29 | "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json", 30 | "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json", 31 | } 32 | 33 | 34 | class T5Config(PretrainedConfig): 35 | r""" 36 | :class:`~transformers.T5Config` is the configuration class to store the configuration of a 37 | `T5Model`. 38 | 39 | 40 | Arguments: 41 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `T5Model`. 42 | hidden_size: Size of the encoder layers and the pooler layer. 43 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 44 | num_attention_heads: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 47 | layer in the Transformer encoder. 48 | hidden_act: The non-linear activation function (function or string) in the 49 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 50 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 51 | layers in the embeddings, encoder, and pooler. 52 | attention_probs_dropout_prob: The dropout ratio for the attention 53 | probabilities. 54 | max_position_embeddings: The maximum sequence length that this model might 55 | ever be used with. Typically set this to something large just in case 56 | (e.g., 512 or 1024 or 2048). 57 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 58 | `T5Model`. 59 | initializer_factor: A factor for initializing all weight matrices (should be kept to 1.0, used for initialization testing). 60 | layer_norm_eps: The epsilon used by LayerNorm. 61 | """ 62 | pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP 63 | 64 | def __init__( 65 | self, 66 | vocab_size=32128, 67 | n_positions=512, 68 | d_model=512, 69 | d_kv=64, 70 | d_ff=2048, 71 | num_layers=6, 72 | num_heads=8, 73 | relative_attention_num_buckets=32, 74 | dropout_rate=0.1, 75 | layer_norm_epsilon=1e-6, 76 | initializer_factor=1.0, 77 | **kwargs 78 | ): 79 | super(T5Config, self).__init__(**kwargs) 80 | self.vocab_size = vocab_size 81 | self.n_positions = n_positions 82 | self.d_model = d_model 83 | self.d_kv = d_kv 84 | self.d_ff = d_ff 85 | self.num_layers = num_layers 86 | self.num_heads = num_heads 87 | self.relative_attention_num_buckets = relative_attention_num_buckets 88 | self.dropout_rate = dropout_rate 89 | self.layer_norm_epsilon = layer_norm_epsilon 90 | self.initializer_factor = initializer_factor 91 | 92 | @property 93 | def max_position_embeddings(self): 94 | return self.n_positions 95 | 96 | @property 97 | def hidden_size(self): 98 | return self.d_model 99 | 100 | @property 101 | def num_attention_heads(self): 102 | return self.num_heads 103 | 104 | @property 105 | def num_hidden_layers(self): 106 | return self.num_layers 107 | -------------------------------------------------------------------------------- /transformers/configuration_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Transformer XL configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", 28 | } 29 | 30 | 31 | class TransfoXLConfig(PretrainedConfig): 32 | """Configuration class to store the configuration of a `TransfoXLModel`. 33 | 34 | Args: 35 | vocab_size: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file. 36 | cutoffs: cutoffs for the adaptive softmax 37 | d_model: Dimensionality of the model's hidden states. 38 | d_embed: Dimensionality of the embeddings 39 | d_head: Dimensionality of the model's heads. 40 | div_val: divident value for adapative input and softmax 41 | pre_lnorm: apply LayerNorm to the input instead of the output 42 | d_inner: Inner dimension in FF 43 | n_layer: Number of hidden layers in the Transformer encoder. 44 | n_head: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | tgt_len: number of tokens to predict 47 | ext_len: length of the extended context 48 | mem_len: length of the retained previous heads 49 | same_length: use the same attn length for all tokens 50 | proj_share_all_but_first: True to share all but first projs, False not to share. 51 | attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 52 | clamp_len: use the same pos embeddings after clamp_len 53 | sample_softmax: number of samples in sampled softmax 54 | adaptive: use adaptive softmax 55 | tie_weight: tie the word embedding and softmax weights 56 | dropout: The dropout probabilitiy for all fully connected 57 | layers in the embeddings, encoder, and pooler. 58 | dropatt: The dropout ratio for the attention probabilities. 59 | untie_r: untie relative position biases 60 | embd_pdrop: The dropout ratio for the embeddings. 61 | init: parameter initializer to use 62 | init_range: parameters initialized by U(-init_range, init_range). 63 | proj_init_std: parameters initialized by N(0, init_std) 64 | init_std: parameters initialized by N(0, init_std) 65 | """ 66 | 67 | pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP 68 | 69 | def __init__( 70 | self, 71 | vocab_size=267735, 72 | cutoffs=[20000, 40000, 200000], 73 | d_model=1024, 74 | d_embed=1024, 75 | n_head=16, 76 | d_head=64, 77 | d_inner=4096, 78 | div_val=4, 79 | pre_lnorm=False, 80 | n_layer=18, 81 | tgt_len=128, 82 | ext_len=0, 83 | mem_len=1600, 84 | clamp_len=1000, 85 | same_length=True, 86 | proj_share_all_but_first=True, 87 | attn_type=0, 88 | sample_softmax=-1, 89 | adaptive=True, 90 | tie_weight=True, 91 | dropout=0.1, 92 | dropatt=0.0, 93 | untie_r=True, 94 | init="normal", 95 | init_range=0.01, 96 | proj_init_std=0.01, 97 | init_std=0.02, 98 | layer_norm_epsilon=1e-5, 99 | **kwargs 100 | ): 101 | """Constructs TransfoXLConfig. 102 | """ 103 | super(TransfoXLConfig, self).__init__(**kwargs) 104 | self.vocab_size = vocab_size 105 | self.cutoffs = [] 106 | self.cutoffs.extend(cutoffs) 107 | self.tie_weight = tie_weight 108 | if proj_share_all_but_first: 109 | self.tie_projs = [False] + [True] * len(self.cutoffs) 110 | else: 111 | self.tie_projs = [False] + [False] * len(self.cutoffs) 112 | self.d_model = d_model 113 | self.d_embed = d_embed 114 | self.d_head = d_head 115 | self.d_inner = d_inner 116 | self.div_val = div_val 117 | self.pre_lnorm = pre_lnorm 118 | self.n_layer = n_layer 119 | self.n_head = n_head 120 | self.tgt_len = tgt_len 121 | self.ext_len = ext_len 122 | self.mem_len = mem_len 123 | self.same_length = same_length 124 | self.attn_type = attn_type 125 | self.clamp_len = clamp_len 126 | self.sample_softmax = sample_softmax 127 | self.adaptive = adaptive 128 | self.dropout = dropout 129 | self.dropatt = dropatt 130 | self.untie_r = untie_r 131 | self.init = init 132 | self.init_range = init_range 133 | self.proj_init_std = proj_init_std 134 | self.init_std = init_std 135 | self.layer_norm_epsilon = layer_norm_epsilon 136 | 137 | @property 138 | def max_position_embeddings(self): 139 | return self.tgt_len + self.ext_len + self.mem_len 140 | 141 | @property 142 | def n_token(self): # Backward compatibility 143 | return self.vocab_size 144 | 145 | @n_token.setter 146 | def n_token(self, value): # Backward compatibility 147 | self.vocab_size = value 148 | 149 | @property 150 | def hidden_size(self): 151 | return self.d_model 152 | 153 | @property 154 | def num_attention_heads(self): 155 | return self.n_head 156 | 157 | @property 158 | def num_hidden_layers(self): 159 | return self.n_layer 160 | -------------------------------------------------------------------------------- /transformers/configuration_xlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ XLM configuration """ 16 | 17 | 18 | import logging 19 | 20 | from .configuration_utils import PretrainedConfig 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 26 | "xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", 27 | "xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json", 28 | "xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json", 29 | "xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json", 30 | "xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json", 31 | "xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", 32 | "xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", 33 | "xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", 34 | "xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", 35 | "xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json", 36 | } 37 | 38 | 39 | class XLMConfig(PretrainedConfig): 40 | """Configuration class to store the configuration of a `XLMModel`. 41 | 42 | Args: 43 | vocab_size: Vocabulary size of `inputs_ids` in `XLMModel`. 44 | d_model: Size of the encoder layers and the pooler layer. 45 | n_layer: Number of hidden layers in the Transformer encoder. 46 | n_head: Number of attention heads for each attention layer in 47 | the Transformer encoder. 48 | d_inner: The size of the "intermediate" (i.e., feed-forward) 49 | layer in the Transformer encoder. 50 | ff_activation: The non-linear activation function (function or string) in the 51 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 52 | untie_r: untie relative position biases 53 | attn_type: 'bi' for XLM, 'uni' for Transformer-XL 54 | 55 | dropout: The dropout probabilitiy for all fully connected 56 | layers in the embeddings, encoder, and pooler. 57 | max_position_embeddings: The maximum sequence length that this model might 58 | ever be used with. Typically set this to something large just in case 59 | (e.g., 512 or 1024 or 2048). 60 | initializer_range: The sttdev of the truncated_normal_initializer for 61 | initializing all weight matrices. 62 | layer_norm_eps: The epsilon used by LayerNorm. 63 | 64 | dropout: float, dropout rate. 65 | init: str, the initialization scheme, either "normal" or "uniform". 66 | init_range: float, initialize the parameters with a uniform distribution 67 | in [-init_range, init_range]. Only effective when init="uniform". 68 | init_std: float, initialize the parameters with a normal distribution 69 | with mean 0 and stddev init_std. Only effective when init="normal". 70 | mem_len: int, the number of tokens to cache. 71 | reuse_len: int, the number of tokens in the currect batch to be cached 72 | and reused in the future. 73 | bi_data: bool, whether to use bidirectional input pipeline. 74 | Usually set to True during pretraining and False during finetuning. 75 | clamp_len: int, clamp all relative distances larger than clamp_len. 76 | -1 means no clamping. 77 | same_length: bool, whether to use the same attention length for each token. 78 | """ 79 | 80 | pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP 81 | 82 | def __init__( 83 | self, 84 | vocab_size=30145, 85 | emb_dim=2048, 86 | n_layers=12, 87 | n_heads=16, 88 | dropout=0.1, 89 | attention_dropout=0.1, 90 | gelu_activation=True, 91 | sinusoidal_embeddings=False, 92 | causal=False, 93 | asm=False, 94 | n_langs=1, 95 | use_lang_emb=True, 96 | max_position_embeddings=512, 97 | embed_init_std=2048 ** -0.5, 98 | layer_norm_eps=1e-12, 99 | init_std=0.02, 100 | bos_index=0, 101 | eos_index=1, 102 | pad_index=2, 103 | unk_index=3, 104 | mask_index=5, 105 | is_encoder=True, 106 | summary_type="first", 107 | summary_use_proj=True, 108 | summary_activation=None, 109 | summary_proj_to_labels=True, 110 | summary_first_dropout=0.1, 111 | start_n_top=5, 112 | end_n_top=5, 113 | mask_token_id=0, 114 | lang_id=0, 115 | **kwargs 116 | ): 117 | """Constructs XLMConfig. 118 | """ 119 | super(XLMConfig, self).__init__(**kwargs) 120 | self.vocab_size = vocab_size 121 | self.emb_dim = emb_dim 122 | self.n_layers = n_layers 123 | self.n_heads = n_heads 124 | self.dropout = dropout 125 | self.attention_dropout = attention_dropout 126 | self.gelu_activation = gelu_activation 127 | self.sinusoidal_embeddings = sinusoidal_embeddings 128 | self.causal = causal 129 | self.asm = asm 130 | self.n_langs = n_langs 131 | self.use_lang_emb = use_lang_emb 132 | self.layer_norm_eps = layer_norm_eps 133 | self.bos_index = bos_index 134 | self.eos_index = eos_index 135 | self.pad_index = pad_index 136 | self.unk_index = unk_index 137 | self.mask_index = mask_index 138 | self.is_encoder = is_encoder 139 | self.max_position_embeddings = max_position_embeddings 140 | self.embed_init_std = embed_init_std 141 | self.init_std = init_std 142 | self.summary_type = summary_type 143 | self.summary_use_proj = summary_use_proj 144 | self.summary_activation = summary_activation 145 | self.summary_proj_to_labels = summary_proj_to_labels 146 | self.summary_first_dropout = summary_first_dropout 147 | self.start_n_top = start_n_top 148 | self.end_n_top = end_n_top 149 | self.mask_token_id = mask_token_id 150 | self.lang_id = lang_id 151 | 152 | if "n_words" in kwargs: 153 | self.n_words = kwargs["n_words"] 154 | 155 | @property 156 | def n_words(self): # For backward compatibility 157 | return self.vocab_size 158 | 159 | @n_words.setter 160 | def n_words(self, value): # For backward compatibility 161 | self.vocab_size = value 162 | 163 | @property 164 | def hidden_size(self): 165 | return self.emb_dim 166 | 167 | @property 168 | def num_attention_heads(self): 169 | return self.n_heads 170 | 171 | @property 172 | def num_hidden_layers(self): 173 | return self.n_layers 174 | -------------------------------------------------------------------------------- /transformers/configuration_xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XLM-RoBERTa configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_roberta import RobertaConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json", 28 | "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json", 29 | "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json", 30 | "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json", 31 | "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json", 32 | "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json", 33 | } 34 | 35 | 36 | class XLMRobertaConfig(RobertaConfig): 37 | pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 38 | -------------------------------------------------------------------------------- /transformers/configuration_xlnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XLNet configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", 28 | "xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", 29 | } 30 | 31 | 32 | class XLNetConfig(PretrainedConfig): 33 | """Configuration class to store the configuration of a ``XLNetModel``. 34 | 35 | Args: 36 | vocab_size: Vocabulary size of ``inputs_ids`` in ``XLNetModel``. 37 | d_model: Size of the encoder layers and the pooler layer. 38 | n_layer: Number of hidden layers in the Transformer encoder. 39 | n_head: Number of attention heads for each attention layer in 40 | the Transformer encoder. 41 | d_inner: The size of the "intermediate" (i.e., feed-forward) 42 | layer in the Transformer encoder. 43 | ff_activation: The non-linear activation function (function or string) in the 44 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 45 | untie_r: untie relative position biases 46 | attn_type: 'bi' for XLNet, 'uni' for Transformer-XL 47 | 48 | dropout: The dropout probabilitiy for all fully connected 49 | layers in the embeddings, encoder, and pooler. 50 | initializer_range: The sttdev of the truncated_normal_initializer for 51 | initializing all weight matrices. 52 | layer_norm_eps: The epsilon used by LayerNorm. 53 | 54 | dropout: float, dropout rate. 55 | init: str, the initialization scheme, either "normal" or "uniform". 56 | init_range: float, initialize the parameters with a uniform distribution 57 | in [-init_range, init_range]. Only effective when init="uniform". 58 | init_std: float, initialize the parameters with a normal distribution 59 | with mean 0 and stddev init_std. Only effective when init="normal". 60 | mem_len: int, the number of tokens to cache. 61 | reuse_len: int, the number of tokens in the currect batch to be cached 62 | and reused in the future. 63 | bi_data: bool, whether to use bidirectional input pipeline. 64 | Usually set to True during pretraining and False during finetuning. 65 | clamp_len: int, clamp all relative distances larger than clamp_len. 66 | -1 means no clamping. 67 | same_length: bool, whether to use the same attention length for each token. 68 | finetuning_task: name of the glue task on which the model was fine-tuned if any 69 | """ 70 | 71 | pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP 72 | 73 | def __init__( 74 | self, 75 | vocab_size=32000, 76 | d_model=1024, 77 | n_layer=24, 78 | n_head=16, 79 | d_inner=4096, 80 | ff_activation="gelu", 81 | untie_r=True, 82 | attn_type="bi", 83 | initializer_range=0.02, 84 | layer_norm_eps=1e-12, 85 | dropout=0.1, 86 | mem_len=None, 87 | reuse_len=None, 88 | bi_data=False, 89 | clamp_len=-1, 90 | same_length=False, 91 | summary_type="last", 92 | summary_use_proj=True, 93 | summary_activation="tanh", 94 | summary_last_dropout=0.1, 95 | start_n_top=5, 96 | end_n_top=5, 97 | **kwargs 98 | ): 99 | """Constructs XLNetConfig. 100 | """ 101 | super(XLNetConfig, self).__init__(**kwargs) 102 | self.vocab_size = vocab_size 103 | self.d_model = d_model 104 | self.n_layer = n_layer 105 | self.n_head = n_head 106 | assert d_model % n_head == 0 107 | self.d_head = d_model // n_head 108 | self.ff_activation = ff_activation 109 | self.d_inner = d_inner 110 | self.untie_r = untie_r 111 | self.attn_type = attn_type 112 | 113 | self.initializer_range = initializer_range 114 | self.layer_norm_eps = layer_norm_eps 115 | 116 | self.dropout = dropout 117 | self.mem_len = mem_len 118 | self.reuse_len = reuse_len 119 | self.bi_data = bi_data 120 | self.clamp_len = clamp_len 121 | self.same_length = same_length 122 | 123 | self.summary_type = summary_type 124 | self.summary_use_proj = summary_use_proj 125 | self.summary_activation = summary_activation 126 | self.summary_last_dropout = summary_last_dropout 127 | self.start_n_top = start_n_top 128 | self.end_n_top = end_n_top 129 | 130 | @property 131 | def max_position_embeddings(self): 132 | return -1 133 | 134 | @property 135 | def n_token(self): # Backward compatibility 136 | return self.vocab_size 137 | 138 | @n_token.setter 139 | def n_token(self, value): # Backward compatibility 140 | self.vocab_size = value 141 | 142 | @property 143 | def hidden_size(self): 144 | return self.d_model 145 | 146 | @property 147 | def num_attention_heads(self): 148 | return self.n_head 149 | 150 | @property 151 | def num_hidden_layers(self): 152 | return self.n_layer 153 | -------------------------------------------------------------------------------- /transformers/convert_albert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ALBERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | 21 | import torch 22 | 23 | from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert 24 | 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = AlbertConfig.from_json_file(albert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = AlbertForMaskedLM(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_albert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--albert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained ALBERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /transformers/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | 21 | import torch 22 | 23 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 24 | 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--bert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained BERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /transformers/convert_bert_pytorch_checkpoint_to_original_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import argparse 19 | import os 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | import torch 24 | 25 | from transformers import BertModel 26 | 27 | 28 | def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): 29 | 30 | """ 31 | :param model:BertModel Pytorch model instance to be converted 32 | :param ckpt_dir: Tensorflow model directory 33 | :param model_name: model name 34 | :return: 35 | 36 | Currently supported HF models: 37 | Y BertModel 38 | N BertForMaskedLM 39 | N BertForPreTraining 40 | N BertForMultipleChoice 41 | N BertForNextSentencePrediction 42 | N BertForSequenceClassification 43 | N BertForQuestionAnswering 44 | """ 45 | 46 | tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") 47 | 48 | var_map = ( 49 | ("layer.", "layer_"), 50 | ("word_embeddings.weight", "word_embeddings"), 51 | ("position_embeddings.weight", "position_embeddings"), 52 | ("token_type_embeddings.weight", "token_type_embeddings"), 53 | (".", "/"), 54 | ("LayerNorm/weight", "LayerNorm/gamma"), 55 | ("LayerNorm/bias", "LayerNorm/beta"), 56 | ("weight", "kernel"), 57 | ) 58 | 59 | if not os.path.isdir(ckpt_dir): 60 | os.makedirs(ckpt_dir) 61 | 62 | state_dict = model.state_dict() 63 | 64 | def to_tf_var_name(name: str): 65 | for patt, repl in iter(var_map): 66 | name = name.replace(patt, repl) 67 | return "bert/{}".format(name) 68 | 69 | def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session): 70 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 71 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 72 | session.run(tf.variables_initializer([tf_var])) 73 | session.run(tf_var) 74 | return tf_var 75 | 76 | tf.reset_default_graph() 77 | with tf.Session() as session: 78 | for var_name in state_dict: 79 | tf_name = to_tf_var_name(var_name) 80 | torch_tensor = state_dict[var_name].numpy() 81 | if any([x in var_name for x in tensors_to_transpose]): 82 | torch_tensor = torch_tensor.T 83 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 84 | tf.keras.backend.set_value(tf_var, torch_tensor) 85 | tf_weight = session.run(tf_var) 86 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 87 | 88 | saver = tf.train.Saver(tf.trainable_variables()) 89 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 90 | 91 | 92 | def main(raw_args=None): 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased") 95 | parser.add_argument( 96 | "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model" 97 | ) 98 | parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/.bin") 99 | parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model") 100 | args = parser.parse_args(raw_args) 101 | 102 | model = BertModel.from_pretrained( 103 | pretrained_model_name_or_path=args.model_name, 104 | state_dict=torch.load(args.pytorch_model_path), 105 | cache_dir=args.cache_dir, 106 | ) 107 | 108 | convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /transformers/convert_gpt2_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | 21 | import torch 22 | 23 | from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2 24 | 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 30 | # Construct model 31 | if gpt2_config_file == "": 32 | config = GPT2Config() 33 | else: 34 | config = GPT2Config.from_json_file(gpt2_config_file) 35 | model = GPT2Model(config) 36 | 37 | # Load weights from numpy 38 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 39 | 40 | # Save pytorch-model 41 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 42 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 43 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 44 | torch.save(model.state_dict(), pytorch_weights_dump_path) 45 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 46 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 47 | f.write(config.to_json_string()) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | # Required parameters 53 | parser.add_argument( 54 | "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 55 | ) 56 | parser.add_argument( 57 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 58 | ) 59 | parser.add_argument( 60 | "--gpt2_config_file", 61 | default="", 62 | type=str, 63 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 64 | "This specifies the model architecture.", 65 | ) 66 | args = parser.parse_args() 67 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path) 68 | -------------------------------------------------------------------------------- /transformers/convert_openai_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | 21 | import torch 22 | 23 | from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt 24 | 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 30 | # Construct model 31 | if openai_config_file == "": 32 | config = OpenAIGPTConfig() 33 | else: 34 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 35 | model = OpenAIGPTModel(config) 36 | 37 | # Load weights from numpy 38 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 39 | 40 | # Save pytorch-model 41 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 42 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 43 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 44 | torch.save(model.state_dict(), pytorch_weights_dump_path) 45 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 46 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 47 | f.write(config.to_json_string()) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | # Required parameters 53 | parser.add_argument( 54 | "--openai_checkpoint_folder_path", 55 | default=None, 56 | type=str, 57 | required=True, 58 | help="Path to the TensorFlow checkpoint path.", 59 | ) 60 | parser.add_argument( 61 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 62 | ) 63 | parser.add_argument( 64 | "--openai_config_file", 65 | default="", 66 | type=str, 67 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.", 69 | ) 70 | args = parser.parse_args() 71 | convert_openai_checkpoint_to_pytorch( 72 | args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path 73 | ) 74 | -------------------------------------------------------------------------------- /transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert RoBERTa checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | import pathlib 21 | 22 | import fairseq 23 | import torch 24 | from fairseq.models.roberta import RobertaModel as FairseqRobertaModel 25 | from fairseq.modules import TransformerSentenceEncoderLayer 26 | from packaging import version 27 | 28 | from transformers.modeling_bert import ( 29 | BertConfig, 30 | BertIntermediate, 31 | BertLayer, 32 | BertOutput, 33 | BertSelfAttention, 34 | BertSelfOutput, 35 | ) 36 | from transformers.modeling_roberta import RobertaForMaskedLM, RobertaForSequenceClassification 37 | 38 | 39 | if version.parse(fairseq.__version__) < version.parse("0.9.0"): 40 | raise Exception("requires fairseq >= 0.9.0") 41 | 42 | 43 | logging.basicConfig(level=logging.INFO) 44 | logger = logging.getLogger(__name__) 45 | 46 | SAMPLE_TEXT = "Hello world! cécé herlolip" 47 | 48 | 49 | def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head): 50 | """ 51 | Copy/paste/tweak roberta's weights to our BERT structure. 52 | """ 53 | roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path) 54 | roberta.eval() # disable dropout 55 | roberta_sent_encoder = roberta.model.decoder.sentence_encoder 56 | config = BertConfig( 57 | vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings, 58 | hidden_size=roberta.args.encoder_embed_dim, 59 | num_hidden_layers=roberta.args.encoder_layers, 60 | num_attention_heads=roberta.args.encoder_attention_heads, 61 | intermediate_size=roberta.args.encoder_ffn_embed_dim, 62 | max_position_embeddings=514, 63 | type_vocab_size=1, 64 | layer_norm_eps=1e-5, # PyTorch default used in fairseq 65 | ) 66 | if classification_head: 67 | config.num_labels = roberta.args.num_classes 68 | print("Our BERT config:", config) 69 | 70 | model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config) 71 | model.eval() 72 | 73 | # Now let's copy all the weights. 74 | # Embeddings 75 | model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight 76 | model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight 77 | model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like( 78 | model.roberta.embeddings.token_type_embeddings.weight 79 | ) # just zero them out b/c RoBERTa doesn't use them. 80 | model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight 81 | model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias 82 | 83 | for i in range(config.num_hidden_layers): 84 | # Encoder: start of layer 85 | layer: BertLayer = model.roberta.encoder.layer[i] 86 | roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i] 87 | 88 | # self attention 89 | self_attn: BertSelfAttention = layer.attention.self 90 | assert ( 91 | roberta_layer.self_attn.k_proj.weight.data.shape 92 | == roberta_layer.self_attn.q_proj.weight.data.shape 93 | == roberta_layer.self_attn.v_proj.weight.data.shape 94 | == torch.Size((config.hidden_size, config.hidden_size)) 95 | ) 96 | 97 | self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight 98 | self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias 99 | self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight 100 | self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias 101 | self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight 102 | self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias 103 | 104 | # self-attention output 105 | self_output: BertSelfOutput = layer.attention.output 106 | assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape 107 | self_output.dense.weight = roberta_layer.self_attn.out_proj.weight 108 | self_output.dense.bias = roberta_layer.self_attn.out_proj.bias 109 | self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight 110 | self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias 111 | 112 | # intermediate 113 | intermediate: BertIntermediate = layer.intermediate 114 | assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape 115 | intermediate.dense.weight = roberta_layer.fc1.weight 116 | intermediate.dense.bias = roberta_layer.fc1.bias 117 | 118 | # output 119 | bert_output: BertOutput = layer.output 120 | assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape 121 | bert_output.dense.weight = roberta_layer.fc2.weight 122 | bert_output.dense.bias = roberta_layer.fc2.bias 123 | bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight 124 | bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias 125 | # end of layer 126 | 127 | if classification_head: 128 | model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight 129 | model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias 130 | model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight 131 | model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias 132 | else: 133 | # LM Head 134 | model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight 135 | model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias 136 | model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight 137 | model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias 138 | model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight 139 | model.lm_head.bias = roberta.model.decoder.lm_head.bias 140 | 141 | # Let's check that we get the same results. 142 | input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 143 | 144 | our_output = model(input_ids)[0] 145 | if classification_head: 146 | their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids)) 147 | else: 148 | their_output = roberta.model(input_ids)[0] 149 | print(our_output.shape, their_output.shape) 150 | max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() 151 | print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 152 | success = torch.allclose(our_output, their_output, atol=1e-3) 153 | print("Do both models output the same tensors?", "🔥" if success else "💩") 154 | if not success: 155 | raise Exception("Something went wRoNg") 156 | 157 | pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) 158 | print(f"Saving model to {pytorch_dump_folder_path}") 159 | model.save_pretrained(pytorch_dump_folder_path) 160 | 161 | 162 | if __name__ == "__main__": 163 | parser = argparse.ArgumentParser() 164 | # Required parameters 165 | parser.add_argument( 166 | "--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." 167 | ) 168 | parser.add_argument( 169 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 170 | ) 171 | parser.add_argument( 172 | "--classification_head", action="store_true", help="Whether to convert a final classification head." 173 | ) 174 | args = parser.parse_args() 175 | convert_roberta_checkpoint_to_pytorch( 176 | args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head 177 | ) 178 | -------------------------------------------------------------------------------- /transformers/convert_t5_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The T5 authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert T5 checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | 21 | import torch 22 | 23 | from transformers import T5Config, T5Model, load_tf_weights_in_t5 24 | 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = T5Config.from_json_file(config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = T5Model(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_t5(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained T5 model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | import os 21 | import pickle 22 | import sys 23 | 24 | import torch 25 | 26 | import transformers.tokenization_transfo_xl as data_utils 27 | from transformers import ( 28 | CONFIG_NAME, 29 | WEIGHTS_NAME, 30 | TransfoXLConfig, 31 | TransfoXLLMHeadModel, 32 | load_tf_weights_in_transfo_xl, 33 | ) 34 | from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES 35 | 36 | 37 | logging.basicConfig(level=logging.INFO) 38 | 39 | # We do this to be able to load python 2 datasets pickles 40 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 41 | data_utils.Vocab = data_utils.TransfoXLTokenizer 42 | data_utils.Corpus = data_utils.TransfoXLCorpus 43 | sys.modules["data_utils"] = data_utils 44 | sys.modules["vocabulary"] = data_utils 45 | 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch( 48 | tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file 49 | ): 50 | if transfo_xl_dataset_file: 51 | # Convert a pre-processed corpus (see original TensorFlow repo) 52 | with open(transfo_xl_dataset_file, "rb") as fp: 53 | corpus = pickle.load(fp, encoding="latin1") 54 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"] 56 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 57 | corpus_vocab_dict = corpus.vocab.__dict__ 58 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 59 | 60 | corpus_dict_no_vocab = corpus.__dict__ 61 | corpus_dict_no_vocab.pop("vocab", None) 62 | pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME 63 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 64 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 65 | 66 | if tf_checkpoint_path: 67 | # Convert a pre-trained TensorFlow model 68 | config_path = os.path.abspath(transfo_xl_config_file) 69 | tf_path = os.path.abspath(tf_checkpoint_path) 70 | 71 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 72 | # Initialise PyTorch model 73 | if transfo_xl_config_file == "": 74 | config = TransfoXLConfig() 75 | else: 76 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 77 | print("Building PyTorch model from configuration: {}".format(str(config))) 78 | model = TransfoXLLMHeadModel(config) 79 | 80 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 81 | # Save pytorch-model 82 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 83 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 84 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 85 | torch.save(model.state_dict(), pytorch_weights_dump_path) 86 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 87 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 88 | f.write(config.to_json_string()) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument( 94 | "--pytorch_dump_folder_path", 95 | default=None, 96 | type=str, 97 | required=True, 98 | help="Path to the folder to store the PyTorch model or dataset/vocab.", 99 | ) 100 | parser.add_argument( 101 | "--tf_checkpoint_path", 102 | default="", 103 | type=str, 104 | help="An optional path to a TensorFlow checkpoint path to be converted.", 105 | ) 106 | parser.add_argument( 107 | "--transfo_xl_config_file", 108 | default="", 109 | type=str, 110 | help="An optional config json file corresponding to the pre-trained BERT model. \n" 111 | "This specifies the model architecture.", 112 | ) 113 | parser.add_argument( 114 | "--transfo_xl_dataset_file", 115 | default="", 116 | type=str, 117 | help="An optional dataset file to be converted in a vocabulary.", 118 | ) 119 | args = parser.parse_args() 120 | convert_transfo_xl_checkpoint_to_pytorch( 121 | args.tf_checkpoint_path, 122 | args.transfo_xl_config_file, 123 | args.pytorch_dump_folder_path, 124 | args.transfo_xl_dataset_file, 125 | ) 126 | -------------------------------------------------------------------------------- /transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import json 20 | import logging 21 | 22 | import numpy 23 | import torch 24 | 25 | from transformers import CONFIG_NAME, WEIGHTS_NAME 26 | from transformers.tokenization_xlm import VOCAB_FILES_NAMES 27 | 28 | 29 | logging.basicConfig(level=logging.INFO) 30 | 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location="cpu") 35 | 36 | state_dict = chkpt["model"] 37 | 38 | # We have the base model one level deeper than the original XLM repository 39 | two_levels_state_dict = {} 40 | for k, v in state_dict.items(): 41 | if "pred_layer" in k: 42 | two_levels_state_dict[k] = v 43 | else: 44 | two_levels_state_dict["transformer." + k] = v 45 | 46 | config = chkpt["params"] 47 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 48 | 49 | vocab = chkpt["dico_word2id"] 50 | vocab = dict((s + "" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items()) 51 | 52 | # Save pytorch-model 53 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 54 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"] 56 | 57 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 58 | torch.save(two_levels_state_dict, pytorch_weights_dump_path) 59 | 60 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 61 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 62 | f.write(json.dumps(config, indent=2) + "\n") 63 | 64 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 65 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 66 | f.write(json.dumps(vocab, indent=2) + "\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | # Required parameters 72 | parser.add_argument( 73 | "--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." 74 | ) 75 | parser.add_argument( 76 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 77 | ) 78 | args = parser.parse_args() 79 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 80 | -------------------------------------------------------------------------------- /transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import logging 20 | import os 21 | 22 | import torch 23 | 24 | from transformers import ( 25 | CONFIG_NAME, 26 | WEIGHTS_NAME, 27 | XLNetConfig, 28 | XLNetForQuestionAnswering, 29 | XLNetForSequenceClassification, 30 | XLNetLMHeadModel, 31 | load_tf_weights_in_xlnet, 32 | ) 33 | 34 | 35 | GLUE_TASKS_NUM_LABELS = { 36 | "cola": 2, 37 | "mnli": 3, 38 | "mrpc": 2, 39 | "sst-2": 2, 40 | "sts-b": 1, 41 | "qqp": 2, 42 | "qnli": 2, 43 | "rte": 2, 44 | "wnli": 2, 45 | } 46 | 47 | 48 | logging.basicConfig(level=logging.INFO) 49 | 50 | 51 | def convert_xlnet_checkpoint_to_pytorch( 52 | tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None 53 | ): 54 | # Initialise PyTorch model 55 | config = XLNetConfig.from_json_file(bert_config_file) 56 | 57 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 58 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 59 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 60 | config.finetuning_task = finetuning_task 61 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 62 | model = XLNetForSequenceClassification(config) 63 | elif "squad" in finetuning_task: 64 | config.finetuning_task = finetuning_task 65 | model = XLNetForQuestionAnswering(config) 66 | else: 67 | model = XLNetLMHeadModel(config) 68 | 69 | # Load weights from tf checkpoint 70 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 71 | 72 | # Save pytorch-model 73 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 74 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 75 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 76 | torch.save(model.state_dict(), pytorch_weights_dump_path) 77 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 78 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 79 | f.write(config.to_json_string()) 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | # Required parameters 85 | parser.add_argument( 86 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 87 | ) 88 | parser.add_argument( 89 | "--xlnet_config_file", 90 | default=None, 91 | type=str, 92 | required=True, 93 | help="The config json file corresponding to the pre-trained XLNet model. \n" 94 | "This specifies the model architecture.", 95 | ) 96 | parser.add_argument( 97 | "--pytorch_dump_folder_path", 98 | default=None, 99 | type=str, 100 | required=True, 101 | help="Path to the folder to store the PyTorch model or dataset/vocab.", 102 | ) 103 | parser.add_argument( 104 | "--finetuning_task", 105 | default=None, 106 | type=str, 107 | help="Name of a task on which the XLNet TensorFloaw model was fine-tuned", 108 | ) 109 | args = parser.parse_args() 110 | print(args) 111 | 112 | convert_xlnet_checkpoint_to_pytorch( 113 | args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task 114 | ) 115 | -------------------------------------------------------------------------------- /transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | from .metrics import is_sklearn_available 6 | from .processors import ( 7 | DataProcessor, 8 | InputExample, 9 | InputFeatures, 10 | SingleSentenceClassificationProcessor, 11 | SquadExample, 12 | SquadFeatures, 13 | SquadV1Processor, 14 | SquadV2Processor, 15 | glue_convert_examples_to_features, 16 | glue_output_modes, 17 | glue_processors, 18 | glue_tasks_num_labels, 19 | squad_convert_examples_to_features, 20 | xnli_output_modes, 21 | xnli_processors, 22 | xnli_tasks_num_labels, 23 | ) 24 | 25 | 26 | if is_sklearn_available(): 27 | from .metrics import glue_compute_metrics, xnli_compute_metrics 28 | -------------------------------------------------------------------------------- /transformers/data/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | try: 18 | from scipy.stats import pearsonr, spearmanr 19 | from sklearn.metrics import matthews_corrcoef, f1_score 20 | 21 | _has_sklearn = True 22 | except (AttributeError, ImportError) as e: 23 | _has_sklearn = False 24 | 25 | 26 | def is_sklearn_available(): 27 | return _has_sklearn 28 | 29 | 30 | if _has_sklearn: 31 | 32 | def simple_accuracy(preds, labels): 33 | return (preds == labels).mean() 34 | 35 | def acc_and_f1(preds, labels): 36 | acc = simple_accuracy(preds, labels) 37 | f1 = f1_score(y_true=labels, y_pred=preds) 38 | return { 39 | "acc": acc, 40 | "f1": f1, 41 | "acc_and_f1": (acc + f1) / 2, 42 | } 43 | 44 | def pearson_and_spearman(preds, labels): 45 | pearson_corr = pearsonr(preds, labels)[0] 46 | spearman_corr = spearmanr(preds, labels)[0] 47 | return { 48 | "pearson": pearson_corr, 49 | "spearmanr": spearman_corr, 50 | "corr": (pearson_corr + spearman_corr) / 2, 51 | } 52 | 53 | def glue_compute_metrics(task_name, preds, labels): 54 | assert len(preds) == len(labels) 55 | if task_name == "cola": 56 | return {"mcc": matthews_corrcoef(labels, preds)} 57 | elif task_name == "sst-2": 58 | return {"acc": simple_accuracy(preds, labels)} 59 | elif task_name == "mrpc": 60 | return acc_and_f1(preds, labels) 61 | elif task_name == "sts-b": 62 | return pearson_and_spearman(preds, labels) 63 | elif task_name == "qqp": 64 | return acc_and_f1(preds, labels) 65 | elif task_name == "mnli": 66 | return {"acc": simple_accuracy(preds, labels)} 67 | elif task_name == "mnli-mm": 68 | return {"acc": simple_accuracy(preds, labels)} 69 | elif task_name == "qnli": 70 | return {"acc": simple_accuracy(preds, labels)} 71 | elif task_name == "rte": 72 | return {"acc": simple_accuracy(preds, labels)} 73 | elif task_name == "wnli": 74 | return {"acc": simple_accuracy(preds, labels)} 75 | else: 76 | raise KeyError(task_name) 77 | 78 | def xnli_compute_metrics(task_name, preds, labels): 79 | assert len(preds) == len(labels) 80 | if task_name == "xnli": 81 | return {"acc": simple_accuracy(preds, labels)} 82 | else: 83 | raise KeyError(task_name) 84 | -------------------------------------------------------------------------------- /transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels 6 | from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features 7 | from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor 8 | from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels 9 | -------------------------------------------------------------------------------- /transformers/data/processors/xnli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XNLI utils (dataset loading and evaluation) """ 17 | 18 | 19 | import logging 20 | import os 21 | 22 | from .utils import DataProcessor, InputExample 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class XnliProcessor(DataProcessor): 29 | """Processor for the XNLI dataset. 30 | Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207""" 31 | 32 | def __init__(self, language, train_language=None): 33 | self.language = language 34 | self.train_language = train_language 35 | 36 | def get_train_examples(self, data_dir): 37 | """See base class.""" 38 | lg = self.language if self.train_language is None else self.train_language 39 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-MT-1.0/multinli/multinli.train.{}.tsv".format(lg))) 40 | examples = [] 41 | for (i, line) in enumerate(lines): 42 | if i == 0: 43 | continue 44 | guid = "%s-%s" % ("train", i) 45 | text_a = line[0] 46 | text_b = line[1] 47 | label = "contradiction" if line[2] == "contradictory" else line[2] 48 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 49 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 50 | return examples 51 | 52 | def get_test_examples(self, data_dir): 53 | """See base class.""" 54 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv")) 55 | examples = [] 56 | for (i, line) in enumerate(lines): 57 | if i == 0: 58 | continue 59 | language = line[0] 60 | if language != self.language: 61 | continue 62 | guid = "%s-%s" % ("test", i) 63 | text_a = line[6] 64 | text_b = line[7] 65 | label = line[1] 66 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 67 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 68 | return examples 69 | 70 | def get_labels(self): 71 | """See base class.""" 72 | return ["contradiction", "entailment", "neutral"] 73 | 74 | 75 | xnli_processors = { 76 | "xnli": XnliProcessor, 77 | } 78 | 79 | xnli_output_modes = { 80 | "xnli": "classification", 81 | } 82 | 83 | xnli_tasks_num_labels = { 84 | "xnli": 3, 85 | } 86 | -------------------------------------------------------------------------------- /transformers/hf_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import io 18 | import os 19 | from os.path import expanduser 20 | from typing import List 21 | 22 | import requests 23 | from tqdm import tqdm 24 | 25 | 26 | ENDPOINT = "https://huggingface.co" 27 | 28 | 29 | class S3Obj: 30 | def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs): 31 | self.filename = filename 32 | self.LastModified = LastModified 33 | self.ETag = ETag 34 | self.Size = Size 35 | 36 | 37 | class PresignedUrl: 38 | def __init__(self, write: str, access: str, type: str, **kwargs): 39 | self.write = write 40 | self.access = access 41 | self.type = type # mime-type to send to S3. 42 | 43 | 44 | class HfApi: 45 | def __init__(self, endpoint=None): 46 | self.endpoint = endpoint if endpoint is not None else ENDPOINT 47 | 48 | def login(self, username: str, password: str) -> str: 49 | """ 50 | Call HF API to sign in a user and get a token if credentials are valid. 51 | 52 | Outputs: 53 | token if credentials are valid 54 | 55 | Throws: 56 | requests.exceptions.HTTPError if credentials are invalid 57 | """ 58 | path = "{}/api/login".format(self.endpoint) 59 | r = requests.post(path, json={"username": username, "password": password}) 60 | r.raise_for_status() 61 | d = r.json() 62 | return d["token"] 63 | 64 | def whoami(self, token: str) -> str: 65 | """ 66 | Call HF API to know "whoami" 67 | """ 68 | path = "{}/api/whoami".format(self.endpoint) 69 | r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) 70 | r.raise_for_status() 71 | d = r.json() 72 | return d["user"] 73 | 74 | def logout(self, token: str) -> None: 75 | """ 76 | Call HF API to log out. 77 | """ 78 | path = "{}/api/logout".format(self.endpoint) 79 | r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) 80 | r.raise_for_status() 81 | 82 | def presign(self, token: str, filename) -> PresignedUrl: 83 | """ 84 | Call HF API to get a presigned url to upload `filename` to S3. 85 | """ 86 | path = "{}/api/presign".format(self.endpoint) 87 | r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) 88 | r.raise_for_status() 89 | d = r.json() 90 | return PresignedUrl(**d) 91 | 92 | def presign_and_upload(self, token: str, filename, filepath) -> str: 93 | """ 94 | Get a presigned url, then upload file to S3. 95 | 96 | Outputs: 97 | url: Read-only url for the stored file on S3. 98 | """ 99 | urls = self.presign(token, filename=filename) 100 | # streaming upload: 101 | # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads 102 | # 103 | # Even though we presign with the correct content-type, 104 | # the client still has to specify it when uploading the file. 105 | with open(filepath, "rb") as f: 106 | pf = TqdmProgressFileReader(f) 107 | data = f if pf.total_size > 0 else "" 108 | 109 | r = requests.put(urls.write, data=data, headers={"content-type": urls.type}) 110 | r.raise_for_status() 111 | pf.close() 112 | return urls.access 113 | 114 | def list_objs(self, token) -> List[S3Obj]: 115 | """ 116 | Call HF API to list all stored files for user. 117 | """ 118 | path = "{}/api/listObjs".format(self.endpoint) 119 | r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) 120 | r.raise_for_status() 121 | d = r.json() 122 | return [S3Obj(**x) for x in d] 123 | 124 | 125 | class TqdmProgressFileReader: 126 | """ 127 | Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) 128 | and override `f.read()` so as to display a tqdm progress bar. 129 | 130 | see github.com/huggingface/transformers/pull/2078#discussion_r354739608 131 | for implementation details. 132 | """ 133 | 134 | def __init__(self, f: io.BufferedReader): 135 | self.f = f 136 | self.total_size = os.fstat(f.fileno()).st_size 137 | self.pbar = tqdm(total=self.total_size, leave=False) 138 | self.read = f.read 139 | f.read = self._read 140 | 141 | def _read(self, n=-1): 142 | self.pbar.update(n) 143 | return self.read(n) 144 | 145 | def close(self): 146 | self.pbar.close() 147 | 148 | 149 | class HfFolder: 150 | path_token = expanduser("~/.huggingface/token") 151 | 152 | @classmethod 153 | def save_token(cls, token): 154 | """ 155 | Save token, creating folder as needed. 156 | """ 157 | os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) 158 | with open(cls.path_token, "w+") as f: 159 | f.write(token) 160 | 161 | @classmethod 162 | def get_token(cls): 163 | """ 164 | Get token or None if not existent. 165 | """ 166 | try: 167 | with open(cls.path_token, "r") as f: 168 | return f.read() 169 | except FileNotFoundError: 170 | pass 171 | 172 | @classmethod 173 | def delete_token(cls): 174 | """ 175 | Delete token. 176 | Do not fail if token does not exist. 177 | """ 178 | try: 179 | os.remove(cls.path_token) 180 | except FileNotFoundError: 181 | pass 182 | -------------------------------------------------------------------------------- /transformers/modeling_tf_transfo_xl_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ A TF 2.0 Adaptive Softmax for Transformer XL model. 17 | """ 18 | 19 | 20 | import tensorflow as tf 21 | 22 | from .modeling_tf_utils import shape_list 23 | 24 | 25 | class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): 26 | def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs): 27 | super(TFAdaptiveSoftmaxMask, self).__init__(**kwargs) 28 | 29 | self.vocab_size = vocab_size 30 | self.d_embed = d_embed 31 | self.d_proj = d_proj 32 | 33 | self.cutoffs = cutoffs + [vocab_size] 34 | self.cutoff_ends = [0] + self.cutoffs 35 | self.div_val = div_val 36 | 37 | self.shortlist_size = self.cutoffs[0] 38 | self.n_clusters = len(self.cutoffs) - 1 39 | self.head_size = self.shortlist_size + self.n_clusters 40 | self.keep_order = keep_order 41 | 42 | self.out_layers = [] 43 | self.out_projs = [] 44 | 45 | def build(self, input_shape): 46 | if self.n_clusters > 0: 47 | self.cluster_weight = self.add_weight( 48 | shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight" 49 | ) 50 | self.cluster_bias = self.add_weight( 51 | shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias" 52 | ) 53 | 54 | if self.div_val == 1: 55 | for i in range(len(self.cutoffs)): 56 | if self.d_proj != self.d_embed: 57 | weight = self.add_weight( 58 | shape=(self.d_embed, self.d_proj), 59 | initializer="zeros", 60 | trainable=True, 61 | name="out_projs_._{}".format(i), 62 | ) 63 | self.out_projs.append(weight) 64 | else: 65 | self.out_projs.append(None) 66 | weight = self.add_weight( 67 | shape=(self.vocab_size, self.d_embed,), 68 | initializer="zeros", 69 | trainable=True, 70 | name="out_layers_._{}_._weight".format(i), 71 | ) 72 | bias = self.add_weight( 73 | shape=(self.vocab_size,), 74 | initializer="zeros", 75 | trainable=True, 76 | name="out_layers_._{}_._bias".format(i), 77 | ) 78 | self.out_layers.append((weight, bias)) 79 | else: 80 | for i in range(len(self.cutoffs)): 81 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 82 | d_emb_i = self.d_embed // (self.div_val ** i) 83 | 84 | weight = self.add_weight( 85 | shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name="out_projs_._{}".format(i) 86 | ) 87 | self.out_projs.append(weight) 88 | weight = self.add_weight( 89 | shape=(r_idx - l_idx, d_emb_i,), 90 | initializer="zeros", 91 | trainable=True, 92 | name="out_layers_._{}_._weight".format(i), 93 | ) 94 | bias = self.add_weight( 95 | shape=(r_idx - l_idx,), 96 | initializer="zeros", 97 | trainable=True, 98 | name="out_layers_._{}_._bias".format(i), 99 | ) 100 | self.out_layers.append((weight, bias)) 101 | super(TFAdaptiveSoftmaxMask, self).build(input_shape) 102 | 103 | @staticmethod 104 | def _logit(x, W, b, proj=None): 105 | y = x 106 | if proj is not None: 107 | y = tf.einsum("ibd,ed->ibe", y, proj) 108 | return tf.einsum("ibd,nd->ibn", y, W) + b 109 | 110 | @staticmethod 111 | def _gather_logprob(logprob, target): 112 | lp_size = shape_list(logprob) 113 | r = tf.range(lp_size[0]) 114 | idx = tf.stack([r, target], 1) 115 | return tf.gather_nd(logprob, idx) 116 | 117 | def call(self, inputs, return_mean=True, training=False): 118 | hidden, target = inputs 119 | head_logprob = 0 120 | if self.n_clusters == 0: 121 | softmax_b = tf.get_variable("bias", [self.config.vocab_size], initializer=tf.zeros_initializer()) 122 | output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0]) 123 | if target is not None: 124 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output) 125 | out = tf.nn.log_softmax(output, axis=-1) 126 | else: 127 | hidden_sizes = shape_list(hidden) 128 | out = [] 129 | loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32) 130 | for i in range(len(self.cutoffs)): 131 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 132 | if target is not None: 133 | mask = (target >= l_idx) & (target < r_idx) 134 | mask_idx = tf.where(mask) 135 | cur_target = tf.boolean_mask(target, mask) - l_idx 136 | 137 | if self.div_val == 1: 138 | cur_W = self.out_layers[0][0][l_idx:r_idx] 139 | cur_b = self.out_layers[0][1][l_idx:r_idx] 140 | else: 141 | cur_W = self.out_layers[i][0] 142 | cur_b = self.out_layers[i][1] 143 | 144 | if i == 0: 145 | cur_W = tf.concat([cur_W, self.cluster_weight], 0) 146 | cur_b = tf.concat([cur_b, self.cluster_bias], 0) 147 | 148 | head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0]) 149 | head_logprob = tf.nn.log_softmax(head_logit) 150 | out.append(head_logprob[..., : self.cutoffs[0]]) 151 | if target is not None: 152 | cur_head_logprob = tf.boolean_mask(head_logprob, mask) 153 | cur_logprob = self._gather_logprob(cur_head_logprob, cur_target) 154 | else: 155 | tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i]) 156 | tail_logprob = tf.nn.log_softmax(tail_logit) 157 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster 158 | logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob 159 | out.append(logprob_i) 160 | if target is not None: 161 | cur_head_logprob = tf.boolean_mask(head_logprob, mask) 162 | cur_tail_logprob = tf.boolean_mask(tail_logprob, mask) 163 | cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target) 164 | cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1] 165 | if target is not None: 166 | loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64)) 167 | out = tf.concat(out, axis=-1) 168 | 169 | if target is not None: 170 | if return_mean: 171 | loss = tf.reduce_mean(loss) 172 | # Add the training-time loss value to the layer using `self.add_loss()`. 173 | self.add_loss(loss) 174 | 175 | # Log the loss as a metric (we could log arbitrary metrics, 176 | # including different metrics for training and inference. 177 | self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "") 178 | 179 | return out 180 | -------------------------------------------------------------------------------- /transformers/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def get_constant_schedule(optimizer, last_epoch=-1): 29 | """ Create a schedule with a constant learning rate. 30 | """ 31 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 32 | 33 | 34 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 35 | """ Create a schedule with a constant learning rate preceded by a warmup 36 | period during which the learning rate increases linearly between 0 and 1. 37 | """ 38 | 39 | def lr_lambda(current_step): 40 | if current_step < num_warmup_steps: 41 | return float(current_step) / float(max(1.0, num_warmup_steps)) 42 | return 1.0 43 | 44 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 45 | 46 | 47 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 48 | """ Create a schedule with a learning rate that decreases linearly after 49 | linearly increasing during a warmup period. 50 | """ 51 | 52 | def lr_lambda(current_step): 53 | if current_step < num_warmup_steps: 54 | return float(current_step) / float(max(1, num_warmup_steps)) 55 | return max( 56 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 57 | ) 58 | 59 | return LambdaLR(optimizer, lr_lambda, last_epoch) 60 | 61 | 62 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1): 63 | """ Create a schedule with a learning rate that decreases following the 64 | values of the cosine function between 0 and `pi * cycles` after a warmup 65 | period during which it increases linearly between 0 and 1. 66 | """ 67 | 68 | def lr_lambda(current_step): 69 | if current_step < num_warmup_steps: 70 | return float(current_step) / float(max(1, num_warmup_steps)) 71 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 72 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 73 | 74 | return LambdaLR(optimizer, lr_lambda, last_epoch) 75 | 76 | 77 | def get_cosine_with_hard_restarts_schedule_with_warmup( 78 | optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1 79 | ): 80 | """ Create a schedule with a learning rate that decreases following the 81 | values of the cosine function with several hard restarts, after a warmup 82 | period during which it increases linearly between 0 and 1. 83 | """ 84 | 85 | def lr_lambda(current_step): 86 | if current_step < num_warmup_steps: 87 | return float(current_step) / float(max(1, num_warmup_steps)) 88 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 89 | if progress >= 1.0: 90 | return 0.0 91 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 92 | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) 94 | 95 | 96 | class AdamW(Optimizer): 97 | """ Implements Adam algorithm with weight decay fix. 98 | 99 | Parameters: 100 | lr (float): learning rate. Default 1e-3. 101 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 102 | eps (float): Adams epsilon. Default: 1e-6 103 | weight_decay (float): Weight decay. Default: 0.0 104 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 105 | """ 106 | 107 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 108 | if lr < 0.0: 109 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 110 | if not 0.0 <= betas[0] < 1.0: 111 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 112 | if not 0.0 <= betas[1] < 1.0: 113 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 114 | if not 0.0 <= eps: 115 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 116 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 117 | super(AdamW, self).__init__(params, defaults) 118 | 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | 122 | Arguments: 123 | closure (callable, optional): A closure that reevaluates the model 124 | and returns the loss. 125 | """ 126 | loss = None 127 | if closure is not None: 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | for p in group["params"]: 132 | if p.grad is None: 133 | continue 134 | grad = p.grad.data 135 | if grad.is_sparse: 136 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 137 | 138 | state = self.state[p] 139 | 140 | # State initialization 141 | if len(state) == 0: 142 | state["step"] = 0 143 | # Exponential moving average of gradient values 144 | state["exp_avg"] = torch.zeros_like(p.data) 145 | # Exponential moving average of squared gradient values 146 | state["exp_avg_sq"] = torch.zeros_like(p.data) 147 | 148 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 149 | beta1, beta2 = group["betas"] 150 | 151 | state["step"] += 1 152 | 153 | # Decay the first and second moment running average coefficient 154 | # In-place operations to update the averages at the same time 155 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 156 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 157 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 158 | 159 | step_size = group["lr"] 160 | if group["correct_bias"]: # No bias correction for Bert 161 | bias_correction1 = 1.0 - beta1 ** state["step"] 162 | bias_correction2 = 1.0 - beta2 ** state["step"] 163 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 164 | 165 | p.data.addcdiv_(-step_size, exp_avg, denom) 166 | 167 | # Just adding the square of the weights to the loss function is *not* 168 | # the correct way of using L2 regularization/weight decay with Adam, 169 | # since that will interact with the m and v parameters in strange ways. 170 | # 171 | # Instead we want to decay the weights in a manner that doesn't interact 172 | # with the m/v parameters. This is equivalent to adding the square 173 | # of the weights to the loss with plain (non-momentum) SGD. 174 | # Add weight decay at the end (fixed version) 175 | if group["weight_decay"] > 0.0: 176 | p.data.add_(-group["lr"] * group["weight_decay"], p.data) 177 | 178 | return loss 179 | -------------------------------------------------------------------------------- /transformers/tokenization_camembert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License 15 | """ Tokenization classes for Camembert model.""" 16 | 17 | 18 | import logging 19 | import os 20 | from shutil import copyfile 21 | 22 | import sentencepiece as spm 23 | 24 | from transformers.tokenization_utils import PreTrainedTokenizer 25 | 26 | from .tokenization_xlnet import SPIECE_UNDERLINE 27 | 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} 32 | 33 | PRETRAINED_VOCAB_FILES_MAP = { 34 | "vocab_file": { 35 | "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model", 36 | } 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | "camembert-base": None, 41 | } 42 | 43 | 44 | class CamembertTokenizer(PreTrainedTokenizer): 45 | """ 46 | Adapted from RobertaTokenizer and XLNetTokenizer 47 | SentencePiece based tokenizer. Peculiarities: 48 | 49 | - requires `SentencePiece `_ 50 | """ 51 | 52 | vocab_files_names = VOCAB_FILES_NAMES 53 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 54 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 55 | 56 | def __init__( 57 | self, 58 | vocab_file, 59 | bos_token="", 60 | eos_token="", 61 | sep_token="", 62 | cls_token="", 63 | unk_token="", 64 | pad_token="", 65 | mask_token="", 66 | additional_special_tokens=["NOTUSED", "NOTUSED"], 67 | **kwargs 68 | ): 69 | super(CamembertTokenizer, self).__init__( 70 | max_len=512, 71 | bos_token=bos_token, 72 | eos_token=eos_token, 73 | unk_token=unk_token, 74 | sep_token=sep_token, 75 | cls_token=cls_token, 76 | pad_token=pad_token, 77 | mask_token=mask_token, 78 | additional_special_tokens=additional_special_tokens, 79 | **kwargs 80 | ) 81 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 82 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens 83 | self.sp_model = spm.SentencePieceProcessor() 84 | self.sp_model.Load(str(vocab_file)) 85 | self.vocab_file = vocab_file 86 | # HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual 87 | # sentencepiece vocabulary (this is the case for and 88 | self.fairseq_tokens_to_ids = {"NOTUSED": 0, "": 1, "NOTUSED": 2, "": 3} 89 | self.fairseq_offset = len(self.fairseq_tokens_to_ids) 90 | self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.fairseq_tokens_to_ids) 91 | self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} 92 | 93 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 94 | """ 95 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 96 | by concatenating and adding special tokens. 97 | A RoBERTa sequence has the following format: 98 | single sequence: X 99 | pair of sequences: A B 100 | """ 101 | if token_ids_1 is None: 102 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 103 | cls = [self.cls_token_id] 104 | sep = [self.sep_token_id] 105 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 106 | 107 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 108 | """ 109 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 110 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 111 | 112 | Args: 113 | token_ids_0: list of ids (must not contain special tokens) 114 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 115 | for sequence pairs 116 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 117 | special tokens for the model 118 | 119 | Returns: 120 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 121 | """ 122 | if already_has_special_tokens: 123 | if token_ids_1 is not None: 124 | raise ValueError( 125 | "You should not supply a second sequence if the provided sequence of " 126 | "ids is already formated with special tokens for the model." 127 | ) 128 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 129 | 130 | if token_ids_1 is None: 131 | return [1] + ([0] * len(token_ids_0)) + [1] 132 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 133 | 134 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 135 | """ 136 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 137 | A RoBERTa sequence pair mask has the following format: 138 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 139 | | first sequence | second sequence 140 | 141 | if token_ids_1 is None, only returns the first portion of the mask (0's). 142 | """ 143 | sep = [self.sep_token_id] 144 | cls = [self.cls_token_id] 145 | 146 | if token_ids_1 is None: 147 | return len(cls + token_ids_0 + sep) * [0] 148 | return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1] 149 | 150 | @property 151 | def vocab_size(self): 152 | return len(self.fairseq_tokens_to_ids) + len(self.sp_model) 153 | 154 | def _tokenize(self, text): 155 | return self.sp_model.EncodeAsPieces(text) 156 | 157 | def _convert_token_to_id(self, token): 158 | """ Converts a token (str) in an id using the vocab. """ 159 | if token in self.fairseq_tokens_to_ids: 160 | return self.fairseq_tokens_to_ids[token] 161 | elif self.sp_model.PieceToId(token) == 0: 162 | # Convert sentence piece unk token to fairseq unk token index 163 | return self.unk_token_id 164 | return self.fairseq_offset + self.sp_model.PieceToId(token) 165 | 166 | def _convert_id_to_token(self, index): 167 | """Converts an index (integer) in a token (str) using the vocab.""" 168 | if index in self.fairseq_ids_to_tokens: 169 | return self.fairseq_ids_to_tokens[index] 170 | return self.sp_model.IdToPiece(index - self.fairseq_offset) 171 | 172 | def convert_tokens_to_string(self, tokens): 173 | """Converts a sequence of tokens (strings for sub-words) in a single string.""" 174 | out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() 175 | return out_string 176 | 177 | def save_vocabulary(self, save_directory): 178 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file 179 | to a directory. 180 | """ 181 | if not os.path.isdir(save_directory): 182 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 183 | return 184 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) 185 | 186 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 187 | copyfile(self.vocab_file, out_vocab_file) 188 | 189 | return (out_vocab_file,) 190 | -------------------------------------------------------------------------------- /transformers/tokenization_ctrl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for Salesforce CTRL.""" 16 | 17 | 18 | import json 19 | import logging 20 | import os 21 | 22 | import regex as re 23 | 24 | from .tokenization_utils import PreTrainedTokenizer 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = { 30 | "vocab_file": "vocab.json", 31 | "merges_file": "merges.txt", 32 | } 33 | 34 | PRETRAINED_VOCAB_FILES_MAP = { 35 | "vocab_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json"}, 36 | "merges_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt"}, 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | "ctrl": 256, 41 | } 42 | 43 | CONTROL_CODES = { 44 | "Pregnancy": 168629, 45 | "Christianity": 7675, 46 | "Explain": 106423, 47 | "Fitness": 63440, 48 | "Saving": 63163, 49 | "Ask": 27171, 50 | "Ass": 95985, 51 | "Joke": 163509, 52 | "Questions": 45622, 53 | "Thoughts": 49605, 54 | "Retail": 52342, 55 | "Feminism": 164338, 56 | "Writing": 11992, 57 | "Atheism": 192263, 58 | "Netflix": 48616, 59 | "Computing": 39639, 60 | "Opinion": 43213, 61 | "Alone": 44967, 62 | "Funny": 58917, 63 | "Gaming": 40358, 64 | "Human": 4088, 65 | "India": 1331, 66 | "Joker": 77138, 67 | "Diet": 36206, 68 | "Legal": 11859, 69 | "Norman": 4939, 70 | "Tip": 72689, 71 | "Weight": 52343, 72 | "Movies": 46273, 73 | "Running": 23425, 74 | "Science": 2090, 75 | "Horror": 37793, 76 | "Confession": 60572, 77 | "Finance": 12250, 78 | "Politics": 16360, 79 | "Scary": 191985, 80 | "Support": 12654, 81 | "Technologies": 32516, 82 | "Teenage": 66160, 83 | "Event": 32769, 84 | "Learned": 67460, 85 | "Notion": 182770, 86 | "Wikipedia": 37583, 87 | "Books": 6665, 88 | "Extract": 76050, 89 | "Confessions": 102701, 90 | "Conspiracy": 75932, 91 | "Links": 63674, 92 | "Narcissus": 150425, 93 | "Relationship": 54766, 94 | "Relationships": 134796, 95 | "Reviews": 41671, 96 | "News": 4256, 97 | "Translation": 26820, 98 | "multilingual": 128406, 99 | } 100 | 101 | 102 | def get_pairs(word): 103 | """Return set of symbol pairs in a word. 104 | 105 | Word is represented as tuple of symbols (symbols being variable-length strings). 106 | """ 107 | pairs = set() 108 | prev_char = word[0] 109 | for char in word[1:]: 110 | pairs.add((prev_char, char)) 111 | prev_char = char 112 | 113 | pairs = set(pairs) 114 | return pairs 115 | 116 | 117 | class CTRLTokenizer(PreTrainedTokenizer): 118 | """ 119 | CTRL BPE tokenizer. Peculiarities: 120 | - Byte-Pair-Encoding 121 | """ 122 | 123 | vocab_files_names = VOCAB_FILES_NAMES 124 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 125 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 126 | control_codes = CONTROL_CODES 127 | 128 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): 129 | super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs) 130 | self.max_len_single_sentence = ( 131 | self.max_len 132 | ) # no default special tokens - you can update this value if you add special tokens 133 | self.max_len_sentences_pair = ( 134 | self.max_len 135 | ) # no default special tokens - you can update this value if you add special tokens 136 | 137 | with open(vocab_file, encoding="utf-8") as vocab_handle: 138 | self.encoder = json.load(vocab_handle) 139 | self.decoder = {v: k for k, v in self.encoder.items()} 140 | with open(merges_file, encoding="utf-8") as merges_handle: 141 | merges = merges_handle.read().split("\n")[1:-1] 142 | merges = [tuple(merge.split()) for merge in merges] 143 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 144 | self.cache = {} 145 | 146 | @property 147 | def vocab_size(self): 148 | return len(self.encoder) 149 | 150 | def bpe(self, token): 151 | if token in self.cache: 152 | return self.cache[token] 153 | word = tuple(token) 154 | word = tuple(list(word[:-1]) + [word[-1] + ""]) 155 | pairs = get_pairs(word) 156 | 157 | if not pairs: 158 | return token 159 | 160 | while True: 161 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 162 | if bigram not in self.bpe_ranks: 163 | break 164 | first, second = bigram 165 | new_word = [] 166 | i = 0 167 | while i < len(word): 168 | try: 169 | j = word.index(first, i) 170 | except ValueError: 171 | new_word.extend(word[i:]) 172 | break 173 | else: 174 | new_word.extend(word[i:j]) 175 | i = j 176 | 177 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 178 | new_word.append(first + second) 179 | i += 2 180 | else: 181 | new_word.append(word[i]) 182 | i += 1 183 | new_word = tuple(new_word) 184 | word = new_word 185 | if len(word) == 1: 186 | break 187 | else: 188 | pairs = get_pairs(word) 189 | word = "@@ ".join(word) 190 | word = word[:-4] 191 | self.cache[token] = word 192 | return word 193 | 194 | def _tokenize(self, text): 195 | """ Tokenize a string. 196 | """ 197 | split_tokens = [] 198 | 199 | words = re.findall(r"\S+\n?", text) 200 | 201 | for token in words: 202 | split_tokens.extend([t for t in self.bpe(token).split(" ")]) 203 | return split_tokens 204 | 205 | def _convert_token_to_id(self, token): 206 | """ Converts a token (str) in an id using the vocab. """ 207 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 208 | 209 | def _convert_id_to_token(self, index): 210 | """Converts an index (integer) in a token (str) using the vocab.""" 211 | return self.decoder.get(index, self.unk_token) 212 | 213 | def convert_tokens_to_string(self, tokens): 214 | """ Converts a sequence of tokens (string) in a single string. """ 215 | out_string = " ".join(tokens).replace("@@ ", "").strip() 216 | return out_string 217 | 218 | def save_vocabulary(self, save_directory): 219 | """Save the tokenizer vocabulary and merge files to a directory.""" 220 | if not os.path.isdir(save_directory): 221 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 222 | return 223 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) 224 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"]) 225 | 226 | with open(vocab_file, "w", encoding="utf-8") as f: 227 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 228 | 229 | index = 0 230 | with open(merge_file, "w", encoding="utf-8") as writer: 231 | writer.write("#version: 0.2\n") 232 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 233 | if index != token_index: 234 | logger.warning( 235 | "Saving vocabulary to {}: BPE merge indices are not consecutive." 236 | " Please check that the tokenizer is not corrupted!".format(merge_file) 237 | ) 238 | index = token_index 239 | writer.write(" ".join(bpe_tokens) + "\n") 240 | index += 1 241 | 242 | return vocab_file, merge_file 243 | 244 | # def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 245 | # filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)) 246 | # tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens) 247 | # tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far) 248 | # return ''.join(tokens_generated_so_far) 249 | -------------------------------------------------------------------------------- /transformers/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | 18 | import logging 19 | 20 | from .tokenization_bert import BertTokenizer 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 26 | 27 | PRETRAINED_VOCAB_FILES_MAP = { 28 | "vocab_file": { 29 | "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 30 | "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 31 | "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt", 32 | "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 33 | } 34 | } 35 | 36 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 37 | "distilbert-base-uncased": 512, 38 | "distilbert-base-uncased-distilled-squad": 512, 39 | "distilbert-base-german-cased": 512, 40 | "distilbert-base-multilingual-cased": 512, 41 | } 42 | 43 | 44 | class DistilBertTokenizer(BertTokenizer): 45 | r""" 46 | Constructs a DistilBertTokenizer. 47 | :class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 48 | 49 | Args: 50 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 51 | do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True 52 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 53 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 54 | minimum of this value (if specified) and the underlying BERT model's sequence length. 55 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 56 | do_basic_tokenize=True 57 | """ 58 | 59 | vocab_files_names = VOCAB_FILES_NAMES 60 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 61 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 62 | -------------------------------------------------------------------------------- /transformers/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | 17 | 18 | import json 19 | import logging 20 | import os 21 | import re 22 | 23 | from .tokenization_bert import BasicTokenizer 24 | from .tokenization_utils import PreTrainedTokenizer 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = { 30 | "vocab_file": "vocab.json", 31 | "merges_file": "merges.txt", 32 | } 33 | 34 | PRETRAINED_VOCAB_FILES_MAP = { 35 | "vocab_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json"}, 36 | "merges_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"}, 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | "openai-gpt": 512, 41 | } 42 | 43 | 44 | def get_pairs(word): 45 | """ 46 | Return set of symbol pairs in a word. 47 | word is represented as tuple of symbols (symbols being variable-length strings) 48 | """ 49 | pairs = set() 50 | prev_char = word[0] 51 | for char in word[1:]: 52 | pairs.add((prev_char, char)) 53 | prev_char = char 54 | return pairs 55 | 56 | 57 | def text_standardize(text): 58 | """ 59 | fixes some issues the spacy tokenizer had on books corpus 60 | also does some whitespace standardization 61 | """ 62 | text = text.replace("—", "-") 63 | text = text.replace("–", "-") 64 | text = text.replace("―", "-") 65 | text = text.replace("…", "...") 66 | text = text.replace("´", "'") 67 | text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text) 68 | text = re.sub(r"\s*\n\s*", " \n ", text) 69 | text = re.sub(r"[^\S\n]+", " ", text) 70 | return text.strip() 71 | 72 | 73 | class OpenAIGPTTokenizer(PreTrainedTokenizer): 74 | """ 75 | BPE tokenizer. Peculiarities: 76 | - lower case all inputs 77 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 78 | """ 79 | 80 | vocab_files_names = VOCAB_FILES_NAMES 81 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 82 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 83 | 84 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): 85 | super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) 86 | 87 | self.max_len_single_sentence = ( 88 | self.max_len 89 | ) # no default special tokens - you can update this value if you add special tokens 90 | self.max_len_sentences_pair = ( 91 | self.max_len 92 | ) # no default special tokens - you can update this value if you add special tokens 93 | 94 | try: 95 | import ftfy 96 | from spacy.lang.en import English 97 | 98 | _nlp = English() 99 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp) 100 | self.fix_text = ftfy.fix_text 101 | except ImportError: 102 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 103 | self.nlp = BasicTokenizer(do_lower_case=True) 104 | self.fix_text = None 105 | 106 | with open(vocab_file, encoding="utf-8") as vocab_handle: 107 | self.encoder = json.load(vocab_handle) 108 | self.decoder = {v: k for k, v in self.encoder.items()} 109 | with open(merges_file, encoding="utf-8") as merges_handle: 110 | merges = merges_handle.read().split("\n")[1:-1] 111 | merges = [tuple(merge.split()) for merge in merges] 112 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 113 | self.cache = {} 114 | 115 | @property 116 | def vocab_size(self): 117 | return len(self.encoder) 118 | 119 | def bpe(self, token): 120 | word = tuple(token[:-1]) + (token[-1] + "",) 121 | if token in self.cache: 122 | return self.cache[token] 123 | pairs = get_pairs(word) 124 | 125 | if not pairs: 126 | return token + "" 127 | 128 | while True: 129 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 130 | if bigram not in self.bpe_ranks: 131 | break 132 | first, second = bigram 133 | new_word = [] 134 | i = 0 135 | while i < len(word): 136 | try: 137 | j = word.index(first, i) 138 | except ValueError: 139 | new_word.extend(word[i:]) 140 | break 141 | else: 142 | new_word.extend(word[i:j]) 143 | i = j 144 | 145 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 146 | new_word.append(first + second) 147 | i += 2 148 | else: 149 | new_word.append(word[i]) 150 | i += 1 151 | new_word = tuple(new_word) 152 | word = new_word 153 | if len(word) == 1: 154 | break 155 | else: 156 | pairs = get_pairs(word) 157 | word = " ".join(word) 158 | if word == "\n ": 159 | word = "\n" 160 | self.cache[token] = word 161 | return word 162 | 163 | def _tokenize(self, text): 164 | """ Tokenize a string. """ 165 | split_tokens = [] 166 | if self.fix_text is None: 167 | # Using BERT's BasicTokenizer 168 | text = self.nlp.tokenize(text) 169 | for token in text: 170 | split_tokens.extend([t for t in self.bpe(token).split(" ")]) 171 | else: 172 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 173 | text = self.nlp(text_standardize(self.fix_text(text))) 174 | for token in text: 175 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(" ")]) 176 | return split_tokens 177 | 178 | def _convert_token_to_id(self, token): 179 | """ Converts a token (str) in an id using the vocab. """ 180 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 181 | 182 | def _convert_id_to_token(self, index): 183 | """Converts an id in a token (BPE) using the vocab.""" 184 | return self.decoder.get(index, self.unk_token) 185 | 186 | def convert_tokens_to_string(self, tokens): 187 | """ Converts a sequence of tokens (string) in a single string. """ 188 | out_string = "".join(tokens).replace("", " ").strip() 189 | return out_string 190 | 191 | def save_vocabulary(self, save_directory): 192 | """Save the tokenizer vocabulary and merge files to a directory.""" 193 | if not os.path.isdir(save_directory): 194 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 195 | return 196 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) 197 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"]) 198 | 199 | with open(vocab_file, "w", encoding="utf-8") as f: 200 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 201 | 202 | index = 0 203 | with open(merge_file, "w", encoding="utf-8") as writer: 204 | writer.write("#version: 0.2\n") 205 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 206 | if index != token_index: 207 | logger.warning( 208 | "Saving vocabulary to {}: BPE merge indices are not consecutive." 209 | " Please check that the tokenizer is not corrupted!".format(merge_file) 210 | ) 211 | index = token_index 212 | writer.write(" ".join(bpe_tokens) + "\n") 213 | index += 1 214 | 215 | return vocab_file, merge_file 216 | -------------------------------------------------------------------------------- /transformers/tokenization_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoBERTa.""" 16 | 17 | 18 | import logging 19 | 20 | from .tokenization_gpt2 import GPT2Tokenizer 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | VOCAB_FILES_NAMES = { 26 | "vocab_file": "vocab.json", 27 | "merges_file": "merges.txt", 28 | } 29 | 30 | PRETRAINED_VOCAB_FILES_MAP = { 31 | "vocab_file": { 32 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 33 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 34 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", 35 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json", 36 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 37 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 38 | }, 39 | "merges_file": { 40 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 41 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 42 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", 43 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt", 44 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 45 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 46 | }, 47 | } 48 | 49 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 50 | "roberta-base": 512, 51 | "roberta-large": 512, 52 | "roberta-large-mnli": 512, 53 | "distilroberta-base": 512, 54 | "roberta-base-openai-detector": 512, 55 | "roberta-large-openai-detector": 512, 56 | } 57 | 58 | 59 | class RobertaTokenizer(GPT2Tokenizer): 60 | """ 61 | RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: 62 | - Byte-level Byte-Pair-Encoding 63 | - Requires a space to start the input string => the encoding methods should be called with the 64 | ``add_prefix_space`` flag set to ``True``. 65 | Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve 66 | the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"` 67 | """ 68 | 69 | vocab_files_names = VOCAB_FILES_NAMES 70 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 71 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 72 | 73 | def __init__( 74 | self, 75 | vocab_file, 76 | merges_file, 77 | errors="replace", 78 | bos_token="", 79 | eos_token="", 80 | sep_token="", 81 | cls_token="", 82 | unk_token="", 83 | pad_token="", 84 | mask_token="", 85 | **kwargs 86 | ): 87 | super(RobertaTokenizer, self).__init__( 88 | vocab_file=vocab_file, 89 | merges_file=merges_file, 90 | errors=errors, 91 | bos_token=bos_token, 92 | eos_token=eos_token, 93 | unk_token=unk_token, 94 | sep_token=sep_token, 95 | cls_token=cls_token, 96 | pad_token=pad_token, 97 | mask_token=mask_token, 98 | **kwargs 99 | ) 100 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 101 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens 102 | 103 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 104 | """ 105 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 106 | by concatenating and adding special tokens. 107 | A RoBERTa sequence has the following format: 108 | single sequence: X 109 | pair of sequences: A B 110 | """ 111 | if token_ids_1 is None: 112 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 113 | cls = [self.cls_token_id] 114 | sep = [self.sep_token_id] 115 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 116 | 117 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 118 | """ 119 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 120 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 121 | 122 | Args: 123 | token_ids_0: list of ids (must not contain special tokens) 124 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 125 | for sequence pairs 126 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 127 | special tokens for the model 128 | 129 | Returns: 130 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 131 | """ 132 | if already_has_special_tokens: 133 | if token_ids_1 is not None: 134 | raise ValueError( 135 | "You should not supply a second sequence if the provided sequence of " 136 | "ids is already formated with special tokens for the model." 137 | ) 138 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 139 | 140 | if token_ids_1 is None: 141 | return [1] + ([0] * len(token_ids_0)) + [1] 142 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 143 | 144 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 145 | """ 146 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 147 | A RoBERTa sequence pair mask has the following format: 148 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 149 | | first sequence | second sequence 150 | 151 | if token_ids_1 is None, only returns the first portion of the mask (0's). 152 | """ 153 | sep = [self.sep_token_id] 154 | cls = [self.cls_token_id] 155 | 156 | if token_ids_1 is None: 157 | return len(cls + token_ids_0 + sep) * [0] 158 | return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1] 159 | -------------------------------------------------------------------------------- /transformers/tokenization_t5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 T5 Authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization class for model T5.""" 16 | 17 | 18 | import logging 19 | import os 20 | import re 21 | from shutil import copyfile 22 | 23 | from .tokenization_utils import PreTrainedTokenizer 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | SPIECE_UNDERLINE = "▁" 29 | 30 | #################################################### 31 | # Mapping from the keyword arguments names of Tokenizer `__init__` 32 | # to file names for serializing Tokenizer instances 33 | #################################################### 34 | VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} 35 | 36 | #################################################### 37 | # Mapping from the keyword arguments names of Tokenizer `__init__` 38 | # to pretrained vocabulary URL for all the model shortcut names. 39 | #################################################### 40 | PRETRAINED_VOCAB_FILES_MAP = { 41 | "vocab_file": { 42 | "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", 43 | "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", 44 | "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", 45 | "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", 46 | "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", 47 | } 48 | } 49 | 50 | #################################################### 51 | # Mapping from model shortcut names to max length of inputs 52 | #################################################### 53 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 54 | "t5-small": 512, 55 | "t5-base": 512, 56 | "t5-large": 512, 57 | "t5-3b": 512, 58 | "t5-11b": 512, 59 | } 60 | 61 | 62 | class T5Tokenizer(PreTrainedTokenizer): 63 | """ 64 | SentencePiece based tokenizer. Peculiarities: 65 | 66 | - requires `SentencePiece `_ 67 | - `extra_ids` add a number of extra ids added to the end of the vocabulary for use as sentinels. 68 | These tokens are accessible as `` where `{%d}` is a number between 0 and extra_ids-1. 69 | Extra tokens are indexed from the end of the vocabulary up to beginnning ( is the last token in the vocabulary) 70 | (like in T5 preprocessing 71 | see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117) 72 | """ 73 | 74 | vocab_files_names = VOCAB_FILES_NAMES 75 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 76 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 77 | 78 | def __init__( 79 | self, 80 | vocab_file, 81 | eos_token="", 82 | unk_token="", 83 | pad_token="", 84 | extra_ids=100, 85 | additional_special_tokens=None, 86 | **kwargs 87 | ): 88 | # Add extra_ids to the special token list 89 | if extra_ids > 0: 90 | if additional_special_tokens is None: 91 | additional_special_tokens = [] 92 | additional_special_tokens.extend(["".format(i) for i in range(extra_ids)]) 93 | 94 | super(T5Tokenizer, self).__init__( 95 | eos_token=eos_token, 96 | unk_token=unk_token, 97 | pad_token=pad_token, 98 | additional_special_tokens=additional_special_tokens, 99 | **kwargs 100 | ) 101 | 102 | try: 103 | import sentencepiece as spm 104 | except ImportError: 105 | logger.warning( 106 | "You need to install SentencePiece to use T5Tokenizer:" 107 | "https://github.com/google/sentencepiece" 108 | "pip install sentencepiece" 109 | ) 110 | raise 111 | 112 | self.vocab_file = vocab_file 113 | self._extra_ids = extra_ids 114 | 115 | self.sp_model = spm.SentencePieceProcessor() 116 | self.sp_model.Load(vocab_file) 117 | 118 | @property 119 | def vocab_size(self): 120 | return self.sp_model.get_piece_size() + self._extra_ids 121 | 122 | def __getstate__(self): 123 | state = self.__dict__.copy() 124 | state["sp_model"] = None 125 | return state 126 | 127 | def __setstate__(self, d): 128 | self.__dict__ = d 129 | try: 130 | import sentencepiece as spm 131 | except ImportError: 132 | logger.warning( 133 | "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 134 | "pip install sentencepiece" 135 | ) 136 | raise 137 | self.sp_model = spm.SentencePieceProcessor() 138 | self.sp_model.Load(self.vocab_file) 139 | 140 | def _tokenize(self, text, sample=False): 141 | """ Take as input a string and return a list of strings (tokens) for words/sub-words 142 | """ 143 | if not sample: 144 | pieces = self.sp_model.EncodeAsPieces(text) 145 | else: 146 | pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) 147 | return pieces 148 | 149 | def _convert_token_to_id(self, token): 150 | """ Converts a token (str) in an id using the vocab. """ 151 | if token.startswith("", token) 153 | num = int(match.group(1)) 154 | return self.vocab_size - num - 1 155 | return self.sp_model.piece_to_id(token) 156 | 157 | def _convert_id_to_token(self, index): 158 | """Converts an index (integer) in a token (str) using the vocab.""" 159 | if index < self.sp_model.get_piece_size(): 160 | token = self.sp_model.IdToPiece(index) 161 | else: 162 | token = "".format(self.vocab_size - 1 - index) 163 | return token 164 | 165 | def convert_tokens_to_string(self, tokens): 166 | """ Converts a sequence of tokens (string) in a single string. """ 167 | out_string = self.sp_model.decode_pieces(tokens) 168 | return out_string 169 | 170 | def save_vocabulary(self, save_directory): 171 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file 172 | to a directory. 173 | """ 174 | if not os.path.isdir(save_directory): 175 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 176 | return 177 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) 178 | 179 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 180 | copyfile(self.vocab_file, out_vocab_file) 181 | 182 | return (out_vocab_file,) 183 | --------------------------------------------------------------------------------