├── .gitignore ├── README.md ├── setup.py ├── src ├── conlleval.py ├── dataset.py ├── models.py ├── run.py └── task_utils.py ├── train.sh └── transformers ├── __init__.py ├── commands ├── __init__.py └── user.py ├── configuration_bert.py ├── configuration_utils.py ├── data ├── __init__.py ├── metrics │ ├── __init__.py │ └── squad_metrics.py └── processors │ ├── __init__.py │ ├── glue.py │ ├── squad.py │ ├── utils.py │ └── xnli.py ├── file_utils.py ├── hf_api.py ├── modeling_bert.py ├── modeling_utils.py ├── optimization.py ├── tests ├── __init__.py ├── configuration_common_test.py ├── fixtures │ ├── input.txt │ ├── sample_text.txt │ ├── spiece.model │ └── test_sentencepiece.model ├── hf_api_test.py ├── modeling_albert_test.py ├── modeling_auto_test.py ├── modeling_bert_test.py ├── modeling_common_test.py ├── modeling_ctrl_test.py ├── modeling_distilbert_test.py ├── modeling_encoder_decoder_test.py ├── modeling_gpt2_test.py ├── modeling_openai_test.py ├── modeling_roberta_test.py ├── modeling_tf_albert_test.py ├── modeling_tf_auto_test.py ├── modeling_tf_bert_test.py ├── modeling_tf_common_test.py ├── modeling_tf_ctrl_test.py ├── modeling_tf_distilbert_test.py ├── modeling_tf_gpt2_test.py ├── modeling_tf_openai_gpt_test.py ├── modeling_tf_roberta_test.py ├── modeling_tf_transfo_xl_test.py ├── modeling_tf_xlm_test.py ├── modeling_tf_xlnet_test.py ├── modeling_transfo_xl_test.py ├── modeling_xlm_test.py ├── modeling_xlnet_test.py ├── optimization_test.py ├── optimization_tf_test.py ├── tokenization_albert_test.py ├── tokenization_auto_test.py ├── tokenization_bert_japanese_test.py ├── tokenization_bert_test.py ├── tokenization_ctrl_test.py ├── tokenization_distilbert_test.py ├── tokenization_gpt2_test.py ├── tokenization_openai_test.py ├── tokenization_roberta_test.py ├── tokenization_tests_commons.py ├── tokenization_transfo_xl_test.py ├── tokenization_utils_test.py ├── tokenization_xlm_test.py ├── tokenization_xlnet_test.py └── utils.py ├── tokenization_bert.py └── tokenization_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.zip 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Improving BERT with Syntax-aware Local Attention

3 |

4 | 5 | The implementation of the ACL2021 Findings paper "Improving BERT with Syntax-aware Local Attention". 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py, setup.py as well as docs/source/conf.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi transformers 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from io import open 37 | from setuptools import find_packages, setup 38 | 39 | 40 | extras = { 41 | 'serving': ['uvicorn', 'fastapi'] 42 | } 43 | extras['all'] = [package for package in extras.values()] 44 | 45 | setup( 46 | name="transformers", 47 | version="2.2.2", 48 | author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors", 49 | author_email="thomas@huggingface.co", 50 | description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch", 51 | keywords='NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU', 52 | license='Apache', 53 | packages=find_packages(exclude=["*.tests", "*.tests.*", 54 | "tests.*", "tests"]), 55 | install_requires=['numpy', 56 | 'boto3', 57 | 'requests', 58 | 'tqdm', 59 | 'regex', 60 | 'sentencepiece', 61 | 'sacremoses'], 62 | extras_require=extras, 63 | # python_requires='>=3.5.0', 64 | classifiers=[ 65 | 'Intended Audience :: Science/Research', 66 | 'License :: OSI Approved :: Apache Software License', 67 | 'Programming Language :: Python :: 3', 68 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 69 | ], 70 | ) 71 | -------------------------------------------------------------------------------- /src/conlleval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script applies to IOB2 or IOBES tagging scheme. 3 | If you are using a different scheme, please convert to IOB2 or IOBES. 4 | 5 | IOB2: 6 | - B = begin, 7 | - I = inside but not the first, 8 | - O = outside 9 | 10 | e.g. 11 | John lives in New York City . 12 | B-PER O O B-LOC I-LOC I-LOC O 13 | 14 | IOBES: 15 | - B = begin, 16 | - E = end, 17 | - S = singleton, 18 | - I = inside but not the first or the last, 19 | - O = outside 20 | 21 | e.g. 22 | John lives in New York City . 23 | S-PER O O B-LOC I-LOC E-LOC O 24 | 25 | prefix: IOBES 26 | chunk_type: PER, LOC, etc. 27 | """ 28 | from __future__ import division, print_function, unicode_literals 29 | 30 | import sys 31 | from collections import defaultdict 32 | 33 | def split_tag(chunk_tag): 34 | """ 35 | split chunk tag into IOBES prefix and chunk_type 36 | e.g. 37 | B-PER -> (B, PER) 38 | O -> (O, None) 39 | """ 40 | if chunk_tag == 'O': 41 | return ('O', None) 42 | return chunk_tag.split('-', maxsplit=1) 43 | 44 | def is_chunk_end(prev_tag, tag): 45 | """ 46 | check if the previous chunk ended between the previous and current word 47 | e.g. 48 | (B-PER, I-PER) -> False 49 | (B-LOC, O) -> True 50 | 51 | Note: in case of contradicting tags, e.g. (B-PER, I-LOC) 52 | this is considered as (B-PER, B-LOC) 53 | """ 54 | prefix1, chunk_type1 = split_tag(prev_tag) 55 | prefix2, chunk_type2 = split_tag(tag) 56 | 57 | if prefix1 == 'O': 58 | return False 59 | if prefix2 == 'O': 60 | return prefix1 != 'O' 61 | 62 | if chunk_type1 != chunk_type2: 63 | return True 64 | 65 | return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S'] 66 | 67 | def is_chunk_start(prev_tag, tag): 68 | """ 69 | check if a new chunk started between the previous and current word 70 | """ 71 | prefix1, chunk_type1 = split_tag(prev_tag) 72 | prefix2, chunk_type2 = split_tag(tag) 73 | 74 | if prefix2 == 'O': 75 | return False 76 | if prefix1 == 'O': 77 | return prefix2 != 'O' 78 | 79 | if chunk_type1 != chunk_type2: 80 | return True 81 | 82 | return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S'] 83 | 84 | 85 | def calc_metrics(tp, p, t, percent=True): 86 | """ 87 | compute overall precision, recall and FB1 (default values are 0.0) 88 | if percent is True, return 100 * original decimal value 89 | """ 90 | precision = tp / p if p else 0 91 | recall = tp / t if t else 0 92 | fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0 93 | if percent: 94 | return 100 * precision, 100 * recall, 100 * fb1 95 | else: 96 | return precision, recall, fb1 97 | 98 | 99 | def count_chunks(true_seqs, pred_seqs): 100 | """ 101 | true_seqs: a list of true tags 102 | pred_seqs: a list of predicted tags 103 | 104 | return: 105 | correct_chunks: a dict (counter), 106 | key = chunk types, 107 | value = number of correctly identified chunks per type 108 | true_chunks: a dict, number of true chunks per type 109 | pred_chunks: a dict, number of identified chunks per type 110 | 111 | correct_counts, true_counts, pred_counts: similar to above, but for tags 112 | """ 113 | correct_chunks = defaultdict(int) 114 | true_chunks = defaultdict(int) 115 | pred_chunks = defaultdict(int) 116 | 117 | correct_counts = defaultdict(int) 118 | true_counts = defaultdict(int) 119 | pred_counts = defaultdict(int) 120 | 121 | prev_true_tag, prev_pred_tag = 'O', 'O' 122 | correct_chunk = None 123 | 124 | for true_tag, pred_tag in zip(true_seqs, pred_seqs): 125 | if true_tag == pred_tag: 126 | correct_counts[true_tag] += 1 127 | true_counts[true_tag] += 1 128 | pred_counts[pred_tag] += 1 129 | 130 | _, true_type = split_tag(true_tag) 131 | _, pred_type = split_tag(pred_tag) 132 | 133 | if correct_chunk is not None: 134 | true_end = is_chunk_end(prev_true_tag, true_tag) 135 | pred_end = is_chunk_end(prev_pred_tag, pred_tag) 136 | 137 | if pred_end and true_end: 138 | correct_chunks[correct_chunk] += 1 139 | correct_chunk = None 140 | elif pred_end != true_end or true_type != pred_type: 141 | correct_chunk = None 142 | 143 | true_start = is_chunk_start(prev_true_tag, true_tag) 144 | pred_start = is_chunk_start(prev_pred_tag, pred_tag) 145 | 146 | if true_start and pred_start and true_type == pred_type: 147 | correct_chunk = true_type 148 | if true_start: 149 | true_chunks[true_type] += 1 150 | if pred_start: 151 | pred_chunks[pred_type] += 1 152 | 153 | prev_true_tag, prev_pred_tag = true_tag, pred_tag 154 | if correct_chunk is not None: 155 | correct_chunks[correct_chunk] += 1 156 | 157 | return (correct_chunks, true_chunks, pred_chunks, 158 | correct_counts, true_counts, pred_counts) 159 | 160 | def get_result(correct_chunks, true_chunks, pred_chunks, 161 | correct_counts, true_counts, pred_counts, verbose=True): 162 | """ 163 | if verbose, print overall performance, as well as preformance per chunk type; 164 | otherwise, simply return overall prec, rec, f1 scores 165 | """ 166 | # sum counts 167 | sum_correct_chunks = sum(correct_chunks.values()) 168 | sum_true_chunks = sum(true_chunks.values()) 169 | sum_pred_chunks = sum(pred_chunks.values()) 170 | 171 | sum_correct_counts = sum(correct_counts.values()) 172 | sum_true_counts = sum(true_counts.values()) 173 | 174 | nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O') 175 | nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O') 176 | 177 | chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks)))) 178 | 179 | # compute overall precision, recall and FB1 (default values are 0.0) 180 | prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks) 181 | res = (prec, rec, f1) 182 | if not verbose: 183 | return res 184 | 185 | # print overall performance, and performance per chunk type 186 | 187 | print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='') 188 | print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='') 189 | 190 | print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts)) 191 | print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='') 192 | print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1)) 193 | 194 | # for each chunk type, compute precision, recall and FB1 (default values are 0.0) 195 | for t in chunk_types: 196 | prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t]) 197 | print("%17s: " %t , end='') 198 | print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % 199 | (prec, rec, f1), end='') 200 | print(" %d" % pred_chunks[t]) 201 | 202 | return res 203 | # you can generate LaTeX output for tables like in 204 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 205 | # but I'm not implementing this 206 | 207 | def evaluate(true_seqs, pred_seqs, verbose=True): 208 | (correct_chunks, true_chunks, pred_chunks, 209 | correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs) 210 | result = get_result(correct_chunks, true_chunks, pred_chunks, 211 | correct_counts, true_counts, pred_counts, verbose=verbose) 212 | return result 213 | 214 | def evaluate_conll_file(fileIterator): 215 | true_seqs, pred_seqs = [], [] 216 | 217 | for line in fileIterator: 218 | cols = line.strip().split() 219 | # each non-empty line must contain >= 3 columns 220 | if not cols: 221 | true_seqs.append('O') 222 | pred_seqs.append('O') 223 | elif len(cols) < 3: 224 | raise IOError("conlleval: too few columns in line %s\n" % line) 225 | else: 226 | # extract tags from last 2 columns 227 | true_seqs.append(cols[-2]) 228 | pred_seqs.append(cols[-1]) 229 | return evaluate(true_seqs, pred_seqs) 230 | 231 | if __name__ == '__main__': 232 | """ 233 | usage: conlleval < file 234 | """ 235 | evaluate_conll_file(sys.stdin) -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | import argparse 3 | import re, csv 4 | import pickle 5 | import numpy as np 6 | from transformers import BertTokenizer 7 | import spacy 8 | from spacy.tokens import Token 9 | from spacy.tokenizer import Tokenizer 10 | import task_utils 11 | from tqdm import tqdm 12 | import copy 13 | 14 | global spacy_parser, tokenizer 15 | Token.set_extension('tid', default=0) 16 | spacy_parser = spacy.load("en_core_web_sm", disable=['ner']) 17 | spacy_parser.tokenizer = Tokenizer(spacy_parser.vocab) 18 | tokenizer = { 19 | 'en-cased': BertTokenizer.from_pretrained('bert-base-cased'), 20 | 'en-uncased': BertTokenizer.from_pretrained('bert-base-uncased'), 21 | } 22 | 23 | class InputRawExample(object): 24 | def __init__(self, text, label): 25 | self.text = text 26 | self.label = label 27 | 28 | class Graph(object): 29 | """docstring for Graph""" 30 | def __init__(self, n): 31 | super(Graph, self).__init__() 32 | self.n = n 33 | self.link_list = [] 34 | self.vis = [0] * self.n 35 | for i in range(self.n): 36 | self.link_list.append([]) 37 | 38 | def add_edge(self, u, v): 39 | if u == v: 40 | return 41 | self.link_list[u].append(v) 42 | self.link_list[v].append(u) 43 | 44 | def bfs(self, start, dist): 45 | que = [start] 46 | self.vis[start] = 1 47 | for _ in range(dist): 48 | que2 = [] 49 | for u in que: 50 | #self.vis[u] = 1 51 | for v in self.link_list[u]: 52 | if self.vis[v]: 53 | continue 54 | que2.append(v) 55 | self.vis[v] = 1 56 | que = copy.deepcopy(que2) 57 | 58 | 59 | def solve(self, start, dist): 60 | self.vis = [0] * self.n 61 | self.bfs(start, dist) 62 | self.vis[0] = 1 63 | return copy.deepcopy(self.vis) 64 | 65 | def process(args, text, label): 66 | # text: str 67 | # label: list[str] or int 68 | global spacy_parser, tokenizer 69 | if args.lang == 'en': 70 | local_tokenizer = tokenizer['en-uncased'] if args.do_lower_case else tokenizer['en-cased'] 71 | 72 | while ' ' in text: 73 | text = text.replace(' ', ' ') 74 | doc = spacy_parser(text) 75 | 76 | tokens = ['[CLS]'] 77 | for token in doc: 78 | token._.tid = len(tokens) 79 | tokens.append(token.text) 80 | 81 | G = Graph(len(tokens)) 82 | for token in doc: 83 | if token.dep_ == 'ROOT': 84 | continue 85 | G.add_edge(token._.tid, token.head._.tid) 86 | 87 | 88 | ntokens = [] 89 | ws = [] 90 | 91 | for i, token in enumerate(tokens): 92 | if token == '[CLS]': 93 | ntokens.append(token) 94 | ws.append(1) 95 | else: 96 | sub_tokens = local_tokenizer.tokenize(token) 97 | ws.append(len(sub_tokens)) 98 | for j, st in enumerate(sub_tokens): 99 | ntokens.append(st) 100 | 101 | dep_att_mask = [] 102 | for i, token in enumerate(tokens): 103 | vis = G.solve(i, args.dist) 104 | 105 | if i-1>=0: 106 | vis_tmp = G.solve(i-1, args.dist) 107 | for j in range(len(vis_tmp)): 108 | vis[j] |= vis_tmp[j] 109 | 110 | if i+1?@[\\]^_': 155 | quote_idx.append(i) 156 | 157 | for i in range(len(ntokens)): 158 | for j in quote_idx: 159 | dep_att_mask[i][j] = 1 160 | 161 | input_ids = local_tokenizer.convert_tokens_to_ids(ntokens) 162 | else: 163 | raise KeyError(args.lang) 164 | 165 | example = { 166 | 'input_ids': input_ids, 167 | 'dep_att_mask': dep_att_mask, 168 | 'labels': labels, 169 | } 170 | if isinstance(label, list): 171 | example['loss_mask'] = loss_mask 172 | return example 173 | 174 | def _read_tsv(input_file, quotechar=None): 175 | """Reads a tab separated value file.""" 176 | with open(input_file, "r", encoding="utf-8-sig") as f: 177 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 178 | lines = [] 179 | for line in reader: 180 | if sys.version_info[0] == 2: 181 | line = list(unicode(cell, 'utf-8') for cell in line) 182 | lines.append(line) 183 | return lines 184 | 185 | def _read_conll_format_file(input_file): 186 | lines = [] 187 | with open(input_file, 'r', encoding='utf-8') as f: 188 | words = [] 189 | labels = [] 190 | for line in f: 191 | contends = line.strip() 192 | feature_vector = contends.split(' ') 193 | word = feature_vector[0].strip() 194 | label = feature_vector[-1].strip() 195 | if len(contends) == 0: 196 | w = ' '.join(words) 197 | l = ' '.join(labels) 198 | lines.append((w, l)) 199 | words = [] 200 | labels = [] 201 | continue 202 | #word = re.sub(u"([^\u4e00-\u9fa5\u0030-\u0039\u0041-\u005a\u0061-\u007a'!'#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘'![\\]^_`{|}~])","",word) 203 | if len(word) == 0 or len(label) == 0: 204 | continue 205 | words.append(word) 206 | labels.append(label) 207 | return lines 208 | 209 | def main(args): 210 | raw_examples = [] 211 | if args.task == 'cola': 212 | lines = _read_tsv(args.data_file) 213 | for line in lines: 214 | text = line[3].lower() if args.do_lower_case else line[3] 215 | raw_examples.append(InputRawExample(text, int(line[1]))) 216 | elif args.task == 'sst-2': 217 | lines = _read_tsv(args.data_file) 218 | for line in lines[1:]: 219 | text = line[0].lower() if args.do_lower_case else line[0] 220 | raw_examples.append(InputRawExample(text, int(line[1]))) 221 | elif args.task == 'fce': 222 | lines = _read_conll_format_file(args.data_file) 223 | for w, l in lines: 224 | text = w.lower() if args.do_lower_case else w 225 | raw_examples.append(InputRawExample(text, l.split(' '))) 226 | else: 227 | raise KeyError(args.task) 228 | examples = [] 229 | for item in tqdm(raw_examples, desc='Convert'): 230 | examples.append(process(args, item.text, item.label)) 231 | 232 | filename = args.data_file.split('/')[-1] 233 | pickle.dump(examples, open('./%s.%s.d%d.pkl'%(args.task, filename, args.dist), 'wb')) 234 | 235 | 236 | if __name__=='__main__': 237 | p = argparse.ArgumentParser() 238 | p.add_argument("--task", default="cola", type=str) 239 | p.add_argument("--data_file", default="cola.tsv", type=str) 240 | p.add_argument("--lang", default="en", type=str) 241 | p.add_argument("--do_lower_case", action="store_true") 242 | p.add_argument("--dist", default=3, type=int) 243 | args = p.parse_args() 244 | main(args) 245 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import logging 3 | import math 4 | import os 5 | import sys 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import CrossEntropyLoss, MSELoss 10 | 11 | from transformers.modeling_bert import * 12 | 13 | class TokenClsModel(BertPreTrainedModel): 14 | def __init__(self, config): 15 | super(TokenClsModel, self).__init__(config) 16 | self.num_labels = config.num_labels 17 | 18 | self.bert = BertModel(config) 19 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 20 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 21 | 22 | self.init_weights() 23 | 24 | @staticmethod 25 | def build_batch(batch, tokenizer): 26 | return batch 27 | 28 | def forward(self, batch): 29 | input_ids = batch['input_ids'] 30 | attention_mask = batch['attention_mask'] 31 | dep_att_mask = batch['dep_att_mask'] 32 | labels = batch['labels'] 33 | loss_mask = batch['loss_mask'] if 'loss_mask' in batch else None 34 | 35 | outputs = self.bert(input_ids, 36 | attention_mask=attention_mask, 37 | dep_att_mask=dep_att_mask) 38 | 39 | sequence_output = outputs[0] 40 | 41 | sequence_output = self.dropout(sequence_output) 42 | logits = self.classifier(sequence_output) 43 | 44 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 45 | if labels is not None: 46 | loss_fct = CrossEntropyLoss() 47 | # Only keep active parts of the loss 48 | if loss_mask is not None: 49 | active_loss = loss_mask.view(-1) == 1 50 | active_logits = logits.view(-1, self.num_labels)[active_loss] 51 | active_labels = labels.view(-1)[active_loss] 52 | loss = loss_fct(active_logits, active_labels) 53 | else: 54 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 55 | outputs = (loss,) + outputs 56 | return outputs # (loss), scores, (hidden_states), (attentions) 57 | 58 | class SentClsModel(BertPreTrainedModel): 59 | def __init__(self, config): 60 | super(SentClsModel, self).__init__(config) 61 | self.num_labels = config.num_labels 62 | 63 | self.bert = BertModel(config) 64 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 65 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 66 | 67 | self.init_weights() 68 | 69 | @staticmethod 70 | def build_batch(batch, tokenizer): 71 | return batch 72 | 73 | def forward(self, batch): 74 | input_ids = batch['input_ids'] 75 | attention_mask = batch['attention_mask'] 76 | dep_att_mask = batch['dep_att_mask'] 77 | labels = batch['labels'] 78 | 79 | outputs = self.bert(input_ids, 80 | attention_mask=attention_mask, 81 | dep_att_mask=dep_att_mask) 82 | pooled_output = outputs[1] 83 | pooled_output = self.dropout(pooled_output) 84 | logits = self.classifier(pooled_output) 85 | 86 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 87 | 88 | if labels is not None: 89 | loss_fct = CrossEntropyLoss() 90 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 91 | outputs = (loss,) + outputs 92 | return outputs # (loss), logits, (hidden_states), (attentions) 93 | 94 | -------------------------------------------------------------------------------- /src/task_utils.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | import logging 3 | import numpy as np 4 | logger = logging.getLogger(__name__) 5 | 6 | try: 7 | from scipy.stats import pearsonr, spearmanr 8 | from sklearn.metrics import matthews_corrcoef, f1_score 9 | _has_sklearn = True 10 | except (AttributeError, ImportError) as e: 11 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") 12 | _has_sklearn = False 13 | 14 | from conlleval import evaluate as con_eval 15 | 16 | num_labels = { 17 | 'cola': 2, 18 | 'sst-2': 2, 19 | 'chns': 2, 20 | 'cged': 9, 21 | 'fce': 6, 22 | } 23 | 24 | class CgedProcessor(object): 25 | def get_labels(self): 26 | return ['O', 'B-R', 'I-R', 'B-M', 'I-M', 'B-S', 'I-S', 'B-W', 'I-W'] 27 | 28 | class FceProcessor(object): 29 | def get_labels(self): 30 | return ['O', 'B-R', 'I-R', 'B-U', 'I-U', 'B-M'] 31 | 32 | task_processors = { 33 | 'cged': CgedProcessor, 34 | 'fce' : FceProcessor, 35 | } 36 | 37 | def is_sklearn_available(): 38 | return _has_sklearn 39 | 40 | if _has_sklearn: 41 | 42 | def simple_accuracy(preds, labels): 43 | return {'acc': (preds == labels).mean()} 44 | 45 | 46 | def acc_and_f1(preds, labels): 47 | acc = simple_accuracy(preds, labels) 48 | f1 = f1_score(y_true=labels, y_pred=preds) 49 | return { 50 | "acc": acc, 51 | "f1": f1, 52 | "acc_and_f1": (acc + f1) / 2, 53 | } 54 | 55 | 56 | def pearson_and_spearman(preds, labels): 57 | pearson_corr = pearsonr(preds, labels)[0] 58 | spearman_corr = spearmanr(preds, labels)[0] 59 | return { 60 | "pearson": pearson_corr, 61 | "spearmanr": spearman_corr, 62 | "corr": (pearson_corr + spearman_corr) / 2, 63 | } 64 | 65 | def evaluate(task, eval_data): 66 | if task == 'cola': 67 | preds = np.array(eval_data['preds']) 68 | labels = np.array(eval_data['labels']) 69 | return {"mcc": matthews_corrcoef(labels, preds)} 70 | elif task == 'chns' or task == 'sst-2': 71 | preds = np.array(eval_data['preds']) 72 | labels = np.array(eval_data['labels']) 73 | return simple_accuracy(preds, labels) 74 | elif task == 'fce': 75 | preds = eval_data['preds'] 76 | labels = eval_data['labels'] 77 | loss_mask = eval_data['loss_mask'] 78 | true_tags, pred_tags = [], [] 79 | label_list = FceProcessor().get_labels() 80 | for i in range(len(preds)): 81 | for pid, tid, mask in zip(preds[i], labels[i], loss_mask[i]): 82 | if mask == 1: 83 | pred_tags.append(label_list[pid]) 84 | true_tags.append(label_list[tid]) 85 | pred_tags.append('O') 86 | true_tags.append('O') 87 | prec, rec, f1 = con_eval(true_tags, pred_tags, verbose=False) 88 | return { 89 | 'precision': prec, 90 | 'recall': rec, 91 | 'f1': f1, 92 | 'f0.5': 1.25 * prec * rec / (rec + 0.25 * prec) 93 | } 94 | else: 95 | raise KeyError(task) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export INPUT_DIR=/path_to_load/ 3 | export OUTPUT_DIR=/path_to_store/ 4 | 5 | pip install --user --editable . 6 | python src/run.py \ 7 | --task cola \ 8 | --model_type sent-cls \ 9 | --model_name_or_path bert-base-uncased \ 10 | --data_dir $INPUT_DIR \ 11 | --output_dir $OUTPUT_DIR \ 12 | --train_file cola.train.tsv.d3.pkl \ 13 | --dev_file cola.dev.tsv.d3.pkl \ 14 | --do_train --do_eval --do_lower_case \ 15 | --learning_rate 3e-5 \ 16 | --num_train_epochs 2 \ 17 | --order_metric mcc \ 18 | --metric_reverse \ 19 | --remove_unused_ckpts \ 20 | --per_gpu_train_batch_size 16 \ 21 | --eval_all_checkpoints \ 22 | --overwrite_output_dir \ 23 | --save_steps 100 \ 24 | --seed 17 25 | 26 | -------------------------------------------------------------------------------- /transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.2.2" 2 | 3 | # Work around to update TensorFlow's absl.logging threshold which alters the 4 | # default Python logging output behavior when present. 5 | # see: https://github.com/abseil/abseil-py/issues/99 6 | # and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493 7 | try: 8 | import absl.logging 9 | absl.logging.set_verbosity('info') 10 | absl.logging.set_stderrthreshold('info') 11 | absl.logging._warn_preinit_stderr = False 12 | except: 13 | pass 14 | 15 | import logging 16 | 17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 18 | 19 | # Files and general utilities 20 | from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, 21 | cached_path, add_start_docstrings, add_end_docstrings, 22 | WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, 23 | is_tf_available, is_torch_available) 24 | 25 | from .data import (is_sklearn_available, 26 | InputExample, InputFeatures, DataProcessor, 27 | glue_output_modes, glue_convert_examples_to_features, 28 | glue_processors, glue_tasks_num_labels, 29 | xnli_output_modes, xnli_processors, xnli_tasks_num_labels, 30 | squad_convert_examples_to_features, SquadFeatures, 31 | SquadExample, SquadV1Processor, SquadV2Processor) 32 | 33 | if is_sklearn_available(): 34 | from .data import glue_compute_metrics, xnli_compute_metrics 35 | 36 | # Tokenizers 37 | from .tokenization_utils import (PreTrainedTokenizer) 38 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 39 | 40 | # Configurations 41 | from .configuration_utils import PretrainedConfig 42 | from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 43 | 44 | # Modeling 45 | if is_torch_available(): 46 | from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) 47 | 48 | from .modeling_bert import (BertModel, 49 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP) 50 | 51 | # Optimization 52 | from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, 53 | get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup) 54 | 55 | if not is_tf_available() and not is_torch_available(): 56 | logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found." 57 | "Models won't be available and only tokenizers, configuration" 58 | "and file/data utilities can be used.") 59 | -------------------------------------------------------------------------------- /transformers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | class BaseTransformersCLICommand(ABC): 5 | @staticmethod 6 | @abstractmethod 7 | def register_subcommand(parser: ArgumentParser): 8 | raise NotImplementedError() 9 | 10 | @abstractmethod 11 | def run(self): 12 | raise NotImplementedError() 13 | -------------------------------------------------------------------------------- /transformers/commands/user.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from getpass import getpass 3 | import os 4 | 5 | from transformers.commands import BaseTransformersCLICommand 6 | from transformers.hf_api import HfApi, HfFolder, HTTPError 7 | 8 | 9 | class UserCommands(BaseTransformersCLICommand): 10 | @staticmethod 11 | def register_subcommand(parser: ArgumentParser): 12 | login_parser = parser.add_parser('login') 13 | login_parser.set_defaults(func=lambda args: LoginCommand(args)) 14 | whoami_parser = parser.add_parser('whoami') 15 | whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) 16 | logout_parser = parser.add_parser('logout') 17 | logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) 18 | list_parser = parser.add_parser('ls') 19 | list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) 20 | # upload 21 | upload_parser = parser.add_parser('upload') 22 | upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.') 23 | upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.') 24 | upload_parser.set_defaults(func=lambda args: UploadCommand(args)) 25 | 26 | 27 | 28 | class ANSI: 29 | """ 30 | Helper for en.wikipedia.org/wiki/ANSI_escape_code 31 | """ 32 | _bold = u"\u001b[1m" 33 | _reset = u"\u001b[0m" 34 | @classmethod 35 | def bold(cls, s): 36 | return "{}{}{}".format(cls._bold, s, cls._reset) 37 | 38 | 39 | class BaseUserCommand: 40 | def __init__(self, args): 41 | self.args = args 42 | self._api = HfApi() 43 | 44 | 45 | class LoginCommand(BaseUserCommand): 46 | def run(self): 47 | print(""" 48 | _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| 49 | _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| 50 | _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| 51 | _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| 52 | _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| 53 | 54 | """) 55 | username = input("Username: ") 56 | password = getpass() 57 | try: 58 | token = self._api.login(username, password) 59 | except HTTPError as e: 60 | # probably invalid credentials, display error message. 61 | print(e) 62 | exit(1) 63 | HfFolder.save_token(token) 64 | print("Login successful") 65 | print("Your token:", token, "\n") 66 | print("Your token has been saved to", HfFolder.path_token) 67 | 68 | 69 | class WhoamiCommand(BaseUserCommand): 70 | def run(self): 71 | token = HfFolder.get_token() 72 | if token is None: 73 | print("Not logged in") 74 | exit() 75 | try: 76 | user = self._api.whoami(token) 77 | print(user) 78 | except HTTPError as e: 79 | print(e) 80 | 81 | 82 | class LogoutCommand(BaseUserCommand): 83 | def run(self): 84 | token = HfFolder.get_token() 85 | if token is None: 86 | print("Not logged in") 87 | exit() 88 | HfFolder.delete_token() 89 | self._api.logout(token) 90 | print("Successfully logged out.") 91 | 92 | 93 | class ListObjsCommand(BaseUserCommand): 94 | def tabulate(self, rows, headers): 95 | # type: (List[List[Union[str, int]]], List[str]) -> str 96 | """ 97 | Inspired by: 98 | stackoverflow.com/a/8356620/593036 99 | stackoverflow.com/questions/9535954/printing-lists-as-tabular-data 100 | """ 101 | col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] 102 | row_format = ("{{:{}}} " * len(headers)).format(*col_widths) 103 | lines = [] 104 | lines.append( 105 | row_format.format(*headers) 106 | ) 107 | lines.append( 108 | row_format.format(*["-" * w for w in col_widths]) 109 | ) 110 | for row in rows: 111 | lines.append( 112 | row_format.format(*row) 113 | ) 114 | return "\n".join(lines) 115 | 116 | def run(self): 117 | token = HfFolder.get_token() 118 | if token is None: 119 | print("Not logged in") 120 | exit(1) 121 | try: 122 | objs = self._api.list_objs(token) 123 | except HTTPError as e: 124 | print(e) 125 | exit(1) 126 | if len(objs) == 0: 127 | print("No shared file yet") 128 | exit() 129 | rows = [ [ 130 | obj.filename, 131 | obj.LastModified, 132 | obj.ETag, 133 | obj.Size 134 | ] for obj in objs ] 135 | print( 136 | self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]) 137 | ) 138 | 139 | 140 | class UploadCommand(BaseUserCommand): 141 | def walk_dir(self, rel_path): 142 | """ 143 | Recursively list all files in a folder. 144 | """ 145 | entries: List[os.DirEntry] = list(os.scandir(rel_path)) 146 | files = [ 147 | ( 148 | os.path.join(os.getcwd(), f.path), # filepath 149 | f.path # filename 150 | ) 151 | for f in entries if f.is_file() 152 | ] 153 | for f in entries: 154 | if f.is_dir(): 155 | files += self.walk_dir(f.path) 156 | return files 157 | 158 | def run(self): 159 | token = HfFolder.get_token() 160 | if token is None: 161 | print("Not logged in") 162 | exit(1) 163 | local_path = os.path.abspath(self.args.path) 164 | if os.path.isdir(local_path): 165 | if self.args.filename is not None: 166 | raise ValueError("Cannot specify a filename override when uploading a folder.") 167 | rel_path = os.path.basename(local_path) 168 | files = self.walk_dir(rel_path) 169 | elif os.path.isfile(local_path): 170 | filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path) 171 | files = [(local_path, filename)] 172 | else: 173 | raise ValueError("Not a valid file or directory: {}".format(local_path)) 174 | 175 | for filepath, filename in files: 176 | print( 177 | "About to upload file {} to S3 under filename {}".format( 178 | ANSI.bold(filepath), ANSI.bold(filename) 179 | ) 180 | ) 181 | 182 | choice = input("Proceed? [Y/n] ").lower() 183 | if not(choice == "" or choice == "y" or choice == "yes"): 184 | print("Abort") 185 | exit() 186 | print( 187 | ANSI.bold("Uploading... This might take a while if files are large") 188 | ) 189 | for filepath, filename in files: 190 | access_url = self._api.presign_and_upload( 191 | token=token, filename=filename, filepath=filepath 192 | ) 193 | print("Your file now lives at:") 194 | print(access_url) 195 | -------------------------------------------------------------------------------- /transformers/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 44 | 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 45 | 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", 46 | 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", 47 | 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", 48 | 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json" 49 | } 50 | 51 | 52 | class BertConfig(PretrainedConfig): 53 | r""" 54 | :class:`~transformers.BertConfig` is the configuration class to store the configuration of a 55 | `BertModel`. 56 | 57 | 58 | Arguments: 59 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 60 | hidden_size: Size of the encoder layers and the pooler layer. 61 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 62 | num_attention_heads: Number of attention heads for each attention layer in 63 | the Transformer encoder. 64 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 65 | layer in the Transformer encoder. 66 | hidden_act: The non-linear activation function (function or string) in the 67 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 68 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 69 | layers in the embeddings, encoder, and pooler. 70 | attention_probs_dropout_prob: The dropout ratio for the attention 71 | probabilities. 72 | max_position_embeddings: The maximum sequence length that this model might 73 | ever be used with. Typically set this to something large just in case 74 | (e.g., 512 or 1024 or 2048). 75 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 76 | `BertModel`. 77 | initializer_range: The sttdev of the truncated_normal_initializer for 78 | initializing all weight matrices. 79 | layer_norm_eps: The epsilon used by LayerNorm. 80 | """ 81 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 82 | 83 | def __init__(self, 84 | vocab_size_or_config_json_file=30522, 85 | hidden_size=768, 86 | num_hidden_layers=12, 87 | num_attention_heads=12, 88 | intermediate_size=3072, 89 | hidden_act="gelu", 90 | hidden_dropout_prob=0.1, 91 | attention_probs_dropout_prob=0.1, 92 | max_position_embeddings=512, 93 | type_vocab_size=2, 94 | initializer_range=0.02, 95 | layer_norm_eps=1e-12, 96 | **kwargs): 97 | super(BertConfig, self).__init__(**kwargs) 98 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 99 | and isinstance(vocab_size_or_config_json_file, unicode)): 100 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 101 | json_config = json.loads(reader.read()) 102 | for key, value in json_config.items(): 103 | self.__dict__[key] = value 104 | elif isinstance(vocab_size_or_config_json_file, int): 105 | self.vocab_size = vocab_size_or_config_json_file 106 | self.hidden_size = hidden_size 107 | self.num_hidden_layers = num_hidden_layers 108 | self.num_attention_heads = num_attention_heads 109 | self.hidden_act = hidden_act 110 | self.intermediate_size = intermediate_size 111 | self.hidden_dropout_prob = hidden_dropout_prob 112 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 113 | self.max_position_embeddings = max_position_embeddings 114 | self.type_vocab_size = type_vocab_size 115 | self.initializer_range = initializer_range 116 | self.layer_norm_eps = layer_norm_eps 117 | else: 118 | raise ValueError("First argument must be either a vocabulary size (int)" 119 | " or the path to a pretrained model config file (str)") 120 | -------------------------------------------------------------------------------- /transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures 2 | from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor 4 | from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels 5 | 6 | from .metrics import is_sklearn_available 7 | if is_sklearn_available(): 8 | from .metrics import glue_compute_metrics, xnli_compute_metrics 9 | -------------------------------------------------------------------------------- /transformers/data/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | try: 24 | from scipy.stats import pearsonr, spearmanr 25 | from sklearn.metrics import matthews_corrcoef, f1_score 26 | _has_sklearn = True 27 | except (AttributeError, ImportError) as e: 28 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") 29 | _has_sklearn = False 30 | 31 | def is_sklearn_available(): 32 | return _has_sklearn 33 | 34 | if _has_sklearn: 35 | 36 | def simple_accuracy(preds, labels): 37 | return (preds == labels).mean() 38 | 39 | 40 | def acc_and_f1(preds, labels): 41 | acc = simple_accuracy(preds, labels) 42 | f1 = f1_score(y_true=labels, y_pred=preds) 43 | return { 44 | "acc": acc, 45 | "f1": f1, 46 | "acc_and_f1": (acc + f1) / 2, 47 | } 48 | 49 | 50 | def pearson_and_spearman(preds, labels): 51 | pearson_corr = pearsonr(preds, labels)[0] 52 | spearman_corr = spearmanr(preds, labels)[0] 53 | return { 54 | "pearson": pearson_corr, 55 | "spearmanr": spearman_corr, 56 | "corr": (pearson_corr + spearman_corr) / 2, 57 | } 58 | 59 | 60 | def glue_compute_metrics(task_name, preds, labels): 61 | assert len(preds) == len(labels) 62 | if task_name == "cola": 63 | return {"mcc": matthews_corrcoef(labels, preds)} 64 | elif task_name == "sst-2": 65 | return {"acc": simple_accuracy(preds, labels)} 66 | elif task_name == "mrpc": 67 | return acc_and_f1(preds, labels) 68 | elif task_name == "sts-b": 69 | return pearson_and_spearman(preds, labels) 70 | elif task_name == "qqp": 71 | return acc_and_f1(preds, labels) 72 | elif task_name == "mnli": 73 | return {"acc": simple_accuracy(preds, labels)} 74 | elif task_name == "mnli-mm": 75 | return {"acc": simple_accuracy(preds, labels)} 76 | elif task_name == "qnli": 77 | return {"acc": simple_accuracy(preds, labels)} 78 | elif task_name == "rte": 79 | return {"acc": simple_accuracy(preds, labels)} 80 | elif task_name == "wnli": 81 | return {"acc": simple_accuracy(preds, labels)} 82 | else: 83 | raise KeyError(task_name) 84 | 85 | 86 | def xnli_compute_metrics(task_name, preds, labels): 87 | assert len(preds) == len(labels) 88 | if task_name == "xnli": 89 | return {"acc": simple_accuracy(preds, labels)} 90 | else: 91 | raise KeyError(task_name) 92 | -------------------------------------------------------------------------------- /transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import InputExample, InputFeatures, DataProcessor 2 | from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor 4 | from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels -------------------------------------------------------------------------------- /transformers/data/processors/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import copy 20 | import json 21 | 22 | class InputExample(object): 23 | """ 24 | A single training/test example for simple sequence classification. 25 | 26 | Args: 27 | guid: Unique id for the example. 28 | text_a: string. The untokenized text of the first sequence. For single 29 | sequence tasks, only this sequence must be specified. 30 | text_b: (Optional) string. The untokenized text of the second sequence. 31 | Only must be specified for sequence pair tasks. 32 | label: (Optional) string. The label of the example. This should be 33 | specified for train and dev examples, but not for test examples. 34 | """ 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | self.guid = guid 37 | self.text_a = text_a 38 | self.text_b = text_b 39 | self.label = label 40 | 41 | def __repr__(self): 42 | return str(self.to_json_string()) 43 | 44 | def to_dict(self): 45 | """Serializes this instance to a Python dictionary.""" 46 | output = copy.deepcopy(self.__dict__) 47 | return output 48 | 49 | def to_json_string(self): 50 | """Serializes this instance to a JSON string.""" 51 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 52 | 53 | 54 | class InputFeatures(object): 55 | """ 56 | A single set of features of data. 57 | 58 | Args: 59 | input_ids: Indices of input sequence tokens in the vocabulary. 60 | attention_mask: Mask to avoid performing attention on padding token indices. 61 | Mask values selected in ``[0, 1]``: 62 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 63 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 64 | label: Label corresponding to the input 65 | """ 66 | 67 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 68 | self.input_ids = input_ids 69 | self.attention_mask = attention_mask 70 | self.token_type_ids = token_type_ids 71 | self.label = label 72 | 73 | def __repr__(self): 74 | return str(self.to_json_string()) 75 | 76 | def to_dict(self): 77 | """Serializes this instance to a Python dictionary.""" 78 | output = copy.deepcopy(self.__dict__) 79 | return output 80 | 81 | def to_json_string(self): 82 | """Serializes this instance to a JSON string.""" 83 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 84 | 85 | 86 | class DataProcessor(object): 87 | """Base class for data converters for sequence classification data sets.""" 88 | 89 | def get_example_from_tensor_dict(self, tensor_dict): 90 | """Gets an example from a dict with tensorflow tensors 91 | 92 | Args: 93 | tensor_dict: Keys and values should match the corresponding Glue 94 | tensorflow_dataset examples. 95 | """ 96 | raise NotImplementedError() 97 | 98 | def get_train_examples(self, data_dir): 99 | """Gets a collection of `InputExample`s for the train set.""" 100 | raise NotImplementedError() 101 | 102 | def get_dev_examples(self, data_dir): 103 | """Gets a collection of `InputExample`s for the dev set.""" 104 | raise NotImplementedError() 105 | 106 | def get_labels(self): 107 | """Gets the list of labels for this data set.""" 108 | raise NotImplementedError() 109 | 110 | def tfds_map(self, example): 111 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. 112 | This method converts examples to the correct format.""" 113 | if len(self.get_labels()) > 1: 114 | example.label = self.get_labels()[int(example.label)] 115 | return example 116 | 117 | @classmethod 118 | def _read_tsv(cls, input_file, quotechar=None): 119 | """Reads a tab separated value file.""" 120 | with open(input_file, "r", encoding="utf-8-sig") as f: 121 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 122 | lines = [] 123 | for line in reader: 124 | if sys.version_info[0] == 2: 125 | line = list(unicode(cell, 'utf-8') for cell in line) 126 | lines.append(line) 127 | return lines 128 | -------------------------------------------------------------------------------- /transformers/data/processors/xnli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XNLI utils (dataset loading and evaluation) """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import logging 21 | import os 22 | 23 | from .utils import DataProcessor, InputExample 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | class XnliProcessor(DataProcessor): 28 | """Processor for the XNLI dataset. 29 | Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207""" 30 | 31 | def __init__(self, language, train_language = None): 32 | self.language = language 33 | self.train_language = train_language 34 | 35 | def get_train_examples(self, data_dir): 36 | """See base class.""" 37 | lg = self.language if self.train_language is None else self.train_language 38 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-MT-1.0/multinli/multinli.train.{}.tsv".format(lg))) 39 | examples = [] 40 | for (i, line) in enumerate(lines): 41 | if i == 0: 42 | continue 43 | guid = "%s-%s" % ('train', i) 44 | text_a = line[0] 45 | text_b = line[1] 46 | label = "contradiction" if line[2] == "contradictory" else line[2] 47 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 48 | examples.append( 49 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 50 | return examples 51 | 52 | def get_test_examples(self, data_dir): 53 | """See base class.""" 54 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv")) 55 | examples = [] 56 | for (i, line) in enumerate(lines): 57 | if i == 0: 58 | continue 59 | language = line[0] 60 | if language != self.language: 61 | continue 62 | guid = "%s-%s" % ('test', i) 63 | text_a = line[6] 64 | text_b = line[7] 65 | label = line[1] 66 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 67 | examples.append( 68 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 69 | return examples 70 | 71 | def get_labels(self): 72 | """See base class.""" 73 | return ["contradiction", "entailment", "neutral"] 74 | 75 | xnli_processors = { 76 | "xnli": XnliProcessor, 77 | } 78 | 79 | xnli_output_modes = { 80 | "xnli": "classification", 81 | } 82 | 83 | xnli_tasks_num_labels = { 84 | "xnli": 3, 85 | } 86 | -------------------------------------------------------------------------------- /transformers/hf_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function 16 | 17 | import os 18 | from os.path import expanduser 19 | 20 | import requests 21 | import six 22 | from requests.exceptions import HTTPError 23 | from tqdm import tqdm 24 | 25 | ENDPOINT = "https://huggingface.co" 26 | 27 | class S3Obj: 28 | def __init__( 29 | self, 30 | filename, # type: str 31 | LastModified, # type: str 32 | ETag, # type: str 33 | Size, # type: int 34 | **kwargs 35 | ): 36 | self.filename = filename 37 | self.LastModified = LastModified 38 | self.ETag = ETag 39 | self.Size = Size 40 | 41 | 42 | class PresignedUrl: 43 | def __init__( 44 | self, 45 | write, # type: str 46 | access, # type: str 47 | type, # type: str 48 | **kwargs 49 | ): 50 | self.write = write 51 | self.access = access 52 | self.type = type # mime-type to send to S3. 53 | 54 | 55 | class HfApi: 56 | def __init__(self, endpoint=None): 57 | self.endpoint = endpoint if endpoint is not None else ENDPOINT 58 | 59 | def login( 60 | self, 61 | username, # type: str 62 | password, # type: str 63 | ): 64 | # type: (...) -> str 65 | """ 66 | Call HF API to sign in a user and get a token if credentials are valid. 67 | 68 | Outputs: 69 | token if credentials are valid 70 | 71 | Throws: 72 | requests.exceptions.HTTPError if credentials are invalid 73 | """ 74 | path = "{}/api/login".format(self.endpoint) 75 | r = requests.post(path, json={"username": username, "password": password}) 76 | r.raise_for_status() 77 | d = r.json() 78 | return d["token"] 79 | 80 | def whoami( 81 | self, 82 | token, # type: str 83 | ): 84 | # type: (...) -> str 85 | """ 86 | Call HF API to know "whoami" 87 | """ 88 | path = "{}/api/whoami".format(self.endpoint) 89 | r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) 90 | r.raise_for_status() 91 | d = r.json() 92 | return d["user"] 93 | 94 | def logout(self, token): 95 | # type: (...) -> void 96 | """ 97 | Call HF API to log out. 98 | """ 99 | path = "{}/api/logout".format(self.endpoint) 100 | r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) 101 | r.raise_for_status() 102 | 103 | def presign(self, token, filename): 104 | # type: (...) -> PresignedUrl 105 | """ 106 | Call HF API to get a presigned url to upload `filename` to S3. 107 | """ 108 | path = "{}/api/presign".format(self.endpoint) 109 | r = requests.post( 110 | path, 111 | headers={"authorization": "Bearer {}".format(token)}, 112 | json={"filename": filename}, 113 | ) 114 | r.raise_for_status() 115 | d = r.json() 116 | return PresignedUrl(**d) 117 | 118 | def presign_and_upload(self, token, filename, filepath): 119 | # type: (...) -> str 120 | """ 121 | Get a presigned url, then upload file to S3. 122 | 123 | Outputs: 124 | url: Read-only url for the stored file on S3. 125 | """ 126 | urls = self.presign(token, filename=filename) 127 | # streaming upload: 128 | # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads 129 | # 130 | # Even though we presign with the correct content-type, 131 | # the client still has to specify it when uploading the file. 132 | with open(filepath, "rb") as f: 133 | pf = TqdmProgressFileReader(f) 134 | 135 | r = requests.put(urls.write, data=f, headers={ 136 | "content-type": urls.type, 137 | }) 138 | r.raise_for_status() 139 | pf.close() 140 | return urls.access 141 | 142 | def list_objs(self, token): 143 | # type: (...) -> List[S3Obj] 144 | """ 145 | Call HF API to list all stored files for user. 146 | """ 147 | path = "{}/api/listObjs".format(self.endpoint) 148 | r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) 149 | r.raise_for_status() 150 | d = r.json() 151 | return [S3Obj(**x) for x in d] 152 | 153 | 154 | 155 | class TqdmProgressFileReader: 156 | """ 157 | Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) 158 | and override `f.read()` so as to display a tqdm progress bar. 159 | 160 | see github.com/huggingface/transformers/pull/2078#discussion_r354739608 161 | for implementation details. 162 | """ 163 | def __init__( 164 | self, 165 | f # type: io.BufferedReader 166 | ): 167 | self.f = f 168 | self.total_size = os.fstat(f.fileno()).st_size # type: int 169 | self.pbar = tqdm(total=self.total_size, leave=False) 170 | if six.PY3: 171 | # does not work unless PY3 172 | # no big deal as the CLI does not currently support PY2 anyways. 173 | self.read = f.read 174 | f.read = self._read 175 | 176 | def _read(self, n=-1): 177 | self.pbar.update(n) 178 | return self.read(n) 179 | 180 | def close(self): 181 | self.pbar.close() 182 | 183 | 184 | 185 | class HfFolder: 186 | path_token = expanduser("~/.huggingface/token") 187 | 188 | @classmethod 189 | def save_token(cls, token): 190 | """ 191 | Save token, creating folder as needed. 192 | """ 193 | if six.PY3: 194 | os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) 195 | else: 196 | # Python 2 197 | try: 198 | os.makedirs(os.path.dirname(cls.path_token)) 199 | except OSError as e: 200 | if e.errno != os.errno.EEXIST: 201 | raise e 202 | pass 203 | with open(cls.path_token, 'w+') as f: 204 | f.write(token) 205 | 206 | @classmethod 207 | def get_token(cls): 208 | """ 209 | Get token or None if not existent. 210 | """ 211 | try: 212 | with open(cls.path_token, 'r') as f: 213 | return f.read() 214 | except: 215 | # this is too wide. When Py2 is dead use: 216 | # `except FileNotFoundError:` instead 217 | return None 218 | 219 | @classmethod 220 | def delete_token(cls): 221 | """ 222 | Delete token. 223 | Do not fail if token does not exist. 224 | """ 225 | try: 226 | os.remove(cls.path_token) 227 | except: 228 | return 229 | -------------------------------------------------------------------------------- /transformers/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def get_constant_schedule(optimizer, last_epoch=-1): 28 | """ Create a schedule with a constant learning rate. 29 | """ 30 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 31 | 32 | 33 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 34 | """ Create a schedule with a constant learning rate preceded by a warmup 35 | period during which the learning rate increases linearly between 0 and 1. 36 | """ 37 | def lr_lambda(current_step): 38 | if current_step < num_warmup_steps: 39 | return float(current_step) / float(max(1.0, num_warmup_steps)) 40 | return 1. 41 | 42 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 43 | 44 | 45 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 46 | """ Create a schedule with a learning rate that decreases linearly after 47 | linearly increasing during a warmup period. 48 | """ 49 | def lr_lambda(current_step): 50 | if current_step < num_warmup_steps: 51 | return float(current_step) / float(max(1, num_warmup_steps)) 52 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 53 | 54 | return LambdaLR(optimizer, lr_lambda, last_epoch) 55 | 56 | 57 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1): 58 | """ Create a schedule with a learning rate that decreases following the 59 | values of the cosine function between 0 and `pi * cycles` after a warmup 60 | period during which it increases linearly between 0 and 1. 61 | """ 62 | def lr_lambda(current_step): 63 | if current_step < num_warmup_steps: 64 | return float(current_step) / float(max(1, num_warmup_steps)) 65 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 66 | return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress))) 67 | 68 | return LambdaLR(optimizer, lr_lambda, last_epoch) 69 | 70 | 71 | def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1): 72 | """ Create a schedule with a learning rate that decreases following the 73 | values of the cosine function with several hard restarts, after a warmup 74 | period during which it increases linearly between 0 and 1. 75 | """ 76 | def lr_lambda(current_step): 77 | if current_step < num_warmup_steps: 78 | return float(current_step) / float(max(1, num_warmup_steps)) 79 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 80 | if progress >= 1.: 81 | return 0. 82 | return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.)))) 83 | 84 | return LambdaLR(optimizer, lr_lambda, last_epoch) 85 | 86 | 87 | class AdamW(Optimizer): 88 | """ Implements Adam algorithm with weight decay fix. 89 | 90 | Parameters: 91 | lr (float): learning rate. Default 1e-3. 92 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 93 | eps (float): Adams epsilon. Default: 1e-6 94 | weight_decay (float): Weight decay. Default: 0.0 95 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 96 | """ 97 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 98 | if lr < 0.0: 99 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 104 | if not 0.0 <= eps: 105 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 106 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 107 | correct_bias=correct_bias) 108 | super(AdamW, self).__init__(params, defaults) 109 | 110 | def step(self, closure=None): 111 | """Performs a single optimization step. 112 | 113 | Arguments: 114 | closure (callable, optional): A closure that reevaluates the model 115 | and returns the loss. 116 | """ 117 | loss = None 118 | if closure is not None: 119 | loss = closure() 120 | 121 | for group in self.param_groups: 122 | for p in group['params']: 123 | if p.grad is None: 124 | continue 125 | grad = p.grad.data 126 | if grad.is_sparse: 127 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 128 | 129 | state = self.state[p] 130 | 131 | # State initialization 132 | if len(state) == 0: 133 | state['step'] = 0 134 | # Exponential moving average of gradient values 135 | state['exp_avg'] = torch.zeros_like(p.data) 136 | # Exponential moving average of squared gradient values 137 | state['exp_avg_sq'] = torch.zeros_like(p.data) 138 | 139 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 140 | beta1, beta2 = group['betas'] 141 | 142 | state['step'] += 1 143 | 144 | # Decay the first and second moment running average coefficient 145 | # In-place operations to update the averages at the same time 146 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 147 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 148 | denom = exp_avg_sq.sqrt().add_(group['eps']) 149 | 150 | step_size = group['lr'] 151 | if group['correct_bias']: # No bias correction for Bert 152 | bias_correction1 = 1.0 - beta1 ** state['step'] 153 | bias_correction2 = 1.0 - beta2 ** state['step'] 154 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 155 | 156 | p.data.addcdiv_(-step_size, exp_avg, denom) 157 | 158 | # Just adding the square of the weights to the loss function is *not* 159 | # the correct way of using L2 regularization/weight decay with Adam, 160 | # since that will interact with the m and v parameters in strange ways. 161 | # 162 | # Instead we want to decay the weights in a manner that doesn't interact 163 | # with the m/v parameters. This is equivalent to adding the square 164 | # of the weights to the loss with plain (non-momentum) SGD. 165 | # Add weight decay at the end (fixed version) 166 | if group['weight_decay'] > 0.0: 167 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 168 | 169 | return loss 170 | -------------------------------------------------------------------------------- /transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Neutralzz/syntax_aware_local_attention/e1b9397278fedb56b09320ad18e9cb9c543f5306/transformers/tests/__init__.py -------------------------------------------------------------------------------- /transformers/tests/configuration_common_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import copy 20 | import os 21 | import shutil 22 | import json 23 | import random 24 | import uuid 25 | 26 | import unittest 27 | import logging 28 | 29 | 30 | class ConfigTester(object): 31 | def __init__(self, parent, config_class=None, **kwargs): 32 | self.parent = parent 33 | self.config_class = config_class 34 | self.inputs_dict = kwargs 35 | 36 | def create_and_test_config_common_properties(self): 37 | config = self.config_class(**self.inputs_dict) 38 | self.parent.assertTrue(hasattr(config, 'vocab_size')) 39 | self.parent.assertTrue(hasattr(config, 'hidden_size')) 40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads')) 41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) 42 | 43 | def create_and_test_config_to_json_string(self): 44 | config = self.config_class(**self.inputs_dict) 45 | obj = json.loads(config.to_json_string()) 46 | for key, value in self.inputs_dict.items(): 47 | self.parent.assertEqual(obj[key], value) 48 | 49 | def create_and_test_config_to_json_file(self): 50 | config_first = self.config_class(**self.inputs_dict) 51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json") 52 | config_first.to_json_file(json_file_path) 53 | config_second = self.config_class.from_json_file(json_file_path) 54 | os.remove(json_file_path) 55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 56 | 57 | def run_common_tests(self): 58 | self.create_and_test_config_common_properties() 59 | self.create_and_test_config_to_json_string() 60 | self.create_and_test_config_to_json_file() 61 | 62 | if __name__ == "__main__": 63 | unittest.main() -------------------------------------------------------------------------------- /transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Neutralzz/syntax_aware_local_attention/e1b9397278fedb56b09320ad18e9cb9c543f5306/transformers/tests/fixtures/spiece.model -------------------------------------------------------------------------------- /transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Neutralzz/syntax_aware_local_attention/e1b9397278fedb56b09320ad18e9cb9c543f5306/transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /transformers/tests/hf_api_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function 16 | 17 | import os 18 | import six 19 | import time 20 | import unittest 21 | 22 | from transformers.hf_api import HfApi, S3Obj, PresignedUrl, HfFolder, HTTPError 23 | 24 | USER = "__DUMMY_TRANSFORMERS_USER__" 25 | PASS = "__DUMMY_TRANSFORMERS_PASS__" 26 | FILE_KEY = "Test-{}.txt".format(int(time.time())) 27 | FILE_PATH = os.path.join( 28 | os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt" 29 | ) 30 | 31 | 32 | 33 | class HfApiCommonTest(unittest.TestCase): 34 | _api = HfApi(endpoint="https://moon-staging.huggingface.co") 35 | 36 | 37 | class HfApiLoginTest(HfApiCommonTest): 38 | def test_login_invalid(self): 39 | with self.assertRaises(HTTPError): 40 | self._api.login(username=USER, password="fake") 41 | 42 | def test_login_valid(self): 43 | token = self._api.login(username=USER, password=PASS) 44 | self.assertIsInstance(token, six.string_types) 45 | 46 | 47 | class HfApiEndpointsTest(HfApiCommonTest): 48 | @classmethod 49 | def setUpClass(cls): 50 | """ 51 | Share this valid token in all tests below. 52 | """ 53 | cls._token = cls._api.login(username=USER, password=PASS) 54 | 55 | def test_whoami(self): 56 | user = self._api.whoami(token=self._token) 57 | self.assertEqual(user, USER) 58 | 59 | def test_presign(self): 60 | urls = self._api.presign(token=self._token, filename=FILE_KEY) 61 | self.assertIsInstance(urls, PresignedUrl) 62 | self.assertEqual(urls.type, "text/plain") 63 | 64 | def test_presign_and_upload(self): 65 | access_url = self._api.presign_and_upload( 66 | token=self._token, filename=FILE_KEY, filepath=FILE_PATH 67 | ) 68 | self.assertIsInstance(access_url, six.string_types) 69 | 70 | def test_list_objs(self): 71 | objs = self._api.list_objs(token=self._token) 72 | self.assertIsInstance(objs, list) 73 | if len(objs) > 0: 74 | o = objs[-1] 75 | self.assertIsInstance(o, S3Obj) 76 | 77 | 78 | 79 | class HfFolderTest(unittest.TestCase): 80 | def test_token_workflow(self): 81 | """ 82 | Test the whole token save/get/delete workflow, 83 | with the desired behavior with respect to non-existent tokens. 84 | """ 85 | token = "token-{}".format(int(time.time())) 86 | HfFolder.save_token(token) 87 | self.assertEqual( 88 | HfFolder.get_token(), 89 | token 90 | ) 91 | HfFolder.delete_token() 92 | HfFolder.delete_token() 93 | # ^^ not an error, we test that the 94 | # second call does not fail. 95 | self.assertEqual( 96 | HfFolder.get_token(), 97 | None 98 | ) 99 | 100 | 101 | if __name__ == "__main__": 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import logging 22 | 23 | from transformers import is_torch_available 24 | 25 | from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER 26 | 27 | if is_torch_available(): 28 | from transformers import (AutoConfig, BertConfig, 29 | AutoModel, BertModel, 30 | AutoModelWithLMHead, BertForMaskedLM, 31 | AutoModelForSequenceClassification, BertForSequenceClassification, 32 | AutoModelForQuestionAnswering, BertForQuestionAnswering) 33 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 34 | 35 | from .modeling_common_test import (CommonTestCases, ids_tensor) 36 | from .configuration_common_test import ConfigTester 37 | 38 | 39 | @require_torch 40 | class AutoModelTest(unittest.TestCase): 41 | @slow 42 | def test_model_from_pretrained(self): 43 | logging.basicConfig(level=logging.INFO) 44 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 45 | config = AutoConfig.from_pretrained(model_name) 46 | self.assertIsNotNone(config) 47 | self.assertIsInstance(config, BertConfig) 48 | 49 | model = AutoModel.from_pretrained(model_name) 50 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 51 | self.assertIsNotNone(model) 52 | self.assertIsInstance(model, BertModel) 53 | for value in loading_info.values(): 54 | self.assertEqual(len(value), 0) 55 | 56 | @slow 57 | def test_lmhead_model_from_pretrained(self): 58 | logging.basicConfig(level=logging.INFO) 59 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 60 | config = AutoConfig.from_pretrained(model_name) 61 | self.assertIsNotNone(config) 62 | self.assertIsInstance(config, BertConfig) 63 | 64 | model = AutoModelWithLMHead.from_pretrained(model_name) 65 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) 66 | self.assertIsNotNone(model) 67 | self.assertIsInstance(model, BertForMaskedLM) 68 | 69 | @slow 70 | def test_sequence_classification_model_from_pretrained(self): 71 | logging.basicConfig(level=logging.INFO) 72 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 73 | config = AutoConfig.from_pretrained(model_name) 74 | self.assertIsNotNone(config) 75 | self.assertIsInstance(config, BertConfig) 76 | 77 | model = AutoModelForSequenceClassification.from_pretrained(model_name) 78 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) 79 | self.assertIsNotNone(model) 80 | self.assertIsInstance(model, BertForSequenceClassification) 81 | 82 | @slow 83 | def test_question_answering_model_from_pretrained(self): 84 | logging.basicConfig(level=logging.INFO) 85 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 86 | config = AutoConfig.from_pretrained(model_name) 87 | self.assertIsNotNone(config) 88 | self.assertIsInstance(config, BertConfig) 89 | 90 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 91 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) 92 | self.assertIsNotNone(model) 93 | self.assertIsInstance(model, BertForQuestionAnswering) 94 | 95 | def test_from_pretrained_identifier(self): 96 | logging.basicConfig(level=logging.INFO) 97 | model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) 98 | self.assertIsInstance(model, BertForMaskedLM) 99 | 100 | 101 | if __name__ == "__main__": 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /transformers/tests/modeling_ctrl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | 18 | import unittest 19 | import shutil 20 | import pdb 21 | 22 | from transformers import is_torch_available 23 | 24 | if is_torch_available(): 25 | from transformers import (CTRLConfig, CTRLModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, 26 | CTRLLMHeadModel) 27 | 28 | from .modeling_common_test import (CommonTestCases, ids_tensor) 29 | from .configuration_common_test import ConfigTester 30 | from .utils import require_torch, slow, torch_device 31 | 32 | 33 | @require_torch 34 | class CTRLModelTest(CommonTestCases.CommonModelTester): 35 | 36 | all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else () 37 | test_pruning = False 38 | test_torchscript = False 39 | test_resize_embeddings = False 40 | test_head_masking = False 41 | 42 | class CTRLModelTester(object): 43 | 44 | def __init__(self, 45 | parent, 46 | batch_size=13, 47 | seq_length=7, 48 | is_training=True, 49 | use_token_type_ids=True, 50 | use_input_mask=True, 51 | use_labels=True, 52 | use_mc_token_ids=True, 53 | vocab_size=99, 54 | hidden_size=32, 55 | num_hidden_layers=5, 56 | num_attention_heads=4, 57 | intermediate_size=37, 58 | hidden_act="gelu", 59 | hidden_dropout_prob=0.1, 60 | attention_probs_dropout_prob=0.1, 61 | max_position_embeddings=512, 62 | type_vocab_size=16, 63 | type_sequence_label_size=2, 64 | initializer_range=0.02, 65 | num_labels=3, 66 | num_choices=4, 67 | scope=None, 68 | ): 69 | self.parent = parent 70 | self.batch_size = batch_size 71 | self.seq_length = seq_length 72 | self.is_training = is_training 73 | self.use_token_type_ids = use_token_type_ids 74 | self.use_input_mask = use_input_mask 75 | self.use_labels = use_labels 76 | self.use_mc_token_ids = use_mc_token_ids 77 | self.vocab_size = vocab_size 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.intermediate_size = intermediate_size 82 | self.hidden_act = hidden_act 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.type_sequence_label_size = type_sequence_label_size 88 | self.initializer_range = initializer_range 89 | self.num_labels = num_labels 90 | self.num_choices = num_choices 91 | self.scope = scope 92 | 93 | def prepare_config_and_inputs(self): 94 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 95 | 96 | input_mask = None 97 | if self.use_input_mask: 98 | input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) 99 | 100 | token_type_ids = None 101 | if self.use_token_type_ids: 102 | token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) 103 | 104 | mc_token_ids = None 105 | if self.use_mc_token_ids: 106 | mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) 107 | 108 | sequence_labels = None 109 | token_labels = None 110 | choice_labels = None 111 | if self.use_labels: 112 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) 113 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) 114 | choice_labels = ids_tensor([self.batch_size], self.num_choices) 115 | 116 | config = CTRLConfig( 117 | vocab_size_or_config_json_file=self.vocab_size, 118 | n_embd=self.hidden_size, 119 | n_layer=self.num_hidden_layers, 120 | n_head=self.num_attention_heads, 121 | # intermediate_size=self.intermediate_size, 122 | # hidden_act=self.hidden_act, 123 | # hidden_dropout_prob=self.hidden_dropout_prob, 124 | # attention_probs_dropout_prob=self.attention_probs_dropout_prob, 125 | n_positions=self.max_position_embeddings, 126 | n_ctx=self.max_position_embeddings 127 | # type_vocab_size=self.type_vocab_size, 128 | # initializer_range=self.initializer_range 129 | ) 130 | 131 | head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) 132 | 133 | return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels 134 | 135 | def check_loss_output(self, result): 136 | self.parent.assertListEqual( 137 | list(result["loss"].size()), 138 | []) 139 | 140 | def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): 141 | model = CTRLModel(config=config) 142 | model.to(torch_device) 143 | model.eval() 144 | 145 | model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) 146 | model(input_ids, token_type_ids=token_type_ids) 147 | sequence_output, presents = model(input_ids) 148 | 149 | result = { 150 | "sequence_output": sequence_output, 151 | "presents": presents, 152 | } 153 | self.parent.assertListEqual( 154 | list(result["sequence_output"].size()), 155 | [self.batch_size, self.seq_length, self.hidden_size]) 156 | self.parent.assertEqual(len(result["presents"]), config.n_layer) 157 | 158 | def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): 159 | model = CTRLLMHeadModel(config) 160 | model.to(torch_device) 161 | model.eval() 162 | 163 | loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) 164 | 165 | result = { 166 | "loss": loss, 167 | "lm_logits": lm_logits 168 | } 169 | self.parent.assertListEqual( 170 | list(result["loss"].size()), 171 | []) 172 | self.parent.assertListEqual( 173 | list(result["lm_logits"].size()), 174 | [self.batch_size, self.seq_length, self.vocab_size]) 175 | 176 | 177 | def prepare_config_and_inputs_for_common(self): 178 | config_and_inputs = self.prepare_config_and_inputs() 179 | 180 | (config, input_ids, input_mask, head_mask, token_type_ids, 181 | mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs 182 | 183 | inputs_dict = { 184 | 'input_ids': input_ids, 185 | 'token_type_ids': token_type_ids, 186 | 'head_mask': head_mask 187 | } 188 | 189 | return config, inputs_dict 190 | 191 | def setUp(self): 192 | self.model_tester = CTRLModelTest.CTRLModelTester(self) 193 | self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37) 194 | 195 | def test_config(self): 196 | self.config_tester.run_common_tests() 197 | 198 | def test_ctrl_model(self): 199 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 200 | self.model_tester.create_and_check_ctrl_model(*config_and_inputs) 201 | 202 | def test_ctrl_lm_head_model(self): 203 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 204 | self.model_tester.create_and_check_lm_head_model(*config_and_inputs) 205 | 206 | @slow 207 | def test_model_from_pretrained(self): 208 | cache_dir = "/tmp/transformers_test/" 209 | for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 210 | model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir) 211 | shutil.rmtree(cache_dir) 212 | self.assertIsNotNone(model) 213 | 214 | 215 | if __name__ == "__main__": 216 | unittest.main() 217 | -------------------------------------------------------------------------------- /transformers/tests/modeling_encoder_decoder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Hugging Face Inc. Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import unittest 18 | 19 | from transformers import is_torch_available 20 | from .utils import require_torch, slow 21 | 22 | if is_torch_available(): 23 | from transformers import BertModel, BertForMaskedLM, Model2Model 24 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 25 | 26 | 27 | @require_torch 28 | class EncoderDecoderModelTest(unittest.TestCase): 29 | @slow 30 | def test_model2model_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 33 | model = Model2Model.from_pretrained(model_name) 34 | self.assertIsInstance(model.encoder, BertModel) 35 | self.assertIsInstance(model.decoder, BertForMaskedLM) 36 | self.assertEqual(model.decoder.config.is_decoder, True) 37 | self.assertEqual(model.encoder.config.is_decoder, False) 38 | 39 | def test_model2model_from_pretrained_not_bert(self): 40 | logging.basicConfig(level=logging.INFO) 41 | with self.assertRaises(ValueError): 42 | _ = Model2Model.from_pretrained('roberta') 43 | 44 | with self.assertRaises(ValueError): 45 | _ = Model2Model.from_pretrained('distilbert') 46 | 47 | with self.assertRaises(ValueError): 48 | _ = Model2Model.from_pretrained('does-not-exist') 49 | 50 | 51 | if __name__ == "__main__": 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /transformers/tests/modeling_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | 22 | from transformers import is_torch_available 23 | 24 | if is_torch_available(): 25 | from transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, 26 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) 27 | 28 | from .modeling_common_test import (CommonTestCases, ids_tensor) 29 | from .configuration_common_test import ConfigTester 30 | from .utils import require_torch, slow, torch_device 31 | 32 | 33 | @require_torch 34 | class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): 35 | 36 | all_model_classes = (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else () 37 | 38 | class OpenAIGPTModelTester(object): 39 | 40 | def __init__(self, 41 | parent, 42 | batch_size=13, 43 | seq_length=7, 44 | is_training=True, 45 | use_token_type_ids=True, 46 | use_labels=True, 47 | vocab_size=99, 48 | hidden_size=32, 49 | num_hidden_layers=5, 50 | num_attention_heads=4, 51 | intermediate_size=37, 52 | hidden_act="gelu", 53 | hidden_dropout_prob=0.1, 54 | attention_probs_dropout_prob=0.1, 55 | max_position_embeddings=512, 56 | type_vocab_size=16, 57 | type_sequence_label_size=2, 58 | initializer_range=0.02, 59 | num_labels=3, 60 | num_choices=4, 61 | scope=None, 62 | ): 63 | self.parent = parent 64 | self.batch_size = batch_size 65 | self.seq_length = seq_length 66 | self.is_training = is_training 67 | self.use_token_type_ids = use_token_type_ids 68 | self.use_labels = use_labels 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.intermediate_size = intermediate_size 74 | self.hidden_act = hidden_act 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.type_sequence_label_size = type_sequence_label_size 80 | self.initializer_range = initializer_range 81 | self.num_labels = num_labels 82 | self.num_choices = num_choices 83 | self.scope = scope 84 | 85 | def prepare_config_and_inputs(self): 86 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 87 | 88 | token_type_ids = None 89 | if self.use_token_type_ids: 90 | token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) 91 | 92 | sequence_labels = None 93 | token_labels = None 94 | choice_labels = None 95 | if self.use_labels: 96 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) 97 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) 98 | choice_labels = ids_tensor([self.batch_size], self.num_choices) 99 | 100 | config = OpenAIGPTConfig( 101 | vocab_size_or_config_json_file=self.vocab_size, 102 | n_embd=self.hidden_size, 103 | n_layer=self.num_hidden_layers, 104 | n_head=self.num_attention_heads, 105 | # intermediate_size=self.intermediate_size, 106 | # hidden_act=self.hidden_act, 107 | # hidden_dropout_prob=self.hidden_dropout_prob, 108 | # attention_probs_dropout_prob=self.attention_probs_dropout_prob, 109 | n_positions=self.max_position_embeddings, 110 | n_ctx=self.max_position_embeddings 111 | # type_vocab_size=self.type_vocab_size, 112 | # initializer_range=self.initializer_range 113 | ) 114 | 115 | head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) 116 | 117 | return config, input_ids, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels 118 | 119 | def check_loss_output(self, result): 120 | self.parent.assertListEqual( 121 | list(result["loss"].size()), 122 | []) 123 | 124 | def create_and_check_openai_gpt_model(self, config, input_ids, head_mask, token_type_ids, *args): 125 | model = OpenAIGPTModel(config=config) 126 | model.to(torch_device) 127 | model.eval() 128 | 129 | model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) 130 | model(input_ids, token_type_ids=token_type_ids) 131 | (sequence_output,) = model(input_ids) 132 | 133 | result = { 134 | "sequence_output": sequence_output 135 | } 136 | self.parent.assertListEqual( 137 | list(result["sequence_output"].size()), 138 | [self.batch_size, self.seq_length, self.hidden_size]) 139 | 140 | def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): 141 | model = OpenAIGPTLMHeadModel(config) 142 | model.to(torch_device) 143 | model.eval() 144 | 145 | loss, lm_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) 146 | 147 | result = { 148 | "loss": loss, 149 | "lm_logits": lm_logits 150 | } 151 | 152 | self.parent.assertListEqual( 153 | list(result["loss"].size()), 154 | []) 155 | self.parent.assertListEqual( 156 | list(result["lm_logits"].size()), 157 | [self.batch_size, self.seq_length, self.vocab_size]) 158 | 159 | def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): 160 | model = OpenAIGPTDoubleHeadsModel(config) 161 | model.to(torch_device) 162 | model.eval() 163 | 164 | loss, lm_logits, mc_logits = model(input_ids, token_type_ids=token_type_ids, lm_labels=input_ids) 165 | 166 | result = { 167 | "loss": loss, 168 | "lm_logits": lm_logits 169 | } 170 | 171 | self.parent.assertListEqual( 172 | list(result["loss"].size()), 173 | []) 174 | self.parent.assertListEqual( 175 | list(result["lm_logits"].size()), 176 | [self.batch_size, self.seq_length, self.vocab_size]) 177 | 178 | def prepare_config_and_inputs_for_common(self): 179 | config_and_inputs = self.prepare_config_and_inputs() 180 | (config, input_ids, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs 181 | inputs_dict = { 182 | 'input_ids': input_ids, 183 | 'token_type_ids': token_type_ids, 184 | 'head_mask': head_mask 185 | } 186 | 187 | return config, inputs_dict 188 | 189 | def setUp(self): 190 | self.model_tester = OpenAIGPTModelTest.OpenAIGPTModelTester(self) 191 | self.config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37) 192 | 193 | def test_config(self): 194 | self.config_tester.run_common_tests() 195 | 196 | def test_openai_gpt_model(self): 197 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 198 | self.model_tester.create_and_check_openai_gpt_model(*config_and_inputs) 199 | 200 | def test_openai_gpt_lm_head_model(self): 201 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 202 | self.model_tester.create_and_check_lm_head_model(*config_and_inputs) 203 | 204 | def test_openai_gpt_double_lm_head_model(self): 205 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 206 | self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) 207 | 208 | @slow 209 | def test_model_from_pretrained(self): 210 | cache_dir = "/tmp/transformers_test/" 211 | for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 212 | model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) 213 | shutil.rmtree(cache_dir) 214 | self.assertIsNotNone(model) 215 | 216 | 217 | if __name__ == "__main__": 218 | unittest.main() 219 | -------------------------------------------------------------------------------- /transformers/tests/modeling_tf_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import logging 22 | 23 | from transformers import is_tf_available 24 | 25 | from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER 26 | 27 | if is_tf_available(): 28 | from transformers import (AutoConfig, BertConfig, 29 | TFAutoModel, TFBertModel, 30 | TFAutoModelWithLMHead, TFBertForMaskedLM, 31 | TFAutoModelForSequenceClassification, TFBertForSequenceClassification, 32 | TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering) 33 | from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP 34 | 35 | from .modeling_common_test import (CommonTestCases, ids_tensor) 36 | from .configuration_common_test import ConfigTester 37 | 38 | 39 | @require_tf 40 | class TFAutoModelTest(unittest.TestCase): 41 | @slow 42 | def test_model_from_pretrained(self): 43 | import h5py 44 | self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) 45 | 46 | logging.basicConfig(level=logging.INFO) 47 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 48 | for model_name in ['bert-base-uncased']: 49 | config = AutoConfig.from_pretrained(model_name, force_download=True) 50 | self.assertIsNotNone(config) 51 | self.assertIsInstance(config, BertConfig) 52 | 53 | model = TFAutoModel.from_pretrained(model_name, force_download=True) 54 | self.assertIsNotNone(model) 55 | self.assertIsInstance(model, TFBertModel) 56 | 57 | @slow 58 | def test_lmhead_model_from_pretrained(self): 59 | logging.basicConfig(level=logging.INFO) 60 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 61 | for model_name in ['bert-base-uncased']: 62 | config = AutoConfig.from_pretrained(model_name, force_download=True) 63 | self.assertIsNotNone(config) 64 | self.assertIsInstance(config, BertConfig) 65 | 66 | model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True) 67 | self.assertIsNotNone(model) 68 | self.assertIsInstance(model, TFBertForMaskedLM) 69 | 70 | @slow 71 | def test_sequence_classification_model_from_pretrained(self): 72 | logging.basicConfig(level=logging.INFO) 73 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 74 | for model_name in ['bert-base-uncased']: 75 | config = AutoConfig.from_pretrained(model_name, force_download=True) 76 | self.assertIsNotNone(config) 77 | self.assertIsInstance(config, BertConfig) 78 | 79 | model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True) 80 | self.assertIsNotNone(model) 81 | self.assertIsInstance(model, TFBertForSequenceClassification) 82 | 83 | @slow 84 | def test_question_answering_model_from_pretrained(self): 85 | logging.basicConfig(level=logging.INFO) 86 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 87 | for model_name in ['bert-base-uncased']: 88 | config = AutoConfig.from_pretrained(model_name, force_download=True) 89 | self.assertIsNotNone(config) 90 | self.assertIsInstance(config, BertConfig) 91 | 92 | model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True) 93 | self.assertIsNotNone(model) 94 | self.assertIsInstance(model, TFBertForQuestionAnswering) 95 | 96 | def test_from_pretrained_identifier(self): 97 | logging.basicConfig(level=logging.INFO) 98 | model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True) 99 | self.assertIsInstance(model, TFBertForMaskedLM) 100 | 101 | 102 | if __name__ == "__main__": 103 | unittest.main() 104 | -------------------------------------------------------------------------------- /transformers/tests/modeling_tf_ctrl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import sys 22 | 23 | from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) 24 | from .configuration_common_test import ConfigTester 25 | from .utils import require_tf, slow 26 | 27 | from transformers import CTRLConfig, is_tf_available 28 | 29 | if is_tf_available(): 30 | import tensorflow as tf 31 | from transformers.modeling_tf_ctrl import (TFCTRLModel, TFCTRLLMHeadModel, 32 | TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) 33 | 34 | 35 | @require_tf 36 | class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester): 37 | 38 | all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else () 39 | 40 | class TFCTRLModelTester(object): 41 | 42 | def __init__(self, 43 | parent, 44 | batch_size=13, 45 | seq_length=7, 46 | is_training=True, 47 | use_token_type_ids=True, 48 | use_input_mask=True, 49 | use_labels=True, 50 | use_mc_token_ids=True, 51 | vocab_size=99, 52 | hidden_size=32, 53 | num_hidden_layers=5, 54 | num_attention_heads=4, 55 | intermediate_size=37, 56 | hidden_act="gelu", 57 | hidden_dropout_prob=0.1, 58 | attention_probs_dropout_prob=0.1, 59 | max_position_embeddings=512, 60 | type_vocab_size=16, 61 | type_sequence_label_size=2, 62 | initializer_range=0.02, 63 | num_labels=3, 64 | num_choices=4, 65 | scope=None, 66 | ): 67 | self.parent = parent 68 | self.batch_size = batch_size 69 | self.seq_length = seq_length 70 | self.is_training = is_training 71 | self.use_token_type_ids = use_token_type_ids 72 | self.use_input_mask = use_input_mask 73 | self.use_labels = use_labels 74 | self.use_mc_token_ids = use_mc_token_ids 75 | self.vocab_size = vocab_size 76 | self.hidden_size = hidden_size 77 | self.num_hidden_layers = num_hidden_layers 78 | self.num_attention_heads = num_attention_heads 79 | self.intermediate_size = intermediate_size 80 | self.hidden_act = hidden_act 81 | self.hidden_dropout_prob = hidden_dropout_prob 82 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 83 | self.max_position_embeddings = max_position_embeddings 84 | self.type_vocab_size = type_vocab_size 85 | self.type_sequence_label_size = type_sequence_label_size 86 | self.initializer_range = initializer_range 87 | self.num_labels = num_labels 88 | self.num_choices = num_choices 89 | self.scope = scope 90 | 91 | def prepare_config_and_inputs(self): 92 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 93 | 94 | input_mask = None 95 | if self.use_input_mask: 96 | input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) 97 | 98 | token_type_ids = None 99 | if self.use_token_type_ids: 100 | token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) 101 | 102 | mc_token_ids = None 103 | if self.use_mc_token_ids: 104 | mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) 105 | 106 | sequence_labels = None 107 | token_labels = None 108 | choice_labels = None 109 | if self.use_labels: 110 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) 111 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) 112 | choice_labels = ids_tensor([self.batch_size], self.num_choices) 113 | 114 | config = CTRLConfig( 115 | vocab_size_or_config_json_file=self.vocab_size, 116 | n_embd=self.hidden_size, 117 | n_layer=self.num_hidden_layers, 118 | n_head=self.num_attention_heads, 119 | # intermediate_size=self.intermediate_size, 120 | # hidden_act=self.hidden_act, 121 | # hidden_dropout_prob=self.hidden_dropout_prob, 122 | # attention_probs_dropout_prob=self.attention_probs_dropout_prob, 123 | n_positions=self.max_position_embeddings, 124 | n_ctx=self.max_position_embeddings 125 | # type_vocab_size=self.type_vocab_size, 126 | # initializer_range=self.initializer_range 127 | ) 128 | 129 | head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) 130 | 131 | return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels 132 | 133 | def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): 134 | model = TFCTRLModel(config=config) 135 | inputs = {'input_ids': input_ids, 136 | 'attention_mask': input_mask, 137 | 'token_type_ids': token_type_ids} 138 | sequence_output = model(inputs)[0] 139 | 140 | inputs = [input_ids, None, input_mask] # None is the input for 'past' 141 | sequence_output = model(inputs)[0] 142 | 143 | sequence_output = model(input_ids)[0] 144 | 145 | result = { 146 | "sequence_output": sequence_output.numpy(), 147 | } 148 | self.parent.assertListEqual( 149 | list(result["sequence_output"].shape), 150 | [self.batch_size, self.seq_length, self.hidden_size]) 151 | 152 | 153 | def create_and_check_ctrl_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): 154 | model = TFCTRLLMHeadModel(config=config) 155 | inputs = {'input_ids': input_ids, 156 | 'attention_mask': input_mask, 157 | 'token_type_ids': token_type_ids} 158 | prediction_scores = model(inputs)[0] 159 | result = { 160 | "prediction_scores": prediction_scores.numpy(), 161 | } 162 | self.parent.assertListEqual( 163 | list(result["prediction_scores"].shape), 164 | [self.batch_size, self.seq_length, self.vocab_size]) 165 | 166 | def prepare_config_and_inputs_for_common(self): 167 | config_and_inputs = self.prepare_config_and_inputs() 168 | 169 | (config, input_ids, input_mask, head_mask, token_type_ids, 170 | mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs 171 | 172 | inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} 173 | return config, inputs_dict 174 | 175 | def setUp(self): 176 | self.model_tester = TFCTRLModelTest.TFCTRLModelTester(self) 177 | self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37) 178 | 179 | def test_config(self): 180 | self.config_tester.run_common_tests() 181 | 182 | def test_ctrl_model(self): 183 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 184 | self.model_tester.create_and_check_ctrl_model(*config_and_inputs) 185 | 186 | def test_ctrl_lm_head(self): 187 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 188 | self.model_tester.create_and_check_ctrl_lm_head(*config_and_inputs) 189 | 190 | @slow 191 | def test_model_from_pretrained(self): 192 | cache_dir = "/tmp/transformers_test/" 193 | for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 194 | model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir) 195 | shutil.rmtree(cache_dir) 196 | self.assertIsNotNone(model) 197 | 198 | if __name__ == "__main__": 199 | unittest.main() 200 | 201 | -------------------------------------------------------------------------------- /transformers/tests/modeling_tf_distilbert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | 21 | from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) 22 | from .configuration_common_test import ConfigTester 23 | from .utils import require_tf, slow 24 | 25 | from transformers import DistilBertConfig, is_tf_available 26 | 27 | if is_tf_available(): 28 | import tensorflow as tf 29 | from transformers.modeling_tf_distilbert import (TFDistilBertModel, 30 | TFDistilBertForMaskedLM, 31 | TFDistilBertForQuestionAnswering, 32 | TFDistilBertForSequenceClassification) 33 | 34 | 35 | @require_tf 36 | class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): 37 | 38 | all_model_classes = (TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, 39 | TFDistilBertForSequenceClassification) if is_tf_available() else None 40 | test_pruning = True 41 | test_torchscript = True 42 | test_resize_embeddings = True 43 | test_head_masking = True 44 | 45 | class TFDistilBertModelTester(object): 46 | 47 | def __init__(self, 48 | parent, 49 | batch_size=13, 50 | seq_length=7, 51 | is_training=True, 52 | use_input_mask=True, 53 | use_token_type_ids=False, 54 | use_labels=True, 55 | vocab_size=99, 56 | hidden_size=32, 57 | num_hidden_layers=5, 58 | num_attention_heads=4, 59 | intermediate_size=37, 60 | hidden_act="gelu", 61 | hidden_dropout_prob=0.1, 62 | attention_probs_dropout_prob=0.1, 63 | max_position_embeddings=512, 64 | type_vocab_size=16, 65 | type_sequence_label_size=2, 66 | initializer_range=0.02, 67 | num_labels=3, 68 | num_choices=4, 69 | scope=None, 70 | ): 71 | self.parent = parent 72 | self.batch_size = batch_size 73 | self.seq_length = seq_length 74 | self.is_training = is_training 75 | self.use_input_mask = use_input_mask 76 | self.use_token_type_ids = use_token_type_ids 77 | self.use_labels = use_labels 78 | self.vocab_size = vocab_size 79 | self.hidden_size = hidden_size 80 | self.num_hidden_layers = num_hidden_layers 81 | self.num_attention_heads = num_attention_heads 82 | self.intermediate_size = intermediate_size 83 | self.hidden_act = hidden_act 84 | self.hidden_dropout_prob = hidden_dropout_prob 85 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 86 | self.max_position_embeddings = max_position_embeddings 87 | self.type_vocab_size = type_vocab_size 88 | self.type_sequence_label_size = type_sequence_label_size 89 | self.initializer_range = initializer_range 90 | self.num_labels = num_labels 91 | self.num_choices = num_choices 92 | self.scope = scope 93 | 94 | def prepare_config_and_inputs(self): 95 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 96 | 97 | input_mask = None 98 | if self.use_input_mask: 99 | input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) 100 | 101 | sequence_labels = None 102 | token_labels = None 103 | choice_labels = None 104 | if self.use_labels: 105 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) 106 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) 107 | choice_labels = ids_tensor([self.batch_size], self.num_choices) 108 | 109 | config = DistilBertConfig( 110 | vocab_size_or_config_json_file=self.vocab_size, 111 | dim=self.hidden_size, 112 | n_layers=self.num_hidden_layers, 113 | n_heads=self.num_attention_heads, 114 | hidden_dim=self.intermediate_size, 115 | hidden_act=self.hidden_act, 116 | dropout=self.hidden_dropout_prob, 117 | attention_dropout=self.attention_probs_dropout_prob, 118 | max_position_embeddings=self.max_position_embeddings, 119 | initializer_range=self.initializer_range) 120 | 121 | return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels 122 | 123 | def create_and_check_distilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): 124 | model = TFDistilBertModel(config=config) 125 | inputs = {'input_ids': input_ids, 126 | 'attention_mask': input_mask} 127 | 128 | outputs = model(inputs) 129 | sequence_output = outputs[0] 130 | 131 | inputs = [input_ids, input_mask] 132 | 133 | (sequence_output,) = model(inputs) 134 | 135 | result = { 136 | "sequence_output": sequence_output.numpy(), 137 | } 138 | self.parent.assertListEqual( 139 | list(result["sequence_output"].shape), 140 | [self.batch_size, self.seq_length, self.hidden_size]) 141 | 142 | def create_and_check_distilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): 143 | model = TFDistilBertForMaskedLM(config=config) 144 | inputs = {'input_ids': input_ids, 145 | 'attention_mask': input_mask} 146 | (prediction_scores,) = model(inputs) 147 | result = { 148 | "prediction_scores": prediction_scores.numpy(), 149 | } 150 | self.parent.assertListEqual( 151 | list(result["prediction_scores"].shape), 152 | [self.batch_size, self.seq_length, self.vocab_size]) 153 | 154 | def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): 155 | model = TFDistilBertForQuestionAnswering(config=config) 156 | inputs = {'input_ids': input_ids, 157 | 'attention_mask': input_mask} 158 | start_logits, end_logits = model(inputs) 159 | result = { 160 | "start_logits": start_logits.numpy(), 161 | "end_logits": end_logits.numpy(), 162 | } 163 | self.parent.assertListEqual( 164 | list(result["start_logits"].shape), 165 | [self.batch_size, self.seq_length]) 166 | self.parent.assertListEqual( 167 | list(result["end_logits"].shape), 168 | [self.batch_size, self.seq_length]) 169 | 170 | def create_and_check_distilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): 171 | config.num_labels = self.num_labels 172 | model = TFDistilBertForSequenceClassification(config) 173 | inputs = {'input_ids': input_ids, 174 | 'attention_mask': input_mask} 175 | (logits,) = model(inputs) 176 | result = { 177 | "logits": logits.numpy(), 178 | } 179 | self.parent.assertListEqual( 180 | list(result["logits"].shape), 181 | [self.batch_size, self.num_labels]) 182 | 183 | def prepare_config_and_inputs_for_common(self): 184 | config_and_inputs = self.prepare_config_and_inputs() 185 | (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs 186 | inputs_dict = {'input_ids': input_ids, 'attention_mask': input_mask} 187 | return config, inputs_dict 188 | 189 | def setUp(self): 190 | self.model_tester = TFDistilBertModelTest.TFDistilBertModelTester(self) 191 | self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37) 192 | 193 | def test_config(self): 194 | self.config_tester.run_common_tests() 195 | 196 | def test_distilbert_model(self): 197 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 198 | self.model_tester.create_and_check_distilbert_model(*config_and_inputs) 199 | 200 | def test_for_masked_lm(self): 201 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 202 | self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs) 203 | 204 | def test_for_question_answering(self): 205 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 206 | self.model_tester.create_and_check_distilbert_for_question_answering(*config_and_inputs) 207 | 208 | def test_for_sequence_classification(self): 209 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 210 | self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs) 211 | 212 | # @slow 213 | # def test_model_from_pretrained(self): 214 | # cache_dir = "/tmp/transformers_test/" 215 | # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 216 | # model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) 217 | # shutil.rmtree(cache_dir) 218 | # self.assertIsNotNone(model) 219 | 220 | if __name__ == "__main__": 221 | unittest.main() 222 | -------------------------------------------------------------------------------- /transformers/tests/modeling_tf_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import random 21 | import shutil 22 | 23 | from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) 24 | from .configuration_common_test import ConfigTester 25 | from .utils import require_tf, slow 26 | 27 | from transformers import TransfoXLConfig, is_tf_available 28 | 29 | if is_tf_available(): 30 | import tensorflow as tf 31 | from transformers.modeling_tf_transfo_xl import (TFTransfoXLModel, 32 | TFTransfoXLLMHeadModel, 33 | TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) 34 | 35 | 36 | @require_tf 37 | class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): 38 | 39 | all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else () 40 | test_pruning = False 41 | test_torchscript = False 42 | test_resize_embeddings = False 43 | 44 | class TFTransfoXLModelTester(object): 45 | 46 | def __init__(self, 47 | parent, 48 | batch_size=13, 49 | seq_length=7, 50 | mem_len=30, 51 | clamp_len=15, 52 | is_training=True, 53 | use_labels=True, 54 | vocab_size=99, 55 | cutoffs=[10, 50, 80], 56 | hidden_size=32, 57 | d_embed=32, 58 | num_attention_heads=4, 59 | d_head=8, 60 | d_inner=128, 61 | div_val=2, 62 | num_hidden_layers=5, 63 | scope=None, 64 | seed=1, 65 | ): 66 | self.parent = parent 67 | self.batch_size = batch_size 68 | self.seq_length = seq_length 69 | self.mem_len = mem_len 70 | self.key_len = seq_length + mem_len 71 | self.clamp_len = clamp_len 72 | self.is_training = is_training 73 | self.use_labels = use_labels 74 | self.vocab_size = vocab_size 75 | self.cutoffs = cutoffs 76 | self.hidden_size = hidden_size 77 | self.d_embed = d_embed 78 | self.num_attention_heads = num_attention_heads 79 | self.d_head = d_head 80 | self.d_inner = d_inner 81 | self.div_val = div_val 82 | self.num_hidden_layers = num_hidden_layers 83 | self.scope = scope 84 | self.seed = seed 85 | 86 | def prepare_config_and_inputs(self): 87 | input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 88 | input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 89 | 90 | lm_labels = None 91 | if self.use_labels: 92 | lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 93 | 94 | config = TransfoXLConfig( 95 | vocab_size_or_config_json_file=self.vocab_size, 96 | mem_len=self.mem_len, 97 | clamp_len=self.clamp_len, 98 | cutoffs=self.cutoffs, 99 | d_model=self.hidden_size, 100 | d_embed=self.d_embed, 101 | n_head=self.num_attention_heads, 102 | d_head=self.d_head, 103 | d_inner=self.d_inner, 104 | div_val=self.div_val, 105 | n_layer=self.num_hidden_layers) 106 | 107 | return (config, input_ids_1, input_ids_2, lm_labels) 108 | 109 | def set_seed(self): 110 | random.seed(self.seed) 111 | tf.random.set_seed(self.seed) 112 | 113 | def create_and_check_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels): 114 | model = TFTransfoXLModel(config) 115 | 116 | hidden_states_1, mems_1 = model(input_ids_1) 117 | 118 | inputs = {'input_ids': input_ids_2, 119 | 'mems': mems_1} 120 | 121 | hidden_states_2, mems_2 = model(inputs) 122 | 123 | result = { 124 | "hidden_states_1": hidden_states_1.numpy(), 125 | "mems_1": [mem.numpy() for mem in mems_1], 126 | "hidden_states_2": hidden_states_2.numpy(), 127 | "mems_2": [mem.numpy() for mem in mems_2], 128 | } 129 | 130 | self.parent.assertListEqual( 131 | list(result["hidden_states_1"].shape), 132 | [self.batch_size, self.seq_length, self.hidden_size]) 133 | self.parent.assertListEqual( 134 | list(result["hidden_states_2"].shape), 135 | [self.batch_size, self.seq_length, self.hidden_size]) 136 | self.parent.assertListEqual( 137 | list(list(mem.shape) for mem in result["mems_1"]), 138 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 139 | self.parent.assertListEqual( 140 | list(list(mem.shape) for mem in result["mems_2"]), 141 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 142 | 143 | 144 | def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): 145 | model = TFTransfoXLLMHeadModel(config) 146 | 147 | lm_logits_1, mems_1 = model(input_ids_1) 148 | 149 | inputs = {'input_ids': input_ids_1, 150 | 'labels': lm_labels} 151 | _, mems_1 = model(inputs) 152 | 153 | lm_logits_2, mems_2 = model([input_ids_2, mems_1]) 154 | 155 | inputs = {'input_ids': input_ids_1, 156 | 'mems': mems_1, 157 | 'labels': lm_labels} 158 | 159 | _, mems_2 = model(inputs) 160 | 161 | result = { 162 | "mems_1": [mem.numpy() for mem in mems_1], 163 | "lm_logits_1": lm_logits_1.numpy(), 164 | "mems_2": [mem.numpy() for mem in mems_2], 165 | "lm_logits_2": lm_logits_2.numpy(), 166 | } 167 | 168 | self.parent.assertListEqual( 169 | list(result["lm_logits_1"].shape), 170 | [self.batch_size, self.seq_length, self.vocab_size]) 171 | self.parent.assertListEqual( 172 | list(list(mem.shape) for mem in result["mems_1"]), 173 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 174 | 175 | self.parent.assertListEqual( 176 | list(result["lm_logits_2"].shape), 177 | [self.batch_size, self.seq_length, self.vocab_size]) 178 | self.parent.assertListEqual( 179 | list(list(mem.shape) for mem in result["mems_2"]), 180 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 181 | 182 | def prepare_config_and_inputs_for_common(self): 183 | config_and_inputs = self.prepare_config_and_inputs() 184 | (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs 185 | inputs_dict = {'input_ids': input_ids_1} 186 | return config, inputs_dict 187 | 188 | 189 | def setUp(self): 190 | self.model_tester = TFTransfoXLModelTest.TFTransfoXLModelTester(self) 191 | self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37) 192 | 193 | def test_config(self): 194 | self.config_tester.run_common_tests() 195 | 196 | def test_transfo_xl_model(self): 197 | self.model_tester.set_seed() 198 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 199 | self.model_tester.create_and_check_transfo_xl_model(*config_and_inputs) 200 | 201 | def test_transfo_xl_lm_head(self): 202 | self.model_tester.set_seed() 203 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 204 | self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs) 205 | 206 | @slow 207 | def test_model_from_pretrained(self): 208 | cache_dir = "/tmp/transformers_test/" 209 | for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 210 | model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) 211 | shutil.rmtree(cache_dir) 212 | self.assertIsNotNone(model) 213 | 214 | 215 | if __name__ == "__main__": 216 | unittest.main() 217 | -------------------------------------------------------------------------------- /transformers/tests/modeling_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import random 21 | import shutil 22 | 23 | from transformers import is_torch_available 24 | 25 | if is_torch_available(): 26 | import torch 27 | from transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) 28 | from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP 29 | 30 | from .modeling_common_test import (CommonTestCases, ids_tensor) 31 | from .configuration_common_test import ConfigTester 32 | from .utils import require_torch, slow, torch_device 33 | 34 | 35 | @require_torch 36 | class TransfoXLModelTest(CommonTestCases.CommonModelTester): 37 | 38 | all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () 39 | test_pruning = False 40 | test_torchscript = False 41 | test_resize_embeddings = False 42 | 43 | class TransfoXLModelTester(object): 44 | 45 | def __init__(self, 46 | parent, 47 | batch_size=13, 48 | seq_length=7, 49 | mem_len=30, 50 | clamp_len=15, 51 | is_training=True, 52 | use_labels=True, 53 | vocab_size=99, 54 | cutoffs=[10, 50, 80], 55 | hidden_size=32, 56 | d_embed=32, 57 | num_attention_heads=4, 58 | d_head=8, 59 | d_inner=128, 60 | div_val=2, 61 | num_hidden_layers=5, 62 | scope=None, 63 | seed=1, 64 | ): 65 | self.parent = parent 66 | self.batch_size = batch_size 67 | self.seq_length = seq_length 68 | self.mem_len = mem_len 69 | self.key_len = seq_length + mem_len 70 | self.clamp_len = clamp_len 71 | self.is_training = is_training 72 | self.use_labels = use_labels 73 | self.vocab_size = vocab_size 74 | self.cutoffs = cutoffs 75 | self.hidden_size = hidden_size 76 | self.d_embed = d_embed 77 | self.num_attention_heads = num_attention_heads 78 | self.d_head = d_head 79 | self.d_inner = d_inner 80 | self.div_val = div_val 81 | self.num_hidden_layers = num_hidden_layers 82 | self.scope = scope 83 | self.seed = seed 84 | 85 | def prepare_config_and_inputs(self): 86 | input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 87 | input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 88 | 89 | lm_labels = None 90 | if self.use_labels: 91 | lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 92 | 93 | config = TransfoXLConfig( 94 | vocab_size_or_config_json_file=self.vocab_size, 95 | mem_len=self.mem_len, 96 | clamp_len=self.clamp_len, 97 | cutoffs=self.cutoffs, 98 | d_model=self.hidden_size, 99 | d_embed=self.d_embed, 100 | n_head=self.num_attention_heads, 101 | d_head=self.d_head, 102 | d_inner=self.d_inner, 103 | div_val=self.div_val, 104 | n_layer=self.num_hidden_layers) 105 | 106 | return (config, input_ids_1, input_ids_2, lm_labels) 107 | 108 | def set_seed(self): 109 | random.seed(self.seed) 110 | torch.manual_seed(self.seed) 111 | 112 | def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels): 113 | model = TransfoXLModel(config) 114 | model.to(torch_device) 115 | model.eval() 116 | 117 | hidden_states_1, mems_1 = model(input_ids_1) 118 | hidden_states_2, mems_2 = model(input_ids_2, mems_1) 119 | outputs = { 120 | "hidden_states_1": hidden_states_1, 121 | "mems_1": mems_1, 122 | "hidden_states_2": hidden_states_2, 123 | "mems_2": mems_2, 124 | } 125 | return outputs 126 | 127 | def check_transfo_xl_model_output(self, result): 128 | self.parent.assertListEqual( 129 | list(result["hidden_states_1"].size()), 130 | [self.batch_size, self.seq_length, self.hidden_size]) 131 | self.parent.assertListEqual( 132 | list(result["hidden_states_2"].size()), 133 | [self.batch_size, self.seq_length, self.hidden_size]) 134 | self.parent.assertListEqual( 135 | list(list(mem.size()) for mem in result["mems_1"]), 136 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 137 | self.parent.assertListEqual( 138 | list(list(mem.size()) for mem in result["mems_2"]), 139 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 140 | 141 | 142 | def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): 143 | model = TransfoXLLMHeadModel(config) 144 | model.to(torch_device) 145 | model.eval() 146 | 147 | lm_logits_1, mems_1 = model(input_ids_1) 148 | loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels) 149 | lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1) 150 | loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1) 151 | 152 | outputs = { 153 | "loss_1": loss_1, 154 | "mems_1": mems_1, 155 | "lm_logits_1": lm_logits_1, 156 | "loss_2": loss_2, 157 | "mems_2": mems_2, 158 | "lm_logits_2": lm_logits_2, 159 | } 160 | return outputs 161 | 162 | def check_transfo_xl_lm_head_output(self, result): 163 | self.parent.assertListEqual( 164 | list(result["loss_1"].size()), 165 | [self.batch_size, self.seq_length]) 166 | self.parent.assertListEqual( 167 | list(result["lm_logits_1"].size()), 168 | [self.batch_size, self.seq_length, self.vocab_size]) 169 | self.parent.assertListEqual( 170 | list(list(mem.size()) for mem in result["mems_1"]), 171 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 172 | 173 | self.parent.assertListEqual( 174 | list(result["loss_2"].size()), 175 | [self.batch_size, self.seq_length]) 176 | self.parent.assertListEqual( 177 | list(result["lm_logits_2"].size()), 178 | [self.batch_size, self.seq_length, self.vocab_size]) 179 | self.parent.assertListEqual( 180 | list(list(mem.size()) for mem in result["mems_2"]), 181 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 182 | 183 | def prepare_config_and_inputs_for_common(self): 184 | config_and_inputs = self.prepare_config_and_inputs() 185 | (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs 186 | inputs_dict = {'input_ids': input_ids_1} 187 | return config, inputs_dict 188 | 189 | 190 | def setUp(self): 191 | self.model_tester = TransfoXLModelTest.TransfoXLModelTester(self) 192 | self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37) 193 | 194 | def test_config(self): 195 | self.config_tester.run_common_tests() 196 | 197 | def test_transfo_xl_model(self): 198 | self.model_tester.set_seed() 199 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 200 | output_result = self.model_tester.create_transfo_xl_model(*config_and_inputs) 201 | self.model_tester.check_transfo_xl_model_output(output_result) 202 | 203 | def test_transfo_xl_lm_head(self): 204 | self.model_tester.set_seed() 205 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 206 | output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) 207 | self.model_tester.check_transfo_xl_lm_head_output(output_result) 208 | 209 | @slow 210 | def test_model_from_pretrained(self): 211 | cache_dir = "/tmp/transformers_test/" 212 | for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 213 | model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) 214 | shutil.rmtree(cache_dir) 215 | self.assertIsNotNone(model) 216 | 217 | 218 | if __name__ == "__main__": 219 | unittest.main() 220 | -------------------------------------------------------------------------------- /transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | 22 | from transformers import is_torch_available 23 | 24 | if is_torch_available(): 25 | import torch 26 | 27 | from transformers import (AdamW, 28 | get_constant_schedule, 29 | get_constant_schedule_with_warmup, 30 | get_cosine_schedule_with_warmup, 31 | get_cosine_with_hard_restarts_schedule_with_warmup, 32 | get_linear_schedule_with_warmup) 33 | 34 | from .tokenization_tests_commons import TemporaryDirectory 35 | from .utils import require_torch 36 | 37 | 38 | def unwrap_schedule(scheduler, num_steps=10): 39 | lrs = [] 40 | for _ in range(num_steps): 41 | scheduler.step() 42 | lrs.append(scheduler.get_lr()) 43 | return lrs 44 | 45 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 46 | lrs = [] 47 | for step in range(num_steps): 48 | scheduler.step() 49 | lrs.append(scheduler.get_lr()) 50 | if step == num_steps // 2: 51 | with TemporaryDirectory() as tmpdirname: 52 | file_name = os.path.join(tmpdirname, 'schedule.bin') 53 | torch.save(scheduler.state_dict(), file_name) 54 | 55 | state_dict = torch.load(file_name) 56 | scheduler.load_state_dict(state_dict) 57 | return lrs 58 | 59 | @require_torch 60 | class OptimizationTest(unittest.TestCase): 61 | 62 | def assertListAlmostEqual(self, list1, list2, tol): 63 | self.assertEqual(len(list1), len(list2)) 64 | for a, b in zip(list1, list2): 65 | self.assertAlmostEqual(a, b, delta=tol) 66 | 67 | def test_adam_w(self): 68 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 69 | target = torch.tensor([0.4, 0.2, -0.5]) 70 | criterion = torch.nn.MSELoss() 71 | # No warmup, constant schedule, no gradient clipping 72 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 73 | for _ in range(100): 74 | loss = criterion(w, target) 75 | loss.backward() 76 | optimizer.step() 77 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 78 | w.grad.zero_() 79 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 80 | 81 | 82 | @require_torch 83 | class ScheduleInitTest(unittest.TestCase): 84 | m = torch.nn.Linear(50, 50) if is_torch_available() else None 85 | optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None 86 | num_steps = 10 87 | 88 | def assertListAlmostEqual(self, list1, list2, tol): 89 | self.assertEqual(len(list1), len(list2)) 90 | for a, b in zip(list1, list2): 91 | self.assertAlmostEqual(a, b, delta=tol) 92 | 93 | def test_constant_scheduler(self): 94 | scheduler = get_constant_schedule(self.optimizer) 95 | lrs = unwrap_schedule(scheduler, self.num_steps) 96 | expected_learning_rates = [10.] * self.num_steps 97 | self.assertEqual(len(lrs[0]), 1) 98 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 99 | 100 | scheduler = get_constant_schedule(self.optimizer) 101 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 102 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 103 | 104 | def test_warmup_constant_scheduler(self): 105 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) 106 | lrs = unwrap_schedule(scheduler, self.num_steps) 107 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 108 | self.assertEqual(len(lrs[0]), 1) 109 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 110 | 111 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) 112 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 113 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 114 | 115 | def test_warmup_linear_scheduler(self): 116 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 117 | lrs = unwrap_schedule(scheduler, self.num_steps) 118 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 119 | self.assertEqual(len(lrs[0]), 1) 120 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 121 | 122 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 123 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 124 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 125 | 126 | def test_warmup_cosine_scheduler(self): 127 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 128 | lrs = unwrap_schedule(scheduler, self.num_steps) 129 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 130 | self.assertEqual(len(lrs[0]), 1) 131 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 132 | 133 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 134 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 135 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 136 | 137 | def test_warmup_cosine_hard_restart_scheduler(self): 138 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) 139 | lrs = unwrap_schedule(scheduler, self.num_steps) 140 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 141 | self.assertEqual(len(lrs[0]), 1) 142 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 143 | 144 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) 145 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 146 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 147 | 148 | 149 | if __name__ == "__main__": 150 | unittest.main() 151 | -------------------------------------------------------------------------------- /transformers/tests/optimization_tf_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import unittest 6 | 7 | from transformers import is_tf_available 8 | 9 | from .utils import require_tf 10 | 11 | if is_tf_available(): 12 | import tensorflow as tf 13 | from tensorflow.python.eager import context 14 | from tensorflow.python.framework import ops 15 | from transformers import (create_optimizer, GradientAccumulator) 16 | 17 | 18 | @require_tf 19 | class OptimizationFTest(unittest.TestCase): 20 | def assertListAlmostEqual(self, list1, list2, tol): 21 | self.assertEqual(len(list1), len(list2)) 22 | for a, b in zip(list1, list2): 23 | self.assertAlmostEqual(a, b, delta=tol) 24 | 25 | def testGradientAccumulator(self): 26 | accumulator = GradientAccumulator() 27 | accumulator([tf.constant([1.0, 2.0])]) 28 | accumulator([tf.constant([-2.0, 1.0])]) 29 | accumulator([tf.constant([-1.0, 2.0])]) 30 | with self.assertRaises(ValueError): 31 | accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])]) 32 | self.assertEqual(accumulator.step, 3) 33 | self.assertEqual(len(accumulator.gradients), 1) 34 | self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2) 35 | accumulator.reset() 36 | self.assertEqual(accumulator.step, 0) 37 | self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2) 38 | 39 | def testGradientAccumulatorDistributionStrategy(self): 40 | context._context = None 41 | ops.enable_eager_execution_internal() 42 | physical_devices = tf.config.experimental.list_physical_devices("CPU") 43 | tf.config.experimental.set_virtual_device_configuration( 44 | physical_devices[0], 45 | [tf.config.experimental.VirtualDeviceConfiguration(), 46 | tf.config.experimental.VirtualDeviceConfiguration()]) 47 | 48 | devices = tf.config.experimental.list_logical_devices(device_type="CPU") 49 | strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices]) 50 | 51 | with strategy.scope(): 52 | accumulator = GradientAccumulator() 53 | variable = tf.Variable([4.0, 3.0]) 54 | optimizer = create_optimizer(5e-5, 10, 5) 55 | gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False) 56 | 57 | def accumulate_on_replica(gradient): 58 | accumulator([gradient]) 59 | 60 | def apply_on_replica(): 61 | optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0) 62 | 63 | @tf.function 64 | def accumulate(grad1, grad2): 65 | with strategy.scope(): 66 | gradient_placeholder.values[0].assign(grad1) 67 | gradient_placeholder.values[1].assign(grad2) 68 | strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,)) 69 | 70 | @tf.function 71 | def apply_grad(): 72 | with strategy.scope(): 73 | strategy.experimental_run_v2(apply_on_replica) 74 | 75 | accumulate([1.0, 2.0], [-1.0, 1.0]) 76 | accumulate([3.0, -1.0], [-1.0, -1.0]) 77 | accumulate([-2.0, 2.0], [3.0, -2.0]) 78 | self.assertEqual(accumulator.step, 3) 79 | self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2) 80 | self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2) 81 | apply_grad() 82 | self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2) 83 | accumulator.reset() 84 | self.assertEqual(accumulator.step, 0) 85 | self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2) 86 | self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2) 87 | 88 | 89 | if __name__ == "__main__": 90 | unittest.main() -------------------------------------------------------------------------------- /transformers/tests/tokenization_albert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Hugging Face inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from transformers.tokenization_albert import (AlbertTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 25 | 'fixtures/spiece.model') 26 | 27 | class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | 29 | tokenizer_class = AlbertTokenizer 30 | 31 | def setUp(self): 32 | super(AlbertTokenizationTest, self).setUp() 33 | 34 | # We have a SentencePiece fixture for testing 35 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB) 36 | tokenizer.save_pretrained(self.tmpdirname) 37 | 38 | def get_tokenizer(self, **kwargs): 39 | return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 40 | 41 | def get_input_output_texts(self): 42 | input_text = u"this is a test" 43 | output_text = u"this is a test" 44 | return input_text, output_text 45 | 46 | 47 | def test_full_tokenizer(self): 48 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB, keep_accents=True) 49 | 50 | tokens = tokenizer.tokenize(u'This is a test') 51 | self.assertListEqual(tokens, [u'▁this', u'▁is', u'▁a', u'▁test']) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [48, 25, 21, 1289]) 55 | 56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 57 | self.assertListEqual(tokens, [u'▁i', u'▁was', u'▁born', u'▁in', u'▁9', u'2000', u',', u'▁and', u'▁this', u'▁is', u'▁fal', u's', u'é', u'.']) 58 | ids = tokenizer.convert_tokens_to_ids(tokens) 59 | self.assertListEqual(ids, [31, 23, 386, 19, 561, 3050, 15, 17, 48, 25, 8256, 18, 1, 9]) 60 | 61 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 62 | self.assertListEqual(back_tokens, ['▁i', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fal', 's', '', '.']) 63 | 64 | def test_sequence_builders(self): 65 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB) 66 | 67 | text = tokenizer.encode("sequence builders") 68 | text_2 = tokenizer.encode("multi-sequence build") 69 | 70 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 71 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 72 | 73 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] 74 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + text_2 + [tokenizer.sep_token_id] 75 | 76 | 77 | if __name__ == '__main__': 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import logging 22 | 23 | from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 24 | from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 25 | 26 | from .utils import slow, SMALL_MODEL_IDENTIFIER 27 | 28 | 29 | class AutoTokenizerTest(unittest.TestCase): 30 | @slow 31 | def test_tokenizer_from_pretrained(self): 32 | logging.basicConfig(level=logging.INFO) 33 | for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 34 | tokenizer = AutoTokenizer.from_pretrained(model_name) 35 | self.assertIsNotNone(tokenizer) 36 | self.assertIsInstance(tokenizer, BertTokenizer) 37 | self.assertGreater(len(tokenizer), 0) 38 | 39 | for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 40 | tokenizer = AutoTokenizer.from_pretrained(model_name) 41 | self.assertIsNotNone(tokenizer) 42 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 43 | self.assertGreater(len(tokenizer), 0) 44 | 45 | def test_tokenizer_from_pretrained_identifier(self): 46 | logging.basicConfig(level=logging.INFO) 47 | tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) 48 | self.assertIsInstance(tokenizer, BertTokenizer) 49 | self.assertEqual(len(tokenizer), 12) 50 | 51 | if __name__ == "__main__": 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_bert_japanese_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from transformers.tokenization_bert import WordpieceTokenizer 22 | from transformers.tokenization_bert_japanese import (BertJapaneseTokenizer, 23 | MecabTokenizer, CharacterTokenizer, 24 | VOCAB_FILES_NAMES) 25 | 26 | from .tokenization_tests_commons import CommonTestCases 27 | from .utils import slow, custom_tokenizers 28 | 29 | 30 | @custom_tokenizers 31 | class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester): 32 | 33 | tokenizer_class = BertJapaneseTokenizer 34 | 35 | def setUp(self): 36 | super(BertJapaneseTokenizationTest, self).setUp() 37 | 38 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", 39 | u"こんにちは", u"こん", u"にちは", u"ばんは", u"##こん", u"##にちは", u"##ばんは", 40 | u"世界", u"##世界", u"、", u"##、", u"。", u"##。"] 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) 43 | with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: 44 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 45 | 46 | def get_tokenizer(self, **kwargs): 47 | return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs) 48 | 49 | def get_input_output_texts(self): 50 | input_text = u"こんにちは、世界。 \nこんばんは、世界。" 51 | output_text = u"こんにちは 、 世界 。 こんばんは 、 世界 。" 52 | return input_text, output_text 53 | 54 | def test_full_tokenizer(self): 55 | tokenizer = self.tokenizer_class(self.vocab_file) 56 | 57 | tokens = tokenizer.tokenize(u"こんにちは、世界。\nこんばんは、世界。") 58 | self.assertListEqual(tokens, 59 | [u"こんにちは", u"、", u"世界", u"。", 60 | u"こん", u"##ばんは", u"、", u"世界", "。"]) 61 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), 62 | [3, 12, 10, 14, 4, 9, 12, 10, 14]) 63 | 64 | def test_mecab_tokenizer(self): 65 | tokenizer = MecabTokenizer() 66 | 67 | self.assertListEqual( 68 | tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), 69 | [u"アップルストア", u"で", u"iPhone", u"8", u"が", 70 | u"発売", u"さ", u"れ", u"た", u"。"]) 71 | 72 | def test_mecab_tokenizer_lower(self): 73 | tokenizer = MecabTokenizer(do_lower_case=True) 74 | 75 | self.assertListEqual( 76 | tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), 77 | [u"アップルストア", u"で", u"iphone", u"8", u"が", 78 | u"発売", u"さ", u"れ", u"た", u"。"]) 79 | 80 | def test_mecab_tokenizer_no_normalize(self): 81 | tokenizer = MecabTokenizer(normalize_text=False) 82 | 83 | self.assertListEqual( 84 | tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), 85 | [u"アップルストア", u"で", u"iPhone", u"8", u"が", 86 | u"発売", u"さ", u"れ", u"た", u" ", u"。"]) 87 | 88 | def test_wordpiece_tokenizer(self): 89 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", 90 | u"こんにちは", u"こん", u"にちは" u"ばんは", u"##こん", u"##にちは", u"##ばんは"] 91 | 92 | vocab = {} 93 | for (i, token) in enumerate(vocab_tokens): 94 | vocab[token] = i 95 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token=u"[UNK]") 96 | 97 | self.assertListEqual(tokenizer.tokenize(u""), []) 98 | 99 | self.assertListEqual(tokenizer.tokenize(u"こんにちは"), 100 | [u"こんにちは"]) 101 | 102 | self.assertListEqual(tokenizer.tokenize(u"こんばんは"), 103 | [u"こん", u"##ばんは"]) 104 | 105 | self.assertListEqual(tokenizer.tokenize(u"こんばんは こんばんにちは こんにちは"), 106 | [u"こん", u"##ばんは", u"[UNK]", u"こんにちは"]) 107 | 108 | @slow 109 | def test_sequence_builders(self): 110 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese") 111 | 112 | text = tokenizer.encode(u"ありがとう。", add_special_tokens=False) 113 | text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False) 114 | 115 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 116 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 117 | 118 | # 2 is for "[CLS]", 3 is for "[SEP]" 119 | assert encoded_sentence == [2] + text + [3] 120 | assert encoded_pair == [2] + text + [3] + text_2 + [3] 121 | 122 | 123 | class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTester): 124 | 125 | tokenizer_class = BertJapaneseTokenizer 126 | 127 | def setUp(self): 128 | super(BertJapaneseCharacterTokenizationTest, self).setUp() 129 | 130 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", 131 | u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界", u"、", u"。"] 132 | 133 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) 134 | with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: 135 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 136 | 137 | def get_tokenizer(self, **kwargs): 138 | return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, 139 | subword_tokenizer_type="character", 140 | **kwargs) 141 | 142 | def get_input_output_texts(self): 143 | input_text = u"こんにちは、世界。 \nこんばんは、世界。" 144 | output_text = u"こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。" 145 | return input_text, output_text 146 | 147 | def test_full_tokenizer(self): 148 | tokenizer = self.tokenizer_class(self.vocab_file, 149 | subword_tokenizer_type="character") 150 | 151 | tokens = tokenizer.tokenize(u"こんにちは、世界。 \nこんばんは、世界。") 152 | self.assertListEqual(tokens, 153 | [u"こ", u"ん", u"に", u"ち", u"は", u"、", u"世", u"界", u"。", 154 | u"こ", u"ん", u"ば", u"ん", u"は", u"、", u"世", u"界", u"。"]) 155 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), 156 | [3, 4, 5, 6, 7, 11, 9, 10, 12, 157 | 3, 4, 8, 4, 7, 11, 9, 10, 12]) 158 | 159 | def test_character_tokenizer(self): 160 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", 161 | u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界"u"、", u"。"] 162 | 163 | vocab = {} 164 | for (i, token) in enumerate(vocab_tokens): 165 | vocab[token] = i 166 | tokenizer = CharacterTokenizer(vocab=vocab, unk_token=u"[UNK]") 167 | 168 | self.assertListEqual(tokenizer.tokenize(u""), []) 169 | 170 | self.assertListEqual(tokenizer.tokenize(u"こんにちは"), 171 | [u"こ", u"ん", u"に", u"ち", u"は"]) 172 | 173 | self.assertListEqual(tokenizer.tokenize(u"こんにちほ"), 174 | [u"こ", u"ん", u"に", u"ち", u"[UNK]"]) 175 | 176 | @slow 177 | def test_sequence_builders(self): 178 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese-char") 179 | 180 | text = tokenizer.encode(u"ありがとう。", add_special_tokens=False) 181 | text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False) 182 | 183 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 184 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 185 | 186 | # 2 is for "[CLS]", 3 is for "[SEP]" 187 | assert encoded_sentence == [2] + text + [3] 188 | assert encoded_pair == [2] + text + [3] + text_2 + [3] 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from transformers.tokenization_bert import (BasicTokenizer, 22 | BertTokenizer, 23 | WordpieceTokenizer, 24 | _is_control, _is_punctuation, 25 | _is_whitespace, VOCAB_FILES_NAMES) 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | from .utils import slow 29 | 30 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 31 | 32 | tokenizer_class = BertTokenizer 33 | 34 | def setUp(self): 35 | super(BertTokenizationTest, self).setUp() 36 | 37 | vocab_tokens = [ 38 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 39 | "##ing", ",", "low", "lowest", 40 | ] 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 43 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 47 | 48 | def get_input_output_texts(self): 49 | input_text = u"UNwant\u00E9d,running" 50 | output_text = u"unwanted, running" 51 | return input_text, output_text 52 | 53 | def test_full_tokenizer(self): 54 | tokenizer = self.tokenizer_class(self.vocab_file) 55 | 56 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 57 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 58 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 59 | 60 | def test_chinese(self): 61 | tokenizer = BasicTokenizer() 62 | 63 | self.assertListEqual( 64 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 65 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 66 | 67 | def test_basic_tokenizer_lower(self): 68 | tokenizer = BasicTokenizer(do_lower_case=True) 69 | 70 | self.assertListEqual( 71 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 72 | ["hello", "!", "how", "are", "you", "?"]) 73 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 74 | 75 | def test_basic_tokenizer_no_lower(self): 76 | tokenizer = BasicTokenizer(do_lower_case=False) 77 | 78 | self.assertListEqual( 79 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 80 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 81 | 82 | def test_wordpiece_tokenizer(self): 83 | vocab_tokens = [ 84 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 85 | "##ing" 86 | ] 87 | 88 | vocab = {} 89 | for (i, token) in enumerate(vocab_tokens): 90 | vocab[token] = i 91 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 92 | 93 | self.assertListEqual(tokenizer.tokenize(""), []) 94 | 95 | self.assertListEqual( 96 | tokenizer.tokenize("unwanted running"), 97 | ["un", "##want", "##ed", "runn", "##ing"]) 98 | 99 | self.assertListEqual( 100 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 101 | 102 | def test_is_whitespace(self): 103 | self.assertTrue(_is_whitespace(u" ")) 104 | self.assertTrue(_is_whitespace(u"\t")) 105 | self.assertTrue(_is_whitespace(u"\r")) 106 | self.assertTrue(_is_whitespace(u"\n")) 107 | self.assertTrue(_is_whitespace(u"\u00A0")) 108 | 109 | self.assertFalse(_is_whitespace(u"A")) 110 | self.assertFalse(_is_whitespace(u"-")) 111 | 112 | def test_is_control(self): 113 | self.assertTrue(_is_control(u"\u0005")) 114 | 115 | self.assertFalse(_is_control(u"A")) 116 | self.assertFalse(_is_control(u" ")) 117 | self.assertFalse(_is_control(u"\t")) 118 | self.assertFalse(_is_control(u"\r")) 119 | 120 | def test_is_punctuation(self): 121 | self.assertTrue(_is_punctuation(u"-")) 122 | self.assertTrue(_is_punctuation(u"$")) 123 | self.assertTrue(_is_punctuation(u"`")) 124 | self.assertTrue(_is_punctuation(u".")) 125 | 126 | self.assertFalse(_is_punctuation(u"A")) 127 | self.assertFalse(_is_punctuation(u" ")) 128 | 129 | @slow 130 | def test_sequence_builders(self): 131 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") 132 | 133 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 134 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 135 | 136 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 137 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 138 | 139 | assert encoded_sentence == [101] + text + [102] 140 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 141 | 142 | 143 | if __name__ == '__main__': 144 | unittest.main() 145 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_ctrl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import, division, print_function, unicode_literals 15 | 16 | import os 17 | import unittest 18 | import json 19 | from io import open 20 | 21 | from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = CTRLTokenizer 28 | 29 | def setUp(self): 30 | super(CTRLTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', ''] 34 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 35 | merges = ["#version: 0.2", 'a p', 'ap t', 'r e', 'a d', 'ad apt', ''] 36 | self.special_tokens_map = {"unk_token": ""} 37 | 38 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 39 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 40 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 41 | fp.write(json.dumps(vocab_tokens) + "\n") 42 | with open(self.merges_file, "w", encoding="utf-8") as fp: 43 | fp.write("\n".join(merges)) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | kwargs.update(self.special_tokens_map) 47 | return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 48 | 49 | def get_input_output_texts(self): 50 | input_text = u"adapt react readapt apt" 51 | output_text = u"adapt react readapt apt" 52 | return input_text, output_text 53 | 54 | def test_full_tokenizer(self): 55 | tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 56 | text = "adapt react readapt apt" 57 | bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split() 58 | tokens = tokenizer.tokenize(text) 59 | self.assertListEqual(tokens, bpe_tokens) 60 | 61 | input_tokens = tokens + [tokenizer.unk_token] 62 | 63 | input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6] 64 | self.assertListEqual( 65 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_distilbert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from transformers.tokenization_distilbert import (DistilBertTokenizer) 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .tokenization_bert_test import BertTokenizationTest 25 | from .utils import slow 26 | 27 | class DistilBertTokenizationTest(BertTokenizationTest): 28 | 29 | tokenizer_class = DistilBertTokenizer 30 | 31 | def get_tokenizer(self, **kwargs): 32 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 33 | 34 | @slow 35 | def test_sequence_builders(self): 36 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") 37 | 38 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 39 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 40 | 41 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 42 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 43 | 44 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] 45 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \ 46 | text_2 + [tokenizer.sep_token_id] 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | from io import open 21 | 22 | from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = GPT2Tokenizer 29 | 30 | def setUp(self): 31 | super(GPT2TokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | if __name__ == '__main__': 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | from io import open 21 | 22 | from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .utils import slow 25 | 26 | 27 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | tokenizer_class = RobertaTokenizer 29 | 30 | def setUp(self): 31 | super(RobertaTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | def roberta_dict_integration_testing(self): 71 | tokenizer = self.get_tokenizer() 72 | 73 | self.assertListEqual( 74 | tokenizer.encode('Hello world!', add_special_tokens=False), 75 | [0, 31414, 232, 328, 2] 76 | ) 77 | self.assertListEqual( 78 | tokenizer.encode('Hello world! cécé herlolip 418', add_special_tokens=False), 79 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 80 | ) 81 | 82 | @slow 83 | def test_sequence_builders(self): 84 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 85 | 86 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 87 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 88 | 89 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 90 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 91 | 92 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 93 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 94 | 95 | assert encoded_sentence == encoded_text_from_decode 96 | assert encoded_pair == encoded_pair_from_decode 97 | 98 | 99 | if __name__ == '__main__': 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from transformers import is_torch_available 22 | 23 | if is_torch_available(): 24 | import torch 25 | from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | from .utils import require_torch 29 | 30 | 31 | @require_torch 32 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 33 | 34 | tokenizer_class = TransfoXLTokenizer if is_torch_available() else None 35 | 36 | def setUp(self): 37 | super(TransfoXLTokenizationTest, self).setUp() 38 | 39 | vocab_tokens = [ 40 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 41 | "running", ",", "low", "l", 42 | ] 43 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 44 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 45 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 46 | 47 | def get_tokenizer(self, **kwargs): 48 | kwargs['lower_case'] = True 49 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u" UNwanted , running" 53 | output_text = u" unwanted, running" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 58 | 59 | tokens = tokenizer.tokenize(u" UNwanted , running") 60 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 61 | 62 | self.assertListEqual( 63 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 64 | 65 | def test_full_tokenizer_lower(self): 66 | tokenizer = TransfoXLTokenizer(lower_case=True) 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 70 | ["hello", "!", "how", "are", "you", "?"]) 71 | 72 | def test_full_tokenizer_no_lower(self): 73 | tokenizer = TransfoXLTokenizer(lower_case=False) 74 | 75 | self.assertListEqual( 76 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 77 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | 22 | from transformers import PreTrainedTokenizer 23 | from transformers.tokenization_gpt2 import GPT2Tokenizer 24 | 25 | from .utils import slow 26 | 27 | class TokenizerUtilsTest(unittest.TestCase): 28 | 29 | def check_tokenizer_from_pretrained(self, tokenizer_class): 30 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 31 | for model_name in s3_models[:1]: 32 | tokenizer = tokenizer_class.from_pretrained(model_name) 33 | self.assertIsNotNone(tokenizer) 34 | self.assertIsInstance(tokenizer, tokenizer_class) 35 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 36 | 37 | for special_tok in tokenizer.all_special_tokens: 38 | if six.PY2: 39 | self.assertIsInstance(special_tok, unicode) 40 | else: 41 | self.assertIsInstance(special_tok, str) 42 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 43 | self.assertIsInstance(special_tok_id, int) 44 | 45 | @slow 46 | def test_pretrained_tokenizers(self): 47 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .utils import slow 25 | 26 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = XLMTokenizer 29 | 30 | def setUp(self): 31 | super(XLMTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 58 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | @slow 71 | def test_sequence_builders(self): 72 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 73 | 74 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 75 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 76 | 77 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 78 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 79 | 80 | assert encoded_sentence == [1] + text + [1] 81 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | from .utils import slow 24 | 25 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 26 | 'fixtures/test_sentencepiece.model') 27 | 28 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 29 | 30 | tokenizer_class = XLNetTokenizer 31 | 32 | def setUp(self): 33 | super(XLNetTokenizationTest, self).setUp() 34 | 35 | # We have a SentencePiece fixture for testing 36 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 37 | tokenizer.save_pretrained(self.tmpdirname) 38 | 39 | def get_tokenizer(self, **kwargs): 40 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) 41 | 42 | def get_input_output_texts(self): 43 | input_text = u"This is a test" 44 | output_text = u"This is a test" 45 | return input_text, output_text 46 | 47 | 48 | def test_full_tokenizer(self): 49 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 50 | 51 | tokens = tokenizer.tokenize(u'This is a test') 52 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 53 | 54 | self.assertListEqual( 55 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 56 | 57 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 58 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 59 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 60 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 61 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 62 | ids = tokenizer.convert_tokens_to_ids(tokens) 63 | self.assertListEqual( 64 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 65 | 602, 347, 347, 347, 3, 12, 66, 66 | 46, 72, 80, 6, 0, 4]) 67 | 68 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 69 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 70 | u'or', u'n', SPIECE_UNDERLINE + u'in', 71 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 72 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 73 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 74 | u'', u'.']) 75 | 76 | def test_tokenizer_lower(self): 77 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 78 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 79 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 80 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 81 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 82 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 83 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 84 | 85 | def test_tokenizer_no_lower(self): 86 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 87 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 88 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 89 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 90 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 91 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 92 | 93 | @slow 94 | def test_sequence_builders(self): 95 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 96 | 97 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 98 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 99 | 100 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 101 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 102 | 103 | assert encoded_sentence == text + [4, 3] 104 | assert encoded_pair == text + [4] + text_2 + [4, 3] 105 | 106 | 107 | if __name__ == '__main__': 108 | unittest.main() 109 | -------------------------------------------------------------------------------- /transformers/tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from distutils.util import strtobool 5 | 6 | from transformers.file_utils import _tf_available, _torch_available 7 | 8 | 9 | SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" 10 | 11 | 12 | def parse_flag_from_env(key, default=False): 13 | try: 14 | value = os.environ[key] 15 | except KeyError: 16 | # KEY isn't set, default to `default`. 17 | _value = default 18 | else: 19 | # KEY is set, convert it to True or False. 20 | try: 21 | _value = strtobool(value) 22 | except ValueError: 23 | # More values are supported, but let's keep the message simple. 24 | raise ValueError("If set, {} must be yes or no.".format(key)) 25 | return _value 26 | 27 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) 28 | _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) 29 | 30 | 31 | def slow(test_case): 32 | """ 33 | Decorator marking a test as slow. 34 | 35 | Slow tests are skipped by default. Set the RUN_SLOW environment variable 36 | to a truthy value to run them. 37 | 38 | """ 39 | if not _run_slow_tests: 40 | test_case = unittest.skip("test is slow")(test_case) 41 | return test_case 42 | 43 | 44 | def custom_tokenizers(test_case): 45 | """ 46 | Decorator marking a test for a custom tokenizer. 47 | 48 | Custom tokenizers require additional dependencies, and are skipped 49 | by default. Set the RUN_CUSTOM_TOKENIZERS environment variable 50 | to a truthy value to run them. 51 | """ 52 | if not _run_custom_tokenizers: 53 | test_case = unittest.skip("test of custom tokenizers")(test_case) 54 | return test_case 55 | 56 | 57 | def require_torch(test_case): 58 | """ 59 | Decorator marking a test that requires PyTorch. 60 | 61 | These tests are skipped when PyTorch isn't installed. 62 | 63 | """ 64 | if not _torch_available: 65 | test_case = unittest.skip("test requires PyTorch")(test_case) 66 | return test_case 67 | 68 | 69 | def require_tf(test_case): 70 | """ 71 | Decorator marking a test that requires TensorFlow. 72 | 73 | These tests are skipped when TensorFlow isn't installed. 74 | 75 | """ 76 | if not _tf_available: 77 | test_case = unittest.skip("test requires TensorFlow")(test_case) 78 | return test_case 79 | 80 | 81 | if _torch_available: 82 | # Set the USE_CUDA environment variable to select a GPU. 83 | torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu" 84 | else: 85 | torch_device = None 86 | --------------------------------------------------------------------------------